from collections import defaultdict
from brokilon.ccd.types import CladeSplitInfo
from brokilon.ccd.domain.transmission.transmission_ccd import get_transmission_maps
from brokilon.core import Tree
[docs]
def compatible(split, parent):
l, r = split
pancest = False if parent.transm_ancest.startswith("Unknown") else int(parent.transm_ancest)
lancest = False if l.transm_ancest.startswith("Unknown") else int(l.transm_ancest)
rancest = False if r.transm_ancest.startswith("Unknown") else int(r.transm_ancest)
if lancest in r.clade and rancest in l.clade:
# this doesn't work
return False
if not any([pancest, lancest, rancest]):
# all unknown so this is fine...
return True
if pancest in l.clade:
if lancest == pancest:
if rancest == pancest or not rancest:
return True
return False
if pancest in r.clade:
if rancest == pancest:
if lancest == pancest or not lancest:
return True
return False
if not any([lancest, rancest]):
# both left and right are unknown which is fine
return True
# here the parent ancestor is not in either left or right
# left and right have to be compatible though...
if lancest in l.clade:
if not rancest or rancest == lancest:
return True
return False
if rancest in r.clade:
if not lancest or lancest == rancest:
return True
return False
if rancest not in l.clade:
return False
if lancest not in r.clade:
return False
print(
f"-----------------\n"
f"No case applied, will be marked as not compatible...\n"
f"{parent}\n"
f"{split}\n"
f"-----------------"
)
return False
[docs]
def expand_tccd0(clade_partitions):
expanded_clade_partitions = defaultdict(set)
for clade, splits in clade_partitions.items():
for left, right in splits:
if min(left.clade) < min(right.clade):
expanded_clade_partitions[clade].add((left, right))
else:
expanded_clade_partitions[clade].add((right, left))
clade_buckets = defaultdict(set)
for c in clade_partitions.keys():
clade_buckets[len(c)].add(c)
total = len(clade_partitions)
for i, clade in enumerate(clade_partitions.keys(), 1):
n = len(clade)
# progress prints
if i % 100 == 0 or i == total:
print(f"[{i}/{total}] Working on clade of size {n}")
if n < 2:
# we now care about cherries, might even have to do leafs?...
continue
existing_splits = set(expanded_clade_partitions.get(clade, set()))
for left_size in range(1, n // 2 + 1):
# print(f"Working on {left_size}")
# this if is early stopping if there is nothing to combine with
if clade_buckets[n - left_size]:
for left in clade_buckets[left_size]:
if left.clade.issubset(clade.clade):
rightCladeSet = clade.clade - left.clade
matches = [tc for tc in clade_buckets[len(rightCladeSet)] if tc.clade ==
rightCladeSet]
if matches:
for cur_right in matches:
# todo this is ugly formatting...
# l, r = (left, cur_right) if (
# min(left.clade) < min(cur_right.clade)) \
# else (cur_right, left)
l, r = sorted([left, cur_right], key=lambda t: min(t.clade))
split = (l, r)
if split not in existing_splits:
if compatible(split, clade):
# print(
# f"-----------------\n"
# f"Found compatible split:"
# f"\n{clade}\n"
# f"{split}\n-----------------"
# )
expanded_clade_partitions[clade].add(split)
return expanded_clade_partitions
[docs]
def get_transmission_ccd0(trees, expansion: bool = True):
m1_observed_clade_counts, m2_observed_clade_split_counts, _, _ = (
get_transmission_maps(trees, type_str="Ancestry"))
n_trees = len(trees)
clade_partitions = defaultdict(list)
for (parent, child1, child2), _ in m2_observed_clade_split_counts.items():
clade_partitions[parent].append((child1, child2))
if expansion:
print("Starting expansion")
expanded_clade_splits = expand_tccd0(clade_partitions)
clade_partitions = expanded_clade_splits
print("Expansion has finished...")
ccd0_probabilities = {}
def recursive_prob_computer(clade):
if clade in ccd0_probabilities:
return ccd0_probabilities[clade]
clade_value = m1_observed_clade_counts.get(clade, 0) / n_trees
# leaf
if len(clade) == 1:
ccd0_probabilities[clade] = 1.0
return 1.0
# cherries are now general clades because leafs can have different transmission ancestors
transmisison_ccd0_map[clade] = {}
part_products = []
for left, right in clade_partitions[clade]:
# print(f"Recursive work on {clade}")
left_probability = recursive_prob_computer(left)
# print(f"Recursive work for left child done...")
right_probability = recursive_prob_computer(right)
# print(f"Recursive work for right child done...")
product = left_probability * right_probability
part_products.append(product)
total = sum(part_products)
if total > 0:
for (left, right), prod in zip(clade_partitions[clade], part_products):
transmisison_ccd0_map[clade][(left, right)] = prod / total
else:
raise ValueError("Need to implement a log fall back...")
clade_prob = clade_value * total
ccd0_probabilities[clade] = clade_prob
child_probabilities = defaultdict(float)
from collections import deque
queue = deque()
# Making sure we only visit each child once, otherwise this explodes, is that correct?...
# This is different here because now a clade can split differently but with the same clade
# as one of the children and the other is different, hence we need to check for visited
# That is. C - C1, C2 but also C - C1, C3 where they have C1 the same but the sibling is
# different...
visited = set()
# Handling multiple possible roots
for clade in m1_observed_clade_counts.keys():
if len(clade) == len(trees[0]):
child_probabilities[clade] = m1_observed_clade_counts[clade] / n_trees
queue.append(clade)
visited.add(clade)
# steps = 0
while queue:
# steps += 1
# if steps % 100 == 0:
# print(f"Step {steps}, queue size={len(queue)}")
# print(f"Length of stuff: {len(transmisison_ccd0_map.get(clade, {}).values())}")
clade = queue.popleft()
parent_prob = child_probabilities[clade]
for (left, right), ccp in transmisison_ccd0_map.get(clade, {}).items():
child_probabilities[left] += parent_prob * ccp
child_probabilities[right] += parent_prob * ccp
for child in (right, left):
if child in transmisison_ccd0_map and child not in visited:
# print(f"adding left: {left}")
queue.append(child)
visited.add(child)
return clade_prob
transmisison_ccd0_map = {}
for clade in m1_observed_clade_counts:
recursive_prob_computer(clade)
return transmisison_ccd0_map
[docs]
def get_map_tree(transmission_ccd0_map):
seen_resolved_clades = {}
for cur_clade in sorted(transmission_ccd0_map.keys(), key=len):
if len(cur_clade) == 2:
# cherries are now the smallest
ta_clade_split, prob = max(tccd0_map[cur_clade].items(), key=lambda item: item[1])
# check if max was unique:
# ties = [k for k, v in tccd0_map[cur_clade].items() if v == prob]
# if len(ties):
# # todo could be annotated here and then if actually used in the MAP tree point it
# # out....
# # todo might be valuable to keep all the most likely ones for uncertainty ...
# print("TIEBREAKING: Highest probable split of a cherry was not unique")
seen_resolved_clades[cur_clade] = CladeSplitInfo(ta_clade_split, prob)
else:
for cur_clade_split, cur_clade_split_prob in transmission_ccd0_map[cur_clade].items():
cur_clade_child1, cur_clade_child2 = cur_clade_split
if len(cur_clade_child1) == 1:
cur_clade_child1_prob = 1.0
else:
cur_clade_child1_prob = seen_resolved_clades[cur_clade_child1].prob
if len(cur_clade_child2) == 1:
cur_clade_child2_prob = 1.0
else:
cur_clade_child2_prob = seen_resolved_clades[cur_clade_child2].prob
cur_clade_split_map_probability = (cur_clade_child1_prob * cur_clade_child2_prob *
cur_clade_split_prob)
if cur_clade in seen_resolved_clades:
# todo here there can be ties too, maybe need to point those out too...
if seen_resolved_clades[cur_clade].prob <= cur_clade_split_map_probability:
seen_resolved_clades[cur_clade] = CladeSplitInfo(
cur_clade_split, cur_clade_split_map_probability
)
else:
seen_resolved_clades[cur_clade] = CladeSplitInfo(
cur_clade_split, cur_clade_split_map_probability
)
output = {}
all_root_clades = [k for k in seen_resolved_clades.keys()
if len(k) == len(max(seen_resolved_clades.keys()))]
max_prob = max(seen_resolved_clades[root].prob for root in all_root_clades)
best_roots = [root for root in all_root_clades if seen_resolved_clades[root].prob == max_prob]
if len(best_roots) > 1:
raise NotImplementedError("There is more than one best root, this needs to be implemented!")
working_list = [best_roots[0]]
while working_list:
cur_parent = working_list.pop()
split, prob = seen_resolved_clades[cur_parent]
output[cur_parent] = split
for split_child in split:
if len(split_child) > 1:
working_list.append(split_child)
else:
output[split_child] = split_child
return output
[docs]
def get_tree_from_dict_of_splits(splits):
root_split = max(splits.keys())
output_tree = Tree(support=0, dist=0, name="root")
output_tree.add_feature("transm_ancest", root_split.transm_ancest)
icount = 1
def recursive_children(node, new_split):
nonlocal splits, icount
clade_c1, clade_c2 = new_split
def add_clade(node, clade):
nonlocal icount, splits
if len(clade) == 1:
label = next(iter(clade.clade))
leaf = node.add_child(name=str(label), dist=1)
leaf.add_feature("transm_ancest", clade.transm_ancest)
else:
internal_node = node.add_child(name=f"child_internal_{icount}", dist=1)
internal_node.add_feature("transm_ancest", clade.transm_ancest)
icount += 1
recursive_children(internal_node, splits[clade])
add_clade(node, clade_c1)
add_clade(node, clade_c2)
recursive_children(output_tree, splits[root_split])
return output_tree
if __name__ == '__main__':
from pathlib import Path
from brokilon.core import read_nexus_trees
trees = read_nexus_trees(
# f"{Path(__file__).parent.absolute().parent.parent}/examples/data/breath32sim.trees",
# f"{Path(__file__).parent.absolute().parent.parent}/examples/data/roetzer40.trees",
f"{Path(__file__).parent.absolute().parent.parent}/examples/data/breath32simShort.trees",
)
# trees = trees[int(len(trees) * 0.95):]
print(f"Parsed {len(trees)} tree after burnin...")
tccd0_map = get_transmission_ccd0(trees)
map_tree = get_map_tree(tccd0_map)
tree_map_thingy = get_tree_from_dict_of_splits(map_tree)
print(tree_map_thingy.write(format=5, features=["transm_ancest"], format_root_node=True))