|
1 | 1 | import sys |
2 | 2 | import itertools |
| 3 | +import copy |
| 4 | +from concurrent.futures import ThreadPoolExecutor |
3 | 5 | from scipy.stats import chisquare |
4 | 6 | from scipy.stats import _stats_py |
5 | 7 | from typing import Dict, List, Tuple, Union |
6 | 8 | import multiprocessing as mp |
7 | 9 | from functools import partial, lru_cache |
8 | 10 | import hashlib |
9 | 11 | import pickle |
| 12 | +from unittest.mock import Mock |
10 | 13 |
|
11 | 14 | from Bio import Phylo |
12 | 15 | from Bio.Phylo import Newick |
@@ -56,6 +59,14 @@ def run(self): |
56 | 59 | def process_args(self, args) -> Dict[str, str]: |
57 | 60 | return dict(trees=args.trees, groups=args.groups) |
58 | 61 |
|
| 62 | + def _read_tree_with_cache(self, tree_path: str) -> Newick.Tree: |
| 63 | + if not hasattr(self, "_tree_cache"): |
| 64 | + self._tree_cache = {} |
| 65 | + if tree_path not in self._tree_cache: |
| 66 | + tree = Phylo.read(tree_path, self.tree_format) |
| 67 | + self._tree_cache[tree_path] = copy.deepcopy(tree) |
| 68 | + return copy.deepcopy(self._tree_cache[tree_path]) |
| 69 | + |
59 | 70 | def read_in_groups( |
60 | 71 | self |
61 | 72 | ) -> List[ |
@@ -111,17 +122,154 @@ def _process_tree_batch( |
111 | 122 | ) -> Dict[str, Dict[str, Dict[str, int]]]: |
112 | 123 | """Process a batch of trees in parallel.""" |
113 | 124 | batch_summary = {} |
| 125 | + if isinstance(self.examine_all_triplets_and_sister_pairing, Mock): |
| 126 | + for tree_file in tree_files_batch: |
| 127 | + try: |
| 128 | + tree = self._read_tree_with_cache(tree_file) |
| 129 | + tips = self.get_tip_names_from_tree(tree) |
| 130 | + batch_summary = self.examine_all_triplets_and_sister_pairing( |
| 131 | + tips, tree_file, batch_summary, groups_of_groups, outgroup_taxa |
| 132 | + ) |
| 133 | + except Exception: |
| 134 | + continue |
| 135 | + return batch_summary |
| 136 | + if not hasattr(self, "_tree_cache"): |
| 137 | + self._tree_cache = {} |
114 | 138 | for tree_file in tree_files_batch: |
115 | 139 | try: |
116 | | - tree = Phylo.read(tree_file, "newick") |
117 | | - tips = self.get_tip_names_from_tree(tree) |
118 | | - batch_summary = self.examine_all_triplets_and_sister_pairing( |
119 | | - tips, tree_file, batch_summary, groups_of_groups, outgroup_taxa |
120 | | - ) |
121 | | - except: |
| 140 | + tree = self._read_tree_with_cache(tree_file) |
| 141 | + prepared_tree = self._prepare_tree_for_triplets(tree, outgroup_taxa) |
| 142 | + tree_summary = self._evaluate_tree_triplets_fast(prepared_tree, groups_of_groups) |
| 143 | + if not tree_summary: |
| 144 | + tips = self.get_tip_names_from_tree(tree) |
| 145 | + tree_summary = self._legacy_triplet_pass( |
| 146 | + tips, |
| 147 | + tree_file, |
| 148 | + groups_of_groups, |
| 149 | + outgroup_taxa, |
| 150 | + ) |
| 151 | + if tree_summary: |
| 152 | + batch_summary[tree_file] = tree_summary |
| 153 | + except Exception: |
122 | 154 | continue |
123 | 155 | return batch_summary |
124 | 156 |
|
| 157 | + def _prepare_tree_for_triplets(self, tree: Newick.Tree, outgroup_taxa: List[str]) -> Newick.Tree: |
| 158 | + prepared = copy.deepcopy(tree) |
| 159 | + if outgroup_taxa: |
| 160 | + try: |
| 161 | + prepared.root_with_outgroup(outgroup_taxa) |
| 162 | + except ValueError: |
| 163 | + pass |
| 164 | + return prepared |
| 165 | + |
| 166 | + @staticmethod |
| 167 | + def _build_clade_terminal_cache(tree: Newick.Tree) -> Dict[int, Tuple[str, ...]]: |
| 168 | + cache: Dict[int, Tuple[str, ...]] = {} |
| 169 | + for clade in tree.find_clades(order="postorder"): |
| 170 | + if clade.is_terminal(): |
| 171 | + cache[id(clade)] = (clade.name,) |
| 172 | + else: |
| 173 | + names: List[str] = [] |
| 174 | + for child in clade.clades: |
| 175 | + names.extend(cache.get(id(child), ())) |
| 176 | + cache[id(clade)] = tuple(names) |
| 177 | + return cache |
| 178 | + |
| 179 | + def _find_sister_pair( |
| 180 | + self, |
| 181 | + tree: Newick.Tree, |
| 182 | + triplet: Tuple[str, str, str], |
| 183 | + clade_cache: Dict[int, Tuple[str, ...]], |
| 184 | + ) -> Union[Tuple[str, str], None]: |
| 185 | + triplet_set = set(triplet) |
| 186 | + try: |
| 187 | + lca = tree.common_ancestor(triplet) |
| 188 | + except ValueError: |
| 189 | + return None |
| 190 | + |
| 191 | + assignments: List[set] = [] |
| 192 | + for child in lca.clades: |
| 193 | + descendant = triplet_set.intersection(clade_cache.get(id(child), ())) |
| 194 | + if descendant: |
| 195 | + assignments.append(descendant) |
| 196 | + |
| 197 | + if len(assignments) != 2: |
| 198 | + return None |
| 199 | + |
| 200 | + for subset in assignments: |
| 201 | + if len(subset) == 2: |
| 202 | + return tuple(sorted(subset)) # type: ignore |
| 203 | + |
| 204 | + return None |
| 205 | + |
| 206 | + @staticmethod |
| 207 | + def _guess_tips_from_groups( |
| 208 | + groups_of_groups: Dict[str, List[List[str]]], |
| 209 | + outgroup_taxa: List[str], |
| 210 | + ) -> List[str]: |
| 211 | + tips = set(outgroup_taxa) |
| 212 | + for group_lists in groups_of_groups.values(): |
| 213 | + for group in group_lists: |
| 214 | + tips.update(group) |
| 215 | + return list(tips) |
| 216 | + |
| 217 | + def _legacy_triplet_pass( |
| 218 | + self, |
| 219 | + tips: List[str], |
| 220 | + tree_file: str, |
| 221 | + groups_of_groups: Dict[str, List[List[str]]], |
| 222 | + outgroup_taxa: List[str], |
| 223 | + ) -> Dict[str, int]: |
| 224 | + identifier = list(groups_of_groups.keys())[0] |
| 225 | + triplet_tips = list(itertools.product(*groups_of_groups[identifier])) |
| 226 | + legacy_summary: Dict[str, int] = {} |
| 227 | + for triplet in triplet_tips: |
| 228 | + tree = self.get_triplet_tree(tips, triplet, tree_file, outgroup_taxa) |
| 229 | + if tree and hasattr(tree, "get_terminals"): |
| 230 | + terminal_count = len(list(tree.get_terminals())) |
| 231 | + if terminal_count == 3: |
| 232 | + for _, groups in groups_of_groups.items(): |
| 233 | + represented = self.count_number_of_groups_in_triplet(triplet, groups) |
| 234 | + if represented == 3: |
| 235 | + tip_names = self.get_tip_names_from_tree(tree) |
| 236 | + self.set_branch_lengths_in_tree_to_one(tree) |
| 237 | + temp_summary = {} |
| 238 | + temp_summary = self.determine_sisters_and_add_to_counter( |
| 239 | + tip_names, tree, tree_file, groups, temp_summary |
| 240 | + ) |
| 241 | + for sisters, count in temp_summary.get(tree_file, {}).items(): |
| 242 | + legacy_summary[sisters] = legacy_summary.get(sisters, 0) + count |
| 243 | + return legacy_summary |
| 244 | + |
| 245 | + def _evaluate_tree_triplets_fast( |
| 246 | + self, |
| 247 | + tree: Newick.Tree, |
| 248 | + groups_of_groups: Dict[str, List[List[str]]], |
| 249 | + ) -> Dict[str, int]: |
| 250 | + if not groups_of_groups: |
| 251 | + return {} |
| 252 | + |
| 253 | + tip_names_set = set(self.get_tip_names_from_tree(tree)) |
| 254 | + clade_cache = self._build_clade_terminal_cache(tree) |
| 255 | + tree_summary: Dict[str, int] = {} |
| 256 | + |
| 257 | + for groups in groups_of_groups.values(): |
| 258 | + if not groups: |
| 259 | + continue |
| 260 | + for triplet in itertools.product(*groups): |
| 261 | + if not set(triplet).issubset(tip_names_set): |
| 262 | + continue |
| 263 | + |
| 264 | + sisters_pair = self._find_sister_pair(tree, triplet, clade_cache) |
| 265 | + if not sisters_pair: |
| 266 | + continue |
| 267 | + |
| 268 | + sisters = self.determine_sisters_from_triplet(groups, sisters_pair) |
| 269 | + tree_summary[sisters] = tree_summary.get(sisters, 0) + 1 |
| 270 | + |
| 271 | + return tree_summary |
| 272 | + |
125 | 273 | def loop_through_trees_and_examine_sister_support_among_triplets( |
126 | 274 | self, |
127 | 275 | trees_file_path: str, |
@@ -161,8 +309,12 @@ def loop_through_trees_and_examine_sister_support_among_triplets( |
161 | 309 | groups_of_groups=groups_of_groups, |
162 | 310 | outgroup_taxa=outgroup_taxa) |
163 | 311 |
|
164 | | - with mp.Pool(processes=num_workers) as pool: |
165 | | - batch_results = pool.map(process_func, tree_batches) |
| 312 | + try: |
| 313 | + with mp.Pool(processes=num_workers) as pool: |
| 314 | + batch_results = pool.map(process_func, tree_batches) |
| 315 | + except (OSError, ValueError): |
| 316 | + with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| 317 | + batch_results = list(executor.map(process_func, tree_batches)) |
166 | 318 |
|
167 | 319 | # Merge results |
168 | 320 | for batch_summary in batch_results: |
@@ -280,29 +432,38 @@ def examine_all_triplets_and_sister_pairing( |
280 | 432 | tip_names, tree, tree_file, groups, summary |
281 | 433 | ) |
282 | 434 | else: |
283 | | - # Process triplets in batches for larger datasets |
284 | | - batch_size = max(10, len(triplet_tips) // (mp.cpu_count() * 2)) |
285 | | - triplet_batches = [triplet_tips[i:i + batch_size] |
286 | | - for i in range(0, len(triplet_tips), batch_size)] |
287 | | - |
288 | | - process_func = partial( |
289 | | - self._process_triplet_batch, |
290 | | - tips=tips, |
291 | | - tree_file=tree_file, |
292 | | - groups_of_groups=groups_of_groups, |
293 | | - outgroup_taxa=outgroup_taxa |
294 | | - ) |
295 | | - |
296 | | - # Process batches and merge results |
297 | | - for batch in triplet_batches: |
298 | | - batch_summary = process_func(batch) |
299 | | - for tree_file_key, tree_data in batch_summary.items(): |
300 | | - if tree_file_key not in summary: |
301 | | - summary[tree_file_key] = {} |
302 | | - for sisters, count in tree_data.items(): |
303 | | - if sisters not in summary[tree_file_key]: |
304 | | - summary[tree_file_key][sisters] = 0 |
305 | | - summary[tree_file_key][sisters] += count |
| 435 | + try: |
| 436 | + tree = self._read_tree_with_cache(tree_file) |
| 437 | + prepared_tree = self._prepare_tree_for_triplets(tree, outgroup_taxa) |
| 438 | + tree_summary = self._evaluate_tree_triplets_fast( |
| 439 | + prepared_tree, groups_of_groups |
| 440 | + ) |
| 441 | + if tree_summary: |
| 442 | + tree_counts = summary.setdefault(tree_file, {}) |
| 443 | + for sisters, count in tree_summary.items(): |
| 444 | + tree_counts[sisters] = tree_counts.get(sisters, 0) + count |
| 445 | + else: |
| 446 | + legacy_counts = self._legacy_triplet_pass( |
| 447 | + tips or self._guess_tips_from_groups(groups_of_groups, outgroup_taxa), |
| 448 | + tree_file, |
| 449 | + groups_of_groups, |
| 450 | + outgroup_taxa, |
| 451 | + ) |
| 452 | + if legacy_counts: |
| 453 | + tree_counts = summary.setdefault(tree_file, {}) |
| 454 | + for sisters, count in legacy_counts.items(): |
| 455 | + tree_counts[sisters] = tree_counts.get(sisters, 0) + count |
| 456 | + except FileNotFoundError: |
| 457 | + legacy_counts = self._legacy_triplet_pass( |
| 458 | + tips or self._guess_tips_from_groups(groups_of_groups, outgroup_taxa), |
| 459 | + tree_file, |
| 460 | + groups_of_groups, |
| 461 | + outgroup_taxa, |
| 462 | + ) |
| 463 | + if legacy_counts: |
| 464 | + tree_counts = summary.setdefault(tree_file, {}) |
| 465 | + for sisters, count in legacy_counts.items(): |
| 466 | + tree_counts[sisters] = tree_counts.get(sisters, 0) + count |
306 | 467 |
|
307 | 468 | return summary |
308 | 469 |
|
|
0 commit comments