Skip to content

Commit d7c36a8

Browse files
authored
Merge pull request #1065 from opencobra/update/multiprocessing
refactor: multiprocessing.Pool usage
2 parents 49c374a + 87be67b commit d7c36a8

File tree

3 files changed

+16
-24
lines changed

3 files changed

+16
-24
lines changed

src/cobra/flux_analysis/deletion.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# -*- coding: utf-8 -*-
22

33
import logging
4-
import multiprocessing
54
from builtins import dict, map
65
from functools import partial
76
from itertools import product
7+
from multiprocessing import Pool
88
from typing import List, Set, Union
99

1010
import pandas as pd
@@ -149,14 +149,11 @@ def extract_knockout_results(result_iter):
149149
gene=_gene_deletion_worker, reaction=_reaction_deletion_worker
150150
)[entity]
151151
chunk_size = len(args) // processes
152-
pool = multiprocessing.Pool(
153-
processes, initializer=_init_worker, initargs=(model,)
154-
)
155-
results = extract_knockout_results(
156-
pool.imap_unordered(worker, args, chunksize=chunk_size)
157-
)
158-
pool.close()
159-
pool.join()
152+
153+
with Pool(processes, initializer=_init_worker, initargs=(model,)) as pool:
154+
results = extract_knockout_results(
155+
pool.imap_unordered(worker, args, chunksize=chunk_size)
156+
)
160157
else:
161158
worker = dict(gene=_gene_deletion, reaction=_reaction_deletion)[entity]
162159
results = extract_knockout_results(map(partial(worker, model), args))

src/cobra/flux_analysis/variability.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from __future__ import absolute_import
44

55
import logging
6-
import multiprocessing
76
from builtins import map
7+
from multiprocessing import Pool
88
from warnings import warn
99

1010
from numpy import zeros
@@ -208,17 +208,16 @@ def flux_variability_analysis(
208208
# objective direction for all reactions. This creates a
209209
# slight overhead but seems the most clean.
210210
chunk_size = len(reaction_ids) // processes
211-
pool = multiprocessing.Pool(
211+
212+
with Pool(
212213
processes,
213214
initializer=_init_worker,
214215
initargs=(model, loopless, what[:3]),
215-
)
216-
for rxn_id, value in pool.imap_unordered(
217-
_fva_step, reaction_ids, chunksize=chunk_size
218-
):
219-
fva_result.at[rxn_id, what] = value
220-
pool.close()
221-
pool.join()
216+
) as pool:
217+
for rxn_id, value in pool.imap_unordered(
218+
_fva_step, reaction_ids, chunksize=chunk_size
219+
):
220+
fva_result.at[rxn_id, what] = value
222221
else:
223222
_init_worker(model, loopless, what[:3])
224223
for rxn_id, value in map(_fva_step, reaction_ids):

src/cobra/sampling/optgp.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,8 @@ def sample(self, n: int, fluxes: bool = True) -> pd.DataFrame:
214214
# limit errors, something weird going on with multiprocessing
215215
args = list(zip([n_process] * self.processes, range(self.processes)))
216216

217-
# No with statement or starmap here since Python 2.x
218-
# does not support it :(
219-
mp = Pool(self.processes, initializer=mp_init, initargs=(self,))
220-
results = mp.map(_sample_chain, args, chunksize=1)
221-
mp.close()
222-
mp.join()
217+
with Pool(self.processes, initializer=mp_init, initargs=(self,)) as pool:
218+
results = pool.map(_sample_chain, args, chunksize=1)
223219

224220
chains = np.vstack([r[1] for r in results])
225221
self.retries += sum(r[0] for r in results)

0 commit comments

Comments
 (0)