"""
This module reads nexus tree files that are created by BEAST2.
The function read_nexus_trees() has options to make it compatible with trees
that were generated by the BREATH package and read transmission trees.
"""
import re
from collections import defaultdict
from brokilon.core.tree import Tree
[docs]
def count_trees_in_nexus(file) -> int:
re_tree = re.compile("\t?tree .*=? (.*$)", flags=re.I | re.MULTILINE)
n = 0
with open(file, 'r', encoding="utf-8") as f:
for line in f:
if re.match(re_tree, line):
n += 1
return n
[docs]
def read_nexus_trees(file, parse_taxon_map: bool = False, burn_in: float = 0.0) \
-> list[Tree] | tuple[list[Tree], dict[str, str]]:
"""
Function to read a nexus file that contains transmission trees.
This assumes that trees are generated by BREATH BEAST2 package.
The necessary information is the blockcount for every node/edge.
This function will fully label the transmission history onto the tree.
By setting breath_trees to false the transmission history labeling is
also disabled.
:param file: Input file
:param parse_taxon_map: If true returns tuple of trees and taxon map
:param burn_in: Fraction of trees at beginning of file to discard
:returns: list of trees with metadata annotation, optionally also a taxon map as dict
"""
# check that burn_in is proper value
if not 0.0 <= burn_in < 1.0:
raise ValueError("burn_in must be between 0 and 1")
total_trees = count_trees_in_nexus(file)
burn_in_end = int(burn_in * total_trees)
if total_trees - burn_in_end == 0:
raise ValueError(f"No trees left after burn-in, reduce value of burn_in.")
# re_tree returns nwk string without the root height and no ; in the end
re_tree = re.compile("\t?tree .*=? (.*$)", flags=re.I | re.MULTILINE)
re_begin_map = re.compile('\t?translate\n', re.I)
re_end_map = re.compile('\t*?;\n?')
# name_dict = get_mapping_dict(file) # Save tree label names in dict
begin_taxon_map = False
taxon_map = {}
trees = []
tree_index = 0
with open(file, 'r', encoding="utf-8") as f:
for line in f:
if parse_taxon_map:
if begin_taxon_map:
if re_end_map.match(line):
begin_taxon_map = False
continue
split = line.split()
taxon_map[int(split[0])] = split[1][:-1] if split[1][-1] == "," else split[1]
if re_begin_map.match(line):
begin_taxon_map = True
# if re_tree.match(line):
m = re_tree.match(line)
if m:
if tree_index < burn_in_end:
tree_index += 1
continue
# tree_string = f'{re.split(re_tree, line)[1][:re.split(re_tree, line)[1]
# .rfind(")") + 1]};'
# tree_string = re.split(re_tree, line)[1]
tree_string = m.group(1)
meta_data_pattern = r"\d*\[[^\]]*\]" # matches meta data annotations
counter = 0 # Initialize a counter
node_meta_data = defaultdict(dict)
def extract_meta_data_from_match(match):
nonlocal counter
nonlocal node_meta_data
full_str = match.group(0)
# temporary fix to the inconsistent use of . and _ in beast treeannotator....
# todo the type here can be any string i believe so I might have to change it
# further....
for old, new in [("type.prob", "type_prob"),
("type.set.prob", "type_set_prob"),
("type.set", "type_set"),]:
full_str = full_str.replace(old, new)
# replacing the Sequence annotations that are present in BREATH trees...
some_data_annotation = r"(\.[A-Za-z_]+:[A-Za-z]+)[_=]"
full_str = re.sub(some_data_annotation, lambda m: m.group(0).replace(m.group(1), ""),
full_str)
split_str = full_str.split("[", 1)
if not split_str[0]:
counter += 1
taxa = split_str[0] or f"internal{counter}"
test_meta_string = split_str[1].replace("&", "").replace("]", "")
metadata_pattern = re.compile(
r"""
,?
(?P<key>[^.=]+)
(?:\.[^=]*)?=
(?:
{(?P<vals>[^}]+)}
|
(?P<val>[^, ]+)
)
""",
re.VERBOSE
)
# metadata = {}
for m in metadata_pattern.finditer(test_meta_string):
key = m.group("key")
if key in ["type_set", "type_set_prob"]:
if m.group("vals"):
vals = m.group("vals").strip("{}\'\"").split(",")
node_meta_data[taxa][key] = vals
elif m.group("val"):
vals = m.group("val").strip("{}\'\"").split(",")
node_meta_data[taxa][key] = vals
else:
raise ValueError("No metadata found!")
else:
if m.group("val"):
# metadata[key] = float(m.group("val").strip())
value = m.group("val").strip().strip('\"\'')
try:
value = float(value)
except ValueError:
pass
node_meta_data[taxa][key] = value
elif m.group("vals"):
cur_vals = [x for x in m.group("vals").split(",")]
try:
cur_vals = [float(x) for x in cur_vals]
except ValueError:
pass
node_meta_data[taxa][key] = tuple(cur_vals)
else:
raise ValueError("No metadata found!")
return taxa
# Replace all matches, name internal nodes and extract meta data
sanitized_tree_newick = re.sub(
meta_data_pattern,
extract_meta_data_from_match,
tree_string
)
tree = Tree(sanitized_tree_newick, format=1)
# Annotating meta data to nodes
_meta_label_nodes(tree, node_meta_data)
trees.append(tree)
if parse_taxon_map:
return trees, taxon_map
return trees
[docs]
def _cast_to_int(value):
"""
Custom function to cast a value to an integer. Fallbacks are either float or the original value.
:param value: Input value to cast
:return: Integer value or float or original value
"""
try:
f = float(value)
if f.is_integer():
return int(f)
return f
except ValueError:
return value