Source code for brokilon.ccd.domain.sranges.sranges

from collections import defaultdict
from collections import namedtuple

from brokilon.ccd.clades.sranges import SRangesClade

import inspect
from dataclasses import dataclass, field


[docs] def get_callsite(): frame = inspect.stack()[2] # caller of caller return f"{frame.filename}:{frame.lineno} ({frame.function})"
[docs] @dataclass(frozen=True) class AncestralSplit: ancestor: SRangesClade descendant: SRangesClade source: str = field(default_factory=get_callsite)
[docs] def prelabel_tree(tree, taxon_map): opened_ranges = set({}) closed_ranges = set({}) sampled_ancestors = set({}) # This requires a two pass thing because the range can be opend and closed # at the same distance (in terms of number of nodes/edges) to the root # This could be avoided by doing a levelorder algorithm # that also takes branch lengths into account # Or first sorting the leafs by distance to root and then iterating over that list instead. for leaf in tree: current_taxon = taxon_map[int(leaf.name)] if leaf.dist == 0.0: # non regular leaf if current_taxon.endswith("_first"): other_leaves = [taxon_map[int(t.name)] for t in leaf.up if t is not leaf] range_end = current_taxon.replace("_first", "_last") if range_end in other_leaves: # We are at the start of a range opened_ranges.add(current_taxon) if not hasattr(leaf.up, "rangetype"): leaf.up.add_feature("rangetype", "range_start") else: assert leaf.up.rangetype == "range_end", "Failure" leaf.up.rangetype = "leaf_range" leaf.up.add_feature("range", current_taxon[:-6]) leaf.add_feature("rangetype", "range_start") else: # We are looking at a SA leaf.add_feature("rangetype", "sampled_ancestor") leaf.up.add_feature("rangetype", "sampled_ancestor") assert not hasattr(leaf.up, "range"), "This should be SA!" # Adding the SA as an attribute to the parent for further use leaf.up.add_feature("sampledancestor", taxon_map[int(leaf.name)][:-6]) sampled_ancestors.add(current_taxon) elif current_taxon.endswith("_last"): # This should always be the end of a range if not hasattr(leaf.up, "rangetype"): leaf.up.add_feature("rangetype", "range_end") else: assert leaf.up.rangetype == "range_start", "failue" leaf.up.rangetype = "leaf_range" leaf.add_feature("rangetype", "range_end") closed_ranges.add(current_taxon) else: raise ValueError("This should not happen") # second pass to catch all the closed ranges that happen at the same node as opening... for leaf in tree: current_taxon = taxon_map[int(leaf.name)] if current_taxon.endswith("_last"): # Closing ranges with leafs that don't have dist 0 if f"{current_taxon[:-5]}_first" in opened_ranges: if not hasattr(leaf.up, "rangetype"): if leaf.dist == 0.0: leaf.up.add_feature("rangetype", "range_end") elif leaf.up.rangetype == "range_start": # The only case that we need to chnage it to leaf_range is this: # It is already a range_start, and we would overwrite it with range_end leaf.up.rangetype = "leaf_range" leaf.add_feature("rangetype", "range_end") closed_ranges.add(current_taxon) else: # assuming that leaf is not end of a range. continue if len(opened_ranges) != len(closed_ranges): raise ValueError("This is not possible and a bug is found.") print(f"Tree has {len(sampled_ancestors)} SAs and {len(opened_ranges)} ranges.")
[docs] def get_sranges_map(trees, taxon_map, ccd_type=1): clade_count_map = defaultdict(int) clade_split_count_map = defaultdict(lambda: defaultdict(int)) reverse_taxon_map = {value: key for key, value in taxon_map.items()} # for node in (node for t in trees for node in t.traverse("levelorder")): for t in trees: prelabel_tree(t, taxon_map) for node in t.traverse("levelorder"): # We can ignore all leafs that are 0.0, i.e. internal leafs that encode a non leaf... if node.is_leaf(): if node.dist == 0.0: continue if hasattr(node, "rangetype"): if node.rangetype == "range_end": if hasattr(node.up, "rangetype"): if node.up.rangetype == "leaf_range": continue assert node.dist != 0.0, "If this happens we need to not allow it?!" assert node.up.range == node.range, "The ranges don't match?" assert (taxon_map[int(node.name)].replace("_last", "_first") == f"{node.range}_first"), \ "More problems" clade_count_map[SRangesClade( frozenset({f"{node.range}_first"}), f"{node.range}_first" )] += 1 else: # ignore range_start and sampled ancestors continue else: # regular leaf assert taxon_map[int(node.name)].endswith("_first"), "Failed leaf case..." cur_range = None if hasattr(node.up, "range"): cur_range = f"{node.up.range}_first" elif hasattr(node.up, "rangetype"): match node.up.rangetype: case "range_end": cur_range = f"{node.up.range}_first" case "sampled_ancestor": cur_range = f"{node.up.sampledancestor}_first" case _: raise AssertionError( "If this happens there might be a missing case..." ) new_clade = SRangesClade( frozenset({taxon_map[int(node.name)]}), cur_range ) clade_count_map[new_clade] += 1 elif hasattr(node, "rangetype"): match node.rangetype: case "sampled_ancestor": current_range = None if hasattr(node.up, "range"): current_range = f"{node.up.range}_first" elif hasattr(node.up, "sampledancestor"): current_range = f"{node.up.sampledancestor}_first" parent_clade_set = {taxon_map[int(l.name)].replace("_last", "_first") for l in node} child0_clade_set = {taxon_map[int(l.name)].replace("_last", "_first") for l in node.children[0]} child1_clade_set = {taxon_map[int(l.name)].replace("_last", "_first") for l in node.children[1]} assert parent_clade_set == child0_clade_set.union(child1_clade_set) current_parent_clade = SRangesClade( frozenset(parent_clade_set), current_range ) if node.children[0].dist == 0.0: current_split = AncestralSplit( ancestor=SRangesClade( frozenset({}), taxon_map[int(node.children[0].name)] ), descendant=SRangesClade( frozenset(child1_clade_set), taxon_map[int(node.children[0].name)] ) ) elif node.children[1].dist == 0.0: current_split = AncestralSplit( ancestor=SRangesClade( frozenset({}), taxon_map[int(node.children[1].name)] ), descendant=SRangesClade( frozenset(child0_clade_set), taxon_map[int(node.children[1].name)] ) ) clade_count_map[current_parent_clade] += 1 clade_split_count_map[current_parent_clade][current_split] += 1 case "range_start": current_range = None # if hasattr(node.up, "range"): # current_range = f"{node.up.range}_first" # elif hasattr(node.up, "sampledancestor"): # current_range = f"{node.up.sampledancestor}_first" if node.up: if hasattr(node.up, "rangetype"): if node.up.rangetype == "range_end": current_range = f"{node.up.range}_first" elif node.up.rangetype == "sampled_ancestor": assert hasattr(node.up, "sampledancestor"), "This needs fixing" current_range = f"{node.up.sampledancestor}_first" elif node.up.rangetype == "range_start": assert node.up.range == node.range, "something wrong with node.up.up case" if hasattr(node.up.up, "range"): current_range = f"{node.up.up.range}_first" elif hasattr(node.up, "range"): current_range = f"{node.up.range}_first" parent_clade = SRangesClade( frozenset( {taxon_map[int(l.name)].replace("_last", "_first") for l in node} ), current_range ) clade_count_map[parent_clade] += 1 case "range_end": # Keeping just the taxon name as range for now... child0_clade_set = {taxon_map[int(l.name)].replace("_last", "_first") for l in node.children[0]} child1_clade_set = {taxon_map[int(l.name)].replace("_last", "_first") for l in node.children[1]} parent_clade_set = { taxon_map[int(l.name)].replace("_last", "_first") for l in node } assert parent_clade_set == child0_clade_set.union(child1_clade_set), \ "Failure in range_end case..." # The following is a range end, hence one of the two clade sets # is the one taxa that is coding the end of the range ..._last # We need to figure out which one, set it as descendant, # and then we remove it from the taxon set to encode the end of range # if node.children[0].orientation == "ancestor": if child1_clade_set == {f"{node.range}_first"}: current_split = AncestralSplit( ancestor=SRangesClade( frozenset(child0_clade_set), f"{node.range}_first" ), descendant=SRangesClade( # frozenset(child1_clade_set), frozenset({}), f"{node.range}_first" ) ) elif child0_clade_set == {f"{node.range}_first"}: current_split = AncestralSplit( ancestor=SRangesClade( frozenset(child1_clade_set), f"{node.range}_first" ), descendant=SRangesClade( # frozenset(child0_clade_set), frozenset({}), f"{node.range}_first" ) ) else: raise AssertionError(f"Failure: range end should have " f"single taxon leaf set which is the range...") parent_range = None if node.up: if hasattr(node.up, "rangetype"): if node.up.rangetype == "range_end": parent_range = f"{node.up.range}_first" elif node.up.rangetype == "sampled_ancestor": assert hasattr(node.up, "sampledancestor"), "This needs fixing" parent_range = f"{node.up.sampledancestor}_first" elif node.up.rangetype == "range_start": assert node.up.range == node.range, "something wrong with node.up.up case" if hasattr(node.up.up, "range"): parent_range = f"{node.up.up.range}_first" elif hasattr(node.up.up, "sampledancestor"): parent_range = f"{node.up.up.sampledancestor}_first" elif hasattr(node.up, "range"): parent_range = f"{node.up.range}_first" parent_clade = SRangesClade( frozenset(parent_clade_set), parent_range ) clade_count_map[parent_clade] += 1 clade_split_count_map[parent_clade][current_split] += 1 case "leaf_range": # We are looking at a node that is parent of a range that is a leaf # We don't need to add a split here? # We only count the clade, which is the taxon and itself as a range? # This can surely be done in a nicer way... current_leaf = [l.name for l in node if l.dist == 0.0][0] range_taxon = reverse_taxon_map[f"{node.range}_first"] assert range_taxon == int(current_leaf), (f"Failure: " f"{range_taxon} != {current_leaf}") # Leaf ranges can not have themselves as a range, # but there can be other ranges as ancestors above # Leaf ranges have no ancestral range of themselves because # otherwise this information is assumed to be present above the leaf # and therefore breaks the logic # One way to avoid this and fully encode leaf ranges is # to have _first and _last taxon as part of the clades... current_range = None if hasattr(node.up, "range"): current_range = f"{node.up.range}_first" elif hasattr(node.up, "sampledancestor"): current_range = f"{node.up.sampledancestor}_first" clade_count_map[ SRangesClade( frozenset({taxon_map[int(current_leaf)]}), current_range ) ] += 1 case _: raise ValueError("Unsupported value...") else: parent_clade_set = { taxon_map[int(l.name)].replace("_last", "_first") for l in node } child0_clade_set = { taxon_map[int(l.name)].replace("_last", "_first") for l in node.children[0] } child1_clade_set = { taxon_map[int(l.name)].replace("_last", "_first") for l in node.children[1] } assert child1_clade_set.union(child0_clade_set) == parent_clade_set, \ "Something is wrong with the clades..." cur_range = None parent_range = None if hasattr(node, "range"): cur_range = f"{node.range}_first" if node.up: if hasattr(node.up, "rangetype"): if node.up.rangetype == "range_end": parent_range = f"{node.up.range}_first" elif node.up.rangetype == "sampled_ancestor": assert hasattr(node.up, "sampledancestor"), "This needs fixing" parent_range = f"{node.up.sampledancestor}_first" elif node.up.rangetype == "range_start": assert node.up.range == node.range, "something wrong with node.up.up case" if hasattr(node.up.up, "range"): parent_range = f"{node.up.up.range}_first" elif hasattr(node.up.up, "sampledancestor"): parent_range = f"{node.up.up.sampledancestor}_first" elif hasattr(node.up, "range"): parent_range = f"{node.up.range}_first" parent_clade = SRangesClade(frozenset(parent_clade_set), parent_range) if node.children[0].orientation == "ancestor": current_split = AncestralSplit( ancestor=SRangesClade(frozenset(child0_clade_set), cur_range), descendant=SRangesClade(frozenset(child1_clade_set), cur_range), ) else: current_split = AncestralSplit( ancestor=SRangesClade(frozenset(child1_clade_set), cur_range), descendant=SRangesClade(frozenset(child0_clade_set), cur_range), ) clade_count_map[parent_clade] += 1 clade_split_count_map[parent_clade][current_split] += 1 return (dict(clade_count_map), {clade: dict(splits) for clade, splits in clade_split_count_map.items()})
[docs] def get_sranges_map_tree( clade_count_map, clade_split_count_map, taxon_map, reverse_taxon_map ): seen_resolved_clades = {} leaf_clades = [] for current_clade in sorted(clade_count_map.keys(), key=len): if len(current_clade) == 1: leaf_clades.append(current_clade) continue elif len(current_clade) == 0: # this is splitting off a range or SA continue for current_split in clade_split_count_map[current_clade]: anc_prob, desc_prob = 0, 0 if len(current_split.ancestor.clade) == 1: leaf_observations = sum( [clade_count_map[k] for k in leaf_clades if k.clade == current_split.ancestor.clade] ) anc_prob = clade_count_map[current_split.ancestor] / leaf_observations assert 0 <= anc_prob <= 1.0, "Prob failure1..." elif len(current_split.ancestor.clade) == 0: anc_prob = 1 else: # [c for c in seen_resolved_clades if c.clade == current_split.ancestor.clade], current_split.ancestor, seen_resolved_clades[current_split.ancestor] anc_prob = seen_resolved_clades[current_split.ancestor][0] if len(current_split.descendant.clade) == 1: leaf_observations = sum( [clade_count_map[k] for k in leaf_clades if k.clade == current_split.descendant.clade] ) # current_split.descendant, [k for k in clade_count_map if k.clade == current_split.descendant.clade] desc_prob = clade_count_map[current_split.descendant] / leaf_observations assert 0 <= desc_prob <= 1.0, "Prob failure2..." elif not current_split.descendant.clade: desc_prob = 1 else: desc_prob = seen_resolved_clades[current_split.descendant][0] split_probability = anc_prob * desc_prob * ( clade_split_count_map[current_clade][current_split] / clade_count_map[current_clade] ) if current_clade in seen_resolved_clades: if seen_resolved_clades[current_clade][0] < split_probability: seen_resolved_clades[current_clade] = (split_probability, current_split, False) elif seen_resolved_clades[current_clade][0] == split_probability: # Tie breaking randomly, either keep the old or pick the new import random if random.random() < 0.5: chosen_prob, chosen_split = split_probability, current_split else: chosen_prob, chosen_split = seen_resolved_clades[current_clade][:2] # resolving tiebreak in the seen_resolved_clades, True indicates the tiebreak seen_resolved_clades[current_clade] = (chosen_prob, chosen_split, True) else: seen_resolved_clades[current_clade] = (split_probability, current_split, False) # End of seen_resolved_clades construction return seen_resolved_clades