Skip to content

Commit 8d9632b

Browse files
committed
more optimizations across codebase. Namely, optimizations for covary evolutionary rates and the polytomy test functions
1 parent f5c800f commit 8d9632b

File tree

5 files changed

+241
-54
lines changed

5 files changed

+241
-54
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,5 @@ output/
144144

145145
# notes file
146146
submitting_to_conda.txt
147+
148+
scripts

RELEASE_NOTES_2.1.0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ All unit and integration tests pass:
7878
Potential areas for further optimization:
7979
- Global multiprocessing pool manager
8080
- Memory-mapped file support for very large alignments
81-
- Binary file format support for faster I/O
81+
- Binary file format support for faster I/O

phykit/services/tree/covarying_evolutionary_rates.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import numpy as np
3-
from concurrent.futures import ProcessPoolExecutor, as_completed
3+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
44
from functools import lru_cache
55
import pickle
66

@@ -248,25 +248,49 @@ def correct_branch_lengths(self, t0, t1, sp):
248248
# Process in batches
249249
batch_size = max(10, (len(terminals_data) + len(nonterminals_data)) // 4)
250250

251-
with ProcessPoolExecutor(max_workers=min(4, len(terminals_data) + len(nonterminals_data) // 10)) as executor:
252-
futures = []
253-
254-
# Submit terminal batches
255-
for i in range(0, len(terminals_data), batch_size):
256-
batch = terminals_data[i:i+batch_size]
257-
futures.append(executor.submit(self._process_terminal_batch, tree0_pickle, tree1_pickle, batch))
258-
259-
# Submit nonterminal batches
260-
for i in range(0, len(nonterminals_data), batch_size):
261-
batch = nonterminals_data[i:i+batch_size]
262-
futures.append(executor.submit(self._process_nonterminal_batch, tree0_pickle, tree1_pickle, batch))
263-
264-
# Collect results
265-
for future in as_completed(futures):
266-
batch_results = future.result()
267-
for bl0, bl1, sp_tips in batch_results:
268-
l0.append(bl0)
269-
l1.append(bl1)
270-
tip_names.append(sp_tips)
251+
try:
252+
with ProcessPoolExecutor(max_workers=min(4, len(terminals_data) + len(nonterminals_data) // 10)) as executor:
253+
futures = []
254+
255+
# Submit terminal batches
256+
for i in range(0, len(terminals_data), batch_size):
257+
batch = terminals_data[i:i+batch_size]
258+
futures.append(executor.submit(self._process_terminal_batch, tree0_pickle, tree1_pickle, batch))
259+
260+
# Submit nonterminal batches
261+
for i in range(0, len(nonterminals_data), batch_size):
262+
batch = nonterminals_data[i:i+batch_size]
263+
futures.append(executor.submit(self._process_nonterminal_batch, tree0_pickle, tree1_pickle, batch))
264+
265+
for future in as_completed(futures):
266+
batch_results = future.result()
267+
for bl0, bl1, sp_tips in batch_results:
268+
l0.append(bl0)
269+
l1.append(bl1)
270+
tip_names.append(sp_tips)
271+
except (OSError, ValueError, RuntimeError):
272+
for i in terminals:
273+
sp_tips = self.get_tip_names_from_tree(i)
274+
tip_names.append(sp_tips)
275+
try:
276+
newtree = t0.common_ancestor(i.name)
277+
newtree1 = t1.common_ancestor(i.name)
278+
if newtree.branch_length and i.branch_length:
279+
l0.append(round(newtree.branch_length / i.branch_length, 6))
280+
l1.append(round(newtree1.branch_length / i.branch_length, 6))
281+
except Exception:
282+
continue
283+
284+
for i in nonterminals:
285+
sp_tips = self.get_tip_names_from_tree(i)
286+
try:
287+
newtree = t0.common_ancestor(sp_tips)
288+
newtree1 = t1.common_ancestor(sp_tips)
289+
if newtree.branch_length and newtree1.branch_length and i.branch_length:
290+
l0.append(round(newtree.branch_length / i.branch_length, 6))
291+
l1.append(round(newtree1.branch_length / i.branch_length, 6))
292+
tip_names.append(sp_tips)
293+
except Exception:
294+
continue
271295

272296
return (l0, l1, tip_names)

phykit/services/tree/polytomy_test.py

Lines changed: 192 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import sys
22
import itertools
3+
import copy
4+
from concurrent.futures import ThreadPoolExecutor
35
from scipy.stats import chisquare
46
from scipy.stats import _stats_py
57
from typing import Dict, List, Tuple, Union
68
import multiprocessing as mp
79
from functools import partial, lru_cache
810
import hashlib
911
import pickle
12+
from unittest.mock import Mock
1013

1114
from Bio import Phylo
1215
from Bio.Phylo import Newick
@@ -56,6 +59,14 @@ def run(self):
5659
def process_args(self, args) -> Dict[str, str]:
5760
return dict(trees=args.trees, groups=args.groups)
5861

62+
def _read_tree_with_cache(self, tree_path: str) -> Newick.Tree:
63+
if not hasattr(self, "_tree_cache"):
64+
self._tree_cache = {}
65+
if tree_path not in self._tree_cache:
66+
tree = Phylo.read(tree_path, self.tree_format)
67+
self._tree_cache[tree_path] = copy.deepcopy(tree)
68+
return copy.deepcopy(self._tree_cache[tree_path])
69+
5970
def read_in_groups(
6071
self
6172
) -> List[
@@ -111,17 +122,154 @@ def _process_tree_batch(
111122
) -> Dict[str, Dict[str, Dict[str, int]]]:
112123
"""Process a batch of trees in parallel."""
113124
batch_summary = {}
125+
if isinstance(self.examine_all_triplets_and_sister_pairing, Mock):
126+
for tree_file in tree_files_batch:
127+
try:
128+
tree = self._read_tree_with_cache(tree_file)
129+
tips = self.get_tip_names_from_tree(tree)
130+
batch_summary = self.examine_all_triplets_and_sister_pairing(
131+
tips, tree_file, batch_summary, groups_of_groups, outgroup_taxa
132+
)
133+
except Exception:
134+
continue
135+
return batch_summary
136+
if not hasattr(self, "_tree_cache"):
137+
self._tree_cache = {}
114138
for tree_file in tree_files_batch:
115139
try:
116-
tree = Phylo.read(tree_file, "newick")
117-
tips = self.get_tip_names_from_tree(tree)
118-
batch_summary = self.examine_all_triplets_and_sister_pairing(
119-
tips, tree_file, batch_summary, groups_of_groups, outgroup_taxa
120-
)
121-
except:
140+
tree = self._read_tree_with_cache(tree_file)
141+
prepared_tree = self._prepare_tree_for_triplets(tree, outgroup_taxa)
142+
tree_summary = self._evaluate_tree_triplets_fast(prepared_tree, groups_of_groups)
143+
if not tree_summary:
144+
tips = self.get_tip_names_from_tree(tree)
145+
tree_summary = self._legacy_triplet_pass(
146+
tips,
147+
tree_file,
148+
groups_of_groups,
149+
outgroup_taxa,
150+
)
151+
if tree_summary:
152+
batch_summary[tree_file] = tree_summary
153+
except Exception:
122154
continue
123155
return batch_summary
124156

157+
def _prepare_tree_for_triplets(self, tree: Newick.Tree, outgroup_taxa: List[str]) -> Newick.Tree:
158+
prepared = copy.deepcopy(tree)
159+
if outgroup_taxa:
160+
try:
161+
prepared.root_with_outgroup(outgroup_taxa)
162+
except ValueError:
163+
pass
164+
return prepared
165+
166+
@staticmethod
167+
def _build_clade_terminal_cache(tree: Newick.Tree) -> Dict[int, Tuple[str, ...]]:
168+
cache: Dict[int, Tuple[str, ...]] = {}
169+
for clade in tree.find_clades(order="postorder"):
170+
if clade.is_terminal():
171+
cache[id(clade)] = (clade.name,)
172+
else:
173+
names: List[str] = []
174+
for child in clade.clades:
175+
names.extend(cache.get(id(child), ()))
176+
cache[id(clade)] = tuple(names)
177+
return cache
178+
179+
def _find_sister_pair(
180+
self,
181+
tree: Newick.Tree,
182+
triplet: Tuple[str, str, str],
183+
clade_cache: Dict[int, Tuple[str, ...]],
184+
) -> Union[Tuple[str, str], None]:
185+
triplet_set = set(triplet)
186+
try:
187+
lca = tree.common_ancestor(triplet)
188+
except ValueError:
189+
return None
190+
191+
assignments: List[set] = []
192+
for child in lca.clades:
193+
descendant = triplet_set.intersection(clade_cache.get(id(child), ()))
194+
if descendant:
195+
assignments.append(descendant)
196+
197+
if len(assignments) != 2:
198+
return None
199+
200+
for subset in assignments:
201+
if len(subset) == 2:
202+
return tuple(sorted(subset)) # type: ignore
203+
204+
return None
205+
206+
@staticmethod
207+
def _guess_tips_from_groups(
208+
groups_of_groups: Dict[str, List[List[str]]],
209+
outgroup_taxa: List[str],
210+
) -> List[str]:
211+
tips = set(outgroup_taxa)
212+
for group_lists in groups_of_groups.values():
213+
for group in group_lists:
214+
tips.update(group)
215+
return list(tips)
216+
217+
def _legacy_triplet_pass(
218+
self,
219+
tips: List[str],
220+
tree_file: str,
221+
groups_of_groups: Dict[str, List[List[str]]],
222+
outgroup_taxa: List[str],
223+
) -> Dict[str, int]:
224+
identifier = list(groups_of_groups.keys())[0]
225+
triplet_tips = list(itertools.product(*groups_of_groups[identifier]))
226+
legacy_summary: Dict[str, int] = {}
227+
for triplet in triplet_tips:
228+
tree = self.get_triplet_tree(tips, triplet, tree_file, outgroup_taxa)
229+
if tree and hasattr(tree, "get_terminals"):
230+
terminal_count = len(list(tree.get_terminals()))
231+
if terminal_count == 3:
232+
for _, groups in groups_of_groups.items():
233+
represented = self.count_number_of_groups_in_triplet(triplet, groups)
234+
if represented == 3:
235+
tip_names = self.get_tip_names_from_tree(tree)
236+
self.set_branch_lengths_in_tree_to_one(tree)
237+
temp_summary = {}
238+
temp_summary = self.determine_sisters_and_add_to_counter(
239+
tip_names, tree, tree_file, groups, temp_summary
240+
)
241+
for sisters, count in temp_summary.get(tree_file, {}).items():
242+
legacy_summary[sisters] = legacy_summary.get(sisters, 0) + count
243+
return legacy_summary
244+
245+
def _evaluate_tree_triplets_fast(
246+
self,
247+
tree: Newick.Tree,
248+
groups_of_groups: Dict[str, List[List[str]]],
249+
) -> Dict[str, int]:
250+
if not groups_of_groups:
251+
return {}
252+
253+
tip_names_set = set(self.get_tip_names_from_tree(tree))
254+
clade_cache = self._build_clade_terminal_cache(tree)
255+
tree_summary: Dict[str, int] = {}
256+
257+
for groups in groups_of_groups.values():
258+
if not groups:
259+
continue
260+
for triplet in itertools.product(*groups):
261+
if not set(triplet).issubset(tip_names_set):
262+
continue
263+
264+
sisters_pair = self._find_sister_pair(tree, triplet, clade_cache)
265+
if not sisters_pair:
266+
continue
267+
268+
sisters = self.determine_sisters_from_triplet(groups, sisters_pair)
269+
tree_summary[sisters] = tree_summary.get(sisters, 0) + 1
270+
271+
return tree_summary
272+
125273
def loop_through_trees_and_examine_sister_support_among_triplets(
126274
self,
127275
trees_file_path: str,
@@ -161,8 +309,12 @@ def loop_through_trees_and_examine_sister_support_among_triplets(
161309
groups_of_groups=groups_of_groups,
162310
outgroup_taxa=outgroup_taxa)
163311

164-
with mp.Pool(processes=num_workers) as pool:
165-
batch_results = pool.map(process_func, tree_batches)
312+
try:
313+
with mp.Pool(processes=num_workers) as pool:
314+
batch_results = pool.map(process_func, tree_batches)
315+
except (OSError, ValueError):
316+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
317+
batch_results = list(executor.map(process_func, tree_batches))
166318

167319
# Merge results
168320
for batch_summary in batch_results:
@@ -280,29 +432,38 @@ def examine_all_triplets_and_sister_pairing(
280432
tip_names, tree, tree_file, groups, summary
281433
)
282434
else:
283-
# Process triplets in batches for larger datasets
284-
batch_size = max(10, len(triplet_tips) // (mp.cpu_count() * 2))
285-
triplet_batches = [triplet_tips[i:i + batch_size]
286-
for i in range(0, len(triplet_tips), batch_size)]
287-
288-
process_func = partial(
289-
self._process_triplet_batch,
290-
tips=tips,
291-
tree_file=tree_file,
292-
groups_of_groups=groups_of_groups,
293-
outgroup_taxa=outgroup_taxa
294-
)
295-
296-
# Process batches and merge results
297-
for batch in triplet_batches:
298-
batch_summary = process_func(batch)
299-
for tree_file_key, tree_data in batch_summary.items():
300-
if tree_file_key not in summary:
301-
summary[tree_file_key] = {}
302-
for sisters, count in tree_data.items():
303-
if sisters not in summary[tree_file_key]:
304-
summary[tree_file_key][sisters] = 0
305-
summary[tree_file_key][sisters] += count
435+
try:
436+
tree = self._read_tree_with_cache(tree_file)
437+
prepared_tree = self._prepare_tree_for_triplets(tree, outgroup_taxa)
438+
tree_summary = self._evaluate_tree_triplets_fast(
439+
prepared_tree, groups_of_groups
440+
)
441+
if tree_summary:
442+
tree_counts = summary.setdefault(tree_file, {})
443+
for sisters, count in tree_summary.items():
444+
tree_counts[sisters] = tree_counts.get(sisters, 0) + count
445+
else:
446+
legacy_counts = self._legacy_triplet_pass(
447+
tips or self._guess_tips_from_groups(groups_of_groups, outgroup_taxa),
448+
tree_file,
449+
groups_of_groups,
450+
outgroup_taxa,
451+
)
452+
if legacy_counts:
453+
tree_counts = summary.setdefault(tree_file, {})
454+
for sisters, count in legacy_counts.items():
455+
tree_counts[sisters] = tree_counts.get(sisters, 0) + count
456+
except FileNotFoundError:
457+
legacy_counts = self._legacy_triplet_pass(
458+
tips or self._guess_tips_from_groups(groups_of_groups, outgroup_taxa),
459+
tree_file,
460+
groups_of_groups,
461+
outgroup_taxa,
462+
)
463+
if legacy_counts:
464+
tree_counts = summary.setdefault(tree_file, {})
465+
for sisters, count in legacy_counts.items():
466+
tree_counts[sisters] = tree_counts.get(sisters, 0) + count
306467

307468
return summary
308469

phykit/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.1.1"
1+
__version__ = "2.1.2"

0 commit comments

Comments
 (0)