Source code for brokilon.ccd.domain.topology.ccd0_attempt

from collections import defaultdict
from math import log
from typing import Any

from brokilon.core import Tree
from brokilon.ccd.types import CladeSplitInfo


[docs] def expand(observed_clades, observed_clade_splits): # todo observed clades can be replaced by using the keys in oberserved_clade_splits? expanded_clade_partitions = defaultdict(set) for clade, splits in observed_clade_splits.items(): for left, right in splits: if min(left) < min(right): expanded_clade_partitions[clade].add((left, right)) else: expanded_clade_partitions[clade].add((right, left)) clade_buckets = defaultdict(set) for c in observed_clades: clade_buckets[len(c)].add(c) for clade in observed_clades: n = len(clade) if n <= 2: continue # skip leaves and cherries existing_splits = set(expanded_clade_partitions.get(clade, set())) # clade_list = sorted(clade) # consider only left sizes from 1 to n//2 for left_size in range(1, n // 2 + 1): # todo here add an early stop if there is nothing of comlementary size to combine with # see transmission ccd0 for left in clade_buckets[left_size]: if left.issubset(clade): right = clade - left if right in clade_buckets[len(right)]: # canonical ordering l, r = (left, right) if min(left) < min(right) else (right, left) split = (l, r) if split not in existing_splits: expanded_clade_partitions[clade].add(split) # print(f"Found new split for {clade}: {split}") return expanded_clade_partitions
[docs] def get_maps_full(trees: list[Tree]) \ -> tuple[defaultdict[Any, int], defaultdict[Any, int]]: """ From a list of trees, return relevant CCD maps from clades/clade splits to counts. :param trees: list of input trees :return: maps for CCDs, clades to occurrences (m1), clades to clade splits (m2), unique trees """ m1 = defaultdict(int) # map for each clade how often it got sampled m2 = defaultdict(int) # map for each (c1,c2) clade how often this specific relation got sampled for _, t in enumerate(trees): for node in t.traverse("levelorder"): parent_clade = frozenset(int(leaf.name) for leaf in node) m1[parent_clade] += 1 if node.children: child0_clade = frozenset(sorted(int(leaf.name) for leaf in node.children[0])) child1_clade = frozenset(sorted(int(leaf.name) for leaf in node.children[1])) m2[(parent_clade, child0_clade, child1_clade)] += 1 return m1, m2
[docs] def get_ccd0(trees): # todo add an option ccdType like and make this compute ccd0 or ccd1, then merge into single # ccd.py file and replace old code with this cleaned up version # todo the expansion step of ccd0 should be an option and might be too much for BREATH trees n_trees = len(trees) m1_obs_clade_counts, m2_obs_clade_split_counts = get_maps_full(trees) clade_partitions = defaultdict(list) for (parent, child1, child2), count in m2_obs_clade_split_counts.items(): clade_partitions[parent].append((child1, child2)) # todo what is happening here, put this together with get maps... expanded_clade_split_counts = expand(set(m1_obs_clade_counts.keys()), clade_partitions) clade_partitions = expanded_clade_split_counts ccd0_probabilities = {} # todo need to extract this function... def recursive_prob_computer(clade): # nonlocal ccd0_probabilities # nonlocal clade_partitions # nonlocal n_trees if clade in ccd0_probabilities: return ccd0_probabilities[clade] clade_value = m1_obs_clade_counts.get(clade, 0) / n_trees # leaf if len(clade) == 1: ccd0_probabilities[clade] = 1.0 return 1.0 # cherry if len(clade) == 2: # [(k,v) for k,v in ccd0_probabilities.items() if len(k) == 2] assert len(clade_partitions[clade]) == 1, ("Cherry split in more than one way? " "impossible!") left, right = list(clade_partitions[clade])[0] # recursion for the two leaves shouldn't do much?... recursive_prob_computer(left) recursive_prob_computer(right) ccd0_probabilities[clade] = clade_value return clade_value # general clade partitions_ccp[clade] = {} part_products = [] for left, right in clade_partitions[clade]: left_probability = recursive_prob_computer(left) right_probability = recursive_prob_computer(right) product = left_probability * right_probability part_products.append(product) total = sum(part_products) if total > 0: for (left, right), prod in zip(clade_partitions[clade], part_products): partitions_ccp[clade][(left, right)] = prod / total else: raise ValueError("Need to implement this fall back...") clade_prob = clade_value * total ccd0_probabilities[clade] = clade_prob child_probabilities = defaultdict(float) root_clade = max(m1_obs_clade_counts.keys()) child_probabilities[root_clade] = 1.0 from collections import deque queue = deque([root_clade]) while queue: clade = queue.popleft() parent_prob = child_probabilities[clade] for (left, right), ccp in partitions_ccp.get(clade, {}).items(): # propagate probability to children child_probabilities[left] += parent_prob * ccp child_probabilities[right] += parent_prob * ccp # add children to the queue if they have further partitions for child in (left, right): if child in partitions_ccp: queue.append(child) # if left in partitions_ccp: # queue.append(left) # if right in partitions_ccp: # queue.append(right) return clade_prob # todo rename the variables to be more consistent across the board... # using the recursion form here on... partitions_ccp = {} for clade in m1_obs_clade_counts: recursive_prob_computer(clade) return partitions_ccp
[docs] def get_tree_probability(tree, partitions_ccp, use_log=False): prob = 0.0 if use_log else 1 for node in tree.traverse("levelorder"): if len(node) > 2: # Cherry or bigger left_clade = frozenset(sorted(int(leaf.name) for leaf in node.children[0])) right_clade = frozenset(sorted(int(leaf.name) for leaf in node.children[1])) split = (left_clade, right_clade) if min(left_clade) < min(right_clade) \ else (right_clade, left_clade) parent_clade = left_clade | right_clade ccp = partitions_ccp.get(parent_clade, {}).get(split, 0.0) if ccp == 0.0: prob = float("-inf") if use_log else 0.0 break if use_log: prob += log(ccp) else: prob *= ccp return float(prob)
[docs] def get_map_tree(partitions_ccp): # todo rename variables... all_clades_sorted = sorted(partitions_ccp.keys(), key=len) seen_resolved_clades = {} for cur_clade in all_clades_sorted: assert len(cur_clade) > 2, "Cherries are not relevant for this?" if len(cur_clade) <= 3: # for a triplet we need to simply choose the max # todo point out to the user if the max is not unique!!!! cur_clade_split, cur_clade_split_prob = max(partitions_ccp[cur_clade].items(), key=lambda item: item[1]) assert cur_clade not in seen_resolved_clades, "Triplets should only get into here once" seen_resolved_clades[cur_clade] = CladeSplitInfo(cur_clade_split, cur_clade_split_prob) else: for cur_clade_split, cur_clade_split_prob in partitions_ccp[cur_clade].items(): cur_clade_child1, cur_clade_child2 = cur_clade_split if len(cur_clade_child1) < 3: cur_clade_child1_prob = 1.0 else: assert cur_clade_child1 in seen_resolved_clades, "This should be impossible" cur_clade_child1_prob = seen_resolved_clades[cur_clade_child1].prob if len(cur_clade_child2) < 3: cur_clade_child2_prob = 1.0 else: assert cur_clade_child2 in seen_resolved_clades, "This should be impossible" cur_clade_child2_prob = seen_resolved_clades[cur_clade_child2].prob cur_clade_split_map_probabiltiy = (cur_clade_child1_prob * cur_clade_child2_prob * cur_clade_split_prob) if cur_clade in seen_resolved_clades: if seen_resolved_clades[cur_clade].prob <= cur_clade_split_map_probabiltiy: seen_resolved_clades[cur_clade] = CladeSplitInfo( cur_clade_split, cur_clade_split_map_probabiltiy ) else: seen_resolved_clades[cur_clade] = CladeSplitInfo( cur_clade_split, cur_clade_split_map_probabiltiy ) output = {} working_list = [max(seen_resolved_clades.keys())] while working_list: cur_parent = working_list.pop() split, _ = seen_resolved_clades[cur_parent] output[cur_parent] = split for split_child in split: if len(split_child) > 2: working_list.append(split_child) return get_tree_from_dict_of_splits(output)
[docs] def get_tree_from_dict_of_splits(splits): output_tree = Tree(support=0, dist=0, name="root") icount = 1 def recursive_children(node, new_split): nonlocal splits, icount clade_c1, clade_c2 = new_split def add_clade(node, clade): nonlocal icount nonlocal splits n = len(clade) if n == 2: # add two leaf nodes l1, l2 = clade internal = node.add_child(name=f"internal_{icount}", dist=1) icount += 1 internal.add_child(name=str(l1), dist=1) internal.add_child(name=str(l2), dist=1) elif n == 1: # add a single leaf node label = next(iter(clade)) node.add_child(name=str(label), dist=1) else: c1 = node.add_child(name=f"child_internal_{icount}", dist=1) icount += 1 recursive_children(c1, splits[clade]) add_clade(node, clade_c1) add_clade(node, clade_c2) recursive_children(output_tree, splits[max(splits.keys())]) return output_tree
if __name__ == '__main__': from pathlib import Path java_tree_probs = {} from brokilon.core.read_nexus import read_nexus_trees tree_file = f"{Path(__file__).parent.absolute().parent.parent}/examples/data/30Taxa.trees" trees, taxon_map = read_nexus_trees(tree_file, parse_taxon_map=True) # todo refactor and make this work for ccd0 and ccd1 to replace the old code # todo compute entropy from this and also compare that to java implementation # todo maybe look at the expansion step and make it a bit faster somehow? profiler needed... trees = trees[int(len(trees) * 0.1):] partitions_ccp = get_ccd0(trees) map_tree = get_map_tree(partitions_ccp) print(map_tree.write(format=5))