Skip to content

Commit 0f5ab4e

Browse files
authored
Merge pull request #1059 from opencobra/refactor/sampling-optgp
refactor: cobra.sampling.optgp
2 parents 0fd39ab + d4904a2 commit 0f5ab4e

File tree

3 files changed

+89
-90
lines changed

3 files changed

+89
-90
lines changed

src/cobra/sampling/optgp.py

Lines changed: 75 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
# -*- coding: utf-8 -*-
2-
3-
"""Provide OptGP sampler."""
4-
5-
from __future__ import absolute_import, division
1+
"""Provide the OptGP sampler class and helper functions."""
62

73
from multiprocessing import Pool
4+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
85

96
import numpy as np
10-
import pandas
7+
import pandas as pd
8+
9+
from ..core.configuration import Configuration
10+
from .core import step
11+
from .hr_sampler import HRSampler, shared_np_array
12+
1113

12-
from cobra.core.configuration import Configuration
13-
from cobra.sampling.core import step
14-
from cobra.sampling.hr_sampler import HRSampler, shared_np_array
14+
if TYPE_CHECKING:
15+
from cobra import Model
16+
from cobra.sampling import OptGPSampler
1517

1618

1719
__all__ = ("OptGPSampler",)
@@ -20,18 +22,18 @@
2022
CONFIGURATION = Configuration()
2123

2224

23-
def mp_init(obj):
25+
def mp_init(obj: "OptGPSampler") -> None:
2426
"""Initialize the multiprocessing pool."""
2527
global sampler
2628
sampler = obj
2729

2830

2931
# Unfortunately this has to be outside the class to be usable with
3032
# multiprocessing :()
31-
def _sample_chain(args):
33+
def _sample_chain(args: Tuple[int, int]) -> Tuple[int, "OptGPSampler"]:
3234
"""Sample a single chain for OptGPSampler.
3335
34-
center and n_samples are updated locally and forgotten afterwards.
36+
`center` and `n_samples` are updated locally and forgotten afterwards.
3537
3638
"""
3739
n, idx = args # has to be this way to work in Python 2.7
@@ -67,59 +69,52 @@ def _sample_chain(args):
6769

6870

6971
class OptGPSampler(HRSampler):
70-
"""A parallel optimized sampler.
72+
"""
73+
Improved Artificial Centering Hit-and-Run sampler.
7174
72-
A parallel sampler with fast convergence and parallel execution. See [1]_
73-
for details.
75+
A parallel sampler with fast convergence and parallel execution.
76+
See [1]_ for details.
7477
7578
Parameters
7679
----------
7780
model : cobra.Model
7881
The cobra model from which to generate samples.
79-
processes: int, optional (default Configuration.processes)
80-
The number of processes used during sampling.
82+
processes: int, optional
83+
The number of processes used during sampling
84+
(default cobra.Configuration.processes).
8185
thinning : int, optional
82-
The thinning factor of the generated sampling chain. A thinning of 10
83-
means samples are returned every 10 steps.
86+
The thinning factor of the generated sampling chain. A thinning of
87+
10 means samples are returned every 10 steps (default 100).
8488
nproj : int > 0, optional
85-
How often to reproject the sampling point into the feasibility space.
86-
Avoids numerical issues at the cost of lower sampling. If you observe
87-
many equality constraint violations with `sampler.validate` you should
88-
lower this number.
89+
How often to reproject the sampling point into the feasibility
90+
space. Avoids numerical issues at the cost of lower sampling. If
91+
you observe many equality constraint violations with
92+
`sampler.validate` you should lower this number (default None).
8993
seed : int > 0, optional
90-
Sets the random number seed. Initialized to the current time stamp if
91-
None.
94+
Sets the random number seed. Initialized to the current time stamp
95+
if None (default None).
9296
9397
Attributes
9498
----------
95-
model : cobra.Model
96-
The cobra model from which the samples get generated.
97-
thinning : int
98-
The currently used thinning factor.
9999
n_samples : int
100100
The total number of samples that have been generated by this
101101
sampler instance.
102-
problem : collections.namedtuple
103-
A python object whose attributes define the entire sampling problem in
104-
matrix form. See docstring of `Problem`.
102+
problem : typing.NamedTuple
103+
A NamedTuple whose attributes define the entire sampling problem in
104+
matrix form.
105105
warmup : numpy.matrix
106-
A matrix of with as many columns as reactions in the model and more
107-
than 3 rows containing a warmup sample in each row. None if no warmup
108-
points have been generated yet.
106+
A numpy matrix with as many columns as reactions in the model and
107+
more than 3 rows containing a warmup sample in each row. None if no
108+
warmup points have been generated yet.
109109
retries : int
110110
The overall of sampling retries the sampler has observed. Larger
111111
values indicate numerical instabilities.
112-
seed : int > 0, optional
113-
Sets the random number seed. Initialized to the current time stamp if
114-
None.
115-
nproj : int
116-
How often to reproject the sampling point into the feasibility space.
117112
fwd_idx : numpy.array
118-
Has one entry for each reaction in the model containing the index of
119-
the respective forward variable.
113+
A numpy array having one entry for each reaction in the model,
114+
containing the index of the respective forward variable.
120115
rev_idx : numpy.array
121-
Has one entry for each reaction in the model containing the index of
122-
the respective reverse variable.
116+
A numpy array having one entry for each reaction in the model,
117+
containing the index of the respective reverse variable.
123118
prev : numpy.array
124119
The current/last flux sample generated.
125120
center : numpy.array
@@ -129,20 +124,20 @@ class OptGPSampler(HRSampler):
129124
Notes
130125
-----
131126
The sampler is very similar to artificial centering where each process
132-
samples its own chain. Initial points are chosen randomly from the warmup
133-
points followed by a linear transformation that pulls the points a little
134-
bit towards the center of the sampling space.
127+
samples its own chain. Initial points are chosen randomly from the
128+
warmup points followed by a linear transformation that pulls the points
129+
a little bit towards the center of the sampling space.
135130
136131
If the number of processes used is larger than the one requested,
137132
number of samples is adjusted to the smallest multiple of the number of
138133
processes larger than the requested sample number. For instance, if you
139-
have 3 processes and request 8 samples you will receive 9.
134+
have 3 processes and request 8 samples, you will receive 9.
140135
141-
Memory usage is roughly in the order of (2 * number reactions)^2
142-
due to the required nullspace matrices and warmup points. So large
143-
models easily take up a few GB of RAM. However, most of the large matrices
144-
are kept in shared memory. So the RAM usage is independent of the number
145-
of processes.
136+
Memory usage is roughly in the order of (2 * number of reactions)^2
137+
due to the required nullspace matrices and warmup points. So, large
138+
models easily take up a few GBs of RAM. However, most of the large
139+
matrices are kept in shared memory. So the RAM usage is independent of
140+
the number of processes.
146141
147142
References
148143
----------
@@ -154,9 +149,17 @@ class OptGPSampler(HRSampler):
154149
155150
"""
156151

157-
def __init__(self, model, processes=None, thinning=100, nproj=None, seed=None):
152+
def __init__(
153+
self,
154+
model: "Model",
155+
thinning: int = 100,
156+
processes: Optional[int] = None,
157+
nproj: Optional[int] = None,
158+
seed: Optional[int] = None,
159+
**kwargs
160+
) -> None:
158161
"""Initialize a new OptGPSampler."""
159-
super(OptGPSampler, self).__init__(model, thinning, seed=seed)
162+
super().__init__(model, thinning, nproj=nproj, seed=seed, *kwargs)
160163
self.generate_fva_warmup()
161164

162165
if processes is None:
@@ -170,37 +173,37 @@ def __init__(self, model, processes=None, thinning=100, nproj=None, seed=None):
170173
(len(self.model.variables),), self.warmup.mean(axis=0)
171174
)
172175

173-
def sample(self, n, fluxes=True):
176+
def sample(self, n: int, fluxes: bool = True) -> pd.DataFrame:
174177
"""Generate a set of samples.
175178
176179
This is the basic sampling function for all hit-and-run samplers.
177180
178181
Parameters
179182
----------
180183
n : int
181-
The minimum number of samples that are generated at once
182-
(see Notes).
183-
fluxes : boolean
184-
Whether to return fluxes or the internal solver variables. If set
185-
to False will return a variable for each forward and backward flux
186-
as well as all additional variables you might have defined in the
187-
model.
184+
The minimum number of samples that are generated at once.
185+
fluxes : bool, optional
186+
Whether to return fluxes or the internal solver variables. If
187+
set to False, will return a variable for each forward and
188+
backward flux as well as all additional variables you might
189+
have defined in the model (default True).
188190
189191
Returns
190192
-------
191-
numpy.matrix
192-
Returns a matrix with `n` rows, each containing a flux sample.
193+
pandas.DataFrame
194+
Returns a pandas DataFrame with `n` rows, each containing a
195+
flux sample.
193196
194197
Notes
195198
-----
196199
Performance of this function linearly depends on the number
197200
of reactions in your model and the thinning factor.
198201
199202
If the number of processes is larger than one, computation is split
200-
across as the CPUs of your machine. This may shorten computation time.
201-
However, there is also overhead in setting up parallel computation so
202-
we recommend to calculate large numbers of samples at once
203-
(`n` > 1000).
203+
across the CPU cores of your machine. This may shorten computation
204+
time. However, there is also overhead in setting up parallel
205+
computation primitives so, we recommend to calculate large numbers
206+
of samples at once (`n` > 1000).
204207
205208
"""
206209
if self.processes > 1:
@@ -234,17 +237,17 @@ def sample(self, n, fluxes=True):
234237
if fluxes:
235238
names = [r.id for r in self.model.reactions]
236239

237-
return pandas.DataFrame(
240+
return pd.DataFrame(
238241
chains[:, self.fwd_idx] - chains[:, self.rev_idx],
239242
columns=names,
240243
)
241244
else:
242245
names = [v.name for v in self.model.variables]
243246

244-
return pandas.DataFrame(chains, columns=names)
247+
return pd.DataFrame(chains, columns=names)
245248

246249
# Models can be large so don't pass them around during multiprocessing
247-
def __getstate__(self):
250+
def __getstate__(self) -> Dict:
248251
"""Return the object for serialization."""
249252
d = dict(self.__dict__)
250253
del d["model"]

src/cobra/sampling/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def sample(model, n, method="optgp", thinning=100, processes=1, seed=None):
7373
"""
7474

7575
if method == "optgp":
76-
sampler = OptGPSampler(model, processes, thinning=thinning, seed=seed)
76+
sampler = OptGPSampler(model, processes=processes, thinning=thinning, seed=seed)
7777
elif method == "achr":
7878
sampler = ACHRSampler(model, thinning=thinning, seed=seed)
7979
else:

src/cobra/test/test_sampling/test_optgp.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,61 @@
1-
# -*- coding: utf-8 -*-
2-
31
"""Test functionalities of OptGPSampler."""
42

5-
from __future__ import absolute_import
3+
from typing import TYPE_CHECKING
64

75
import numpy as np
86
import pytest
97

108
from cobra.sampling import OptGPSampler
119

1210

11+
if TYPE_CHECKING:
12+
from cobra import Model
13+
from cobra.sampling import ACHRSampler
14+
15+
1316
@pytest.fixture(scope="function")
14-
def optgp(model):
17+
def optgp(model: "Model") -> OptGPSampler:
1518
"""Return OptGPSampler instance for tests."""
16-
1719
sampler = OptGPSampler(model, processes=1, thinning=1)
1820
assert (sampler.n_warmup > 0) and (sampler.n_warmup <= 2 * len(model.variables))
1921
assert all(sampler.validate(sampler.warmup) == "v")
2022

2123
return sampler
2224

2325

24-
def test_optgp_init_benchmark(model, benchmark):
26+
def test_optgp_init_benchmark(model: "Model", benchmark) -> None:
2527
"""Benchmark inital OptGP sampling."""
26-
2728
benchmark(lambda: OptGPSampler(model, processes=2))
2829

2930

30-
def test_optgp_sample_benchmark(optgp, benchmark):
31+
def test_optgp_sample_benchmark(optgp: "Model", benchmark) -> None:
3132
"""Benchmark OptGP sampling."""
32-
3333
benchmark(optgp.sample, 1)
3434

3535

36-
def test_sampling(optgp):
36+
def test_sampling(optgp: OptGPSampler) -> None:
3737
"""Test sampling."""
38-
3938
s = optgp.sample(10)
4039
assert all(optgp.validate(s) == "v")
4140

4241

43-
def test_batch_sampling(optgp):
42+
def test_batch_sampling(optgp: OptGPSampler) -> None:
4443
"""Test batch sampling."""
45-
4644
for b in optgp.batch(5, 4):
4745
assert all(optgp.validate(b) == "v")
4846

4947

50-
def test_variables_samples(achr, optgp):
48+
def test_variables_samples(achr: "ACHRSampler", optgp: OptGPSampler) -> None:
5149
"""Test variable samples."""
52-
5350
vnames = np.array([v.name for v in achr.model.variables])
5451
s = optgp.sample(10, fluxes=False)
5552
assert s.shape == (10, optgp.warmup.shape[1])
5653
assert (s.columns == vnames).all()
5754
assert (optgp.validate(s) == "v").all()
5855

5956

60-
def test_reproject(optgp):
57+
def test_reproject(optgp: OptGPSampler) -> None:
6158
"""Test reprojection of sampling."""
62-
6359
s = optgp.sample(10, fluxes=False).values
6460
proj = np.apply_along_axis(optgp._reproject, 1, s)
6561
assert all(optgp.validate(proj) == "v")

0 commit comments

Comments
 (0)