Skip to content

Commit 8391338

Browse files
fixed new single tests by noise model
1 parent b851187 commit 8391338

File tree

3 files changed

+129
-40
lines changed

3 files changed

+129
-40
lines changed

diffxpy/testing/tests.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
DifferentialExpressionTestWaldCont, DifferentialExpressionTestLRTCont
2020
from .utils import parse_gene_names, parse_data, parse_sample_description, parse_size_factors, parse_grouping
2121

22-
logger = logging.getLogger("diffxpy")
23-
2422
# Use this to suppress matrix subclass PendingDepreceationWarnings from numpy:
2523
np.warnings.filterwarnings("ignore")
2624

@@ -145,8 +143,8 @@ def _fit(
145143
else:
146144
raise ValueError('base.test(): `noise_model="%s"` not recognized.' % noise_model)
147145

148-
logger.info("Fitting model...")
149-
logger.debug(" * Assembling input data...")
146+
logging.getLogger("diffxpy").info("Fitting model...")
147+
logging.getLogger("diffxpy").debug(" * Assembling input data...")
150148
input_data = InputData.new(
151149
data=data,
152150
design_loc=design_loc,
@@ -157,7 +155,7 @@ def _fit(
157155
feature_names=gene_names,
158156
)
159157

160-
logger.debug(" * Set up Estimator...")
158+
logging.getLogger("diffxpy").debug(" * Set up Estimator...")
161159
constructor_args = {}
162160
if batch_size is not None:
163161
constructor_args["batch_size"] = batch_size
@@ -176,10 +174,10 @@ def _fit(
176174
**constructor_args
177175
)
178176

179-
logger.debug(" * Initializing Estimator...")
177+
logging.getLogger("diffxpy").debug(" * Initializing Estimator...")
180178
estim.initialize()
181179

182-
logger.debug(" * Run estimation...")
180+
logging.getLogger("diffxpy").debug(" * Run estimation...")
183181
# training:
184182
if callable(training_strategy):
185183
# call training_strategy if it is a function
@@ -188,11 +186,11 @@ def _fit(
188186
estim.train_sequence(training_strategy=training_strategy)
189187

190188
if close_session:
191-
logger.debug(" * Finalize estimation...")
189+
logging.getLogger("diffxpy").debug(" * Finalize estimation...")
192190
model = estim.finalize()
193191
else:
194192
model = estim
195-
logger.debug(" * Model fitting done.")
193+
logging.getLogger("diffxpy").debug(" * Model fitting done.")
196194

197195
return model
198196

diffxpy/unit_test/test_single_de.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,30 @@
44
import pandas as pd
55
import scipy.stats as stats
66

7-
from batchglm.api.models.glm_nb import Simulator
87
import diffxpy.api as de
98

109

1110
class _TestSingleDE:
1211

13-
def _prepare_data(self, n_cells: int = 2000, n_genes: int = 100):
12+
def _prepare_data(
13+
self,
14+
n_cells: int,
15+
n_genes: int,
16+
noise_model: str
17+
):
1418
"""
1519
1620
:param n_cells: Number of cells to simulate (number of observations per test).
1721
:param n_genes: Number of genes to simulate (number of tests).
22+
:param noise_model: Noise model to use for data fitting.
1823
"""
24+
if noise_model == "nb":
25+
from batchglm.api.models.glm_nb import Simulator
26+
elif noise_model == "norm":
27+
from batchglm.api.models.glm_norm import Simulator
28+
else:
29+
raise ValueError("noise model %s not recognized" % noise_model)
30+
1931
num_non_de = n_genes // 2
2032
sim = Simulator(num_observations=n_cells, num_features=n_genes)
2133
sim.generate_sample_description(num_batches=0, num_conditions=2)
@@ -50,7 +62,12 @@ def _eval(self, sim, test):
5062

5163
return sim
5264

53-
def _test_rank_de(self, n_cells: int = 2000, n_genes: int = 100):
65+
def _test_rank_de(
66+
self,
67+
n_cells: int,
68+
n_genes: int,
69+
noise_model: str
70+
):
5471
"""
5572
:param n_cells: Number of cells to simulate (number of observations per test).
5673
:param n_genes: Number of genes to simulate (number of tests).
@@ -59,7 +76,11 @@ def _test_rank_de(self, n_cells: int = 2000, n_genes: int = 100):
5976
logging.getLogger("batchglm").setLevel(logging.WARNING)
6077
logging.getLogger("diffxpy").setLevel(logging.WARNING)
6178

62-
sim = self._prepare_data(n_cells=n_cells, n_genes=n_genes)
79+
sim = self._prepare_data(
80+
n_cells=n_cells,
81+
n_genes=n_genes,
82+
noise_model=noise_model
83+
)
6384

6485
test = de.test.rank_test(
6586
data=sim.X,
@@ -72,7 +93,11 @@ def _test_rank_de(self, n_cells: int = 2000, n_genes: int = 100):
7293

7394
return True
7495

75-
def _test_t_test_de(self, n_cells: int = 2000, n_genes: int = 100):
96+
def _test_t_test_de(
97+
self,
98+
n_cells: int,
99+
n_genes: int
100+
):
76101
"""
77102
:param n_cells: Number of cells to simulate (number of observations per test).
78103
:param n_genes: Number of genes to simulate (number of tests).
@@ -81,7 +106,11 @@ def _test_t_test_de(self, n_cells: int = 2000, n_genes: int = 100):
81106
logging.getLogger("batchglm").setLevel(logging.WARNING)
82107
logging.getLogger("diffxpy").setLevel(logging.WARNING)
83108

84-
sim = self._prepare_data(n_cells=n_cells, n_genes=n_genes)
109+
sim = self._prepare_data(
110+
n_cells=n_cells,
111+
n_genes=n_genes,
112+
noise_model="norm"
113+
)
85114

86115
test = de.test.t_test(
87116
data=sim.X,
@@ -109,7 +138,11 @@ def _test_wald_de(
109138
logging.getLogger("batchglm").setLevel(logging.WARNING)
110139
logging.getLogger("diffxpy").setLevel(logging.WARNING)
111140

112-
sim = self._prepare_data(n_cells=n_cells, n_genes=n_genes)
141+
sim = self._prepare_data(
142+
n_cells=n_cells,
143+
n_genes=n_genes,
144+
noise_model=noise_model
145+
)
113146

114147
test = de.test.wald(
115148
data=sim.X,
@@ -140,7 +173,11 @@ def _test_lrt_de(
140173
logging.getLogger("batchglm").setLevel(logging.WARNING)
141174
logging.getLogger("diffxpy").setLevel(logging.WARNING)
142175

143-
sim = self._prepare_data(n_cells=n_cells, n_genes=n_genes)
176+
sim = self._prepare_data(
177+
n_cells=n_cells,
178+
n_genes=n_genes,
179+
noise_model=noise_model
180+
)
144181

145182
test = de.test.lrt(
146183
data=sim.X,
@@ -264,5 +301,6 @@ def test_lrt_de_norm(
264301
noise_model="norm"
265302
)
266303

304+
267305
if __name__ == '__main__':
268306
unittest.main()

0 commit comments

Comments
 (0)