Skip to content

Commit f7ad2ed

Browse files
added new type hinting api from batchglm
1 parent 4a5640f commit f7ad2ed

File tree

4 files changed

+59
-65
lines changed

4 files changed

+59
-65
lines changed

diffxpy/fit/fit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33
from anndata.base import Raw
44
except ImportError:
55
from anndata import Raw
6+
import batchglm.api as glm
67
import logging
78
import numpy as np
89
import pandas as pd
910
import patsy
1011
import scipy.sparse
1112
from typing import Union, List, Dict, Callable, Tuple
1213

13-
from batchglm.models.base import _InputDataBase
1414
from .external import _fit
1515
from .external import parse_gene_names, parse_sample_description, parse_size_factors, parse_grouping, \
1616
constraint_system_from_star
1717

1818

1919
def model(
20-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
20+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
2121
formula_loc: Union[None, str] = None,
2222
formula_scale: Union[None, str] = "~1",
2323
as_numeric: Union[List[str], Tuple[str], str] = (),
@@ -226,7 +226,7 @@ def model(
226226

227227

228228
def residuals(
229-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
229+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
230230
formula_loc: Union[None, str] = None,
231231
formula_scale: Union[None, str] = "~1",
232232
as_numeric: Union[List[str], Tuple[str], str] = (),
@@ -400,7 +400,7 @@ def residuals(
400400

401401

402402
def partition(
403-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
403+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
404404
parts: Union[str, np.ndarray, list],
405405
gene_names: Union[np.ndarray, list] = None,
406406
sample_description: pd.DataFrame = None,
@@ -454,7 +454,7 @@ class _Partition:
454454

455455
def __init__(
456456
self,
457-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
457+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
458458
parts: Union[str, np.ndarray, list],
459459
gene_names: Union[np.ndarray, list] = None,
460460
sample_description: pd.DataFrame = None,
@@ -481,7 +481,7 @@ def __init__(
481481
same order as in data or string-type column identifier of size-factor containing
482482
column in sample description.
483483
"""
484-
if isinstance(data, _InputDataBase):
484+
if isinstance(data, glm.typing.InputDataBaseTyping):
485485
self.x = data.x
486486
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
487487
self.x = data.X

diffxpy/testing/det.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import abc
2+
try:
3+
import anndata
4+
except ImportError:
5+
anndata = None
6+
import batchglm.api as glm
27
import logging
3-
from typing import Union, Dict, Tuple, List, Set
8+
import numpy as np
9+
import patsy
410
import pandas as pd
511
from random import sample
612
import scipy.sparse
7-
8-
import numpy as np
9-
import patsy
13+
from typing import Union, Dict, Tuple, List, Set
1014

1115
from .utils import split_x, dmat_unique
12-
13-
try:
14-
import anndata
15-
except ImportError:
16-
anndata = None
17-
18-
from batchglm.models.base import _EstimatorBase, _InputDataBase
19-
2016
from ..stats import stats
2117
from . import correction
2218
from diffxpy import pkg_constants
@@ -468,17 +464,17 @@ class DifferentialExpressionTestLRT(_DifferentialExpressionTestSingle):
468464

469465
sample_description: pd.DataFrame
470466
full_design_loc_info: patsy.design_info
471-
full_estim: _EstimatorBase
467+
full_estim: glm.typing.EstimatorBaseTyping
472468
reduced_design_loc_info: patsy.design_info
473-
reduced_estim: _EstimatorBase
469+
reduced_estim: glm.typing.EstimatorBaseTyping
474470

475471
def __init__(
476472
self,
477473
sample_description: pd.DataFrame,
478474
full_design_loc_info: patsy.design_info,
479-
full_estim: _EstimatorBase,
475+
full_estim: glm.typing.EstimatorBaseTyping,
480476
reduced_design_loc_info: patsy.design_info,
481-
reduced_estim: _EstimatorBase
477+
reduced_estim: glm.typing.EstimatorBaseTyping
482478
):
483479
super().__init__()
484480
self.sample_description = sample_description
@@ -689,7 +685,7 @@ class DifferentialExpressionTestWald(_DifferentialExpressionTestSingle):
689685
Single wald test per gene.
690686
"""
691687

692-
model_estim: _EstimatorBase
688+
model_estim: glm.typing.EstimatorBaseTyping
693689
sample_description: pd.DataFrame
694690
coef_loc_totest: np.ndarray
695691
theta_mle: np.ndarray
@@ -699,7 +695,7 @@ class DifferentialExpressionTestWald(_DifferentialExpressionTestSingle):
699695

700696
def __init__(
701697
self,
702-
model_estim: _EstimatorBase,
698+
model_estim: glm.typing.EstimatorBaseTyping,
703699
col_indices: np.ndarray,
704700
noise_model: str,
705701
sample_description: pd.DataFrame
@@ -1548,7 +1544,7 @@ def __init__(
15481544
super().__init__()
15491545
if isinstance(data, anndata.AnnData) or isinstance(data, anndata.Raw):
15501546
data = data.X
1551-
elif isinstance(data, _InputDataBase):
1547+
elif isinstance(data, glm.typing.InputDataBaseTyping):
15521548
data = data.x
15531549
self._x = data
15541550
self.sample_description = sample_description
@@ -1673,7 +1669,7 @@ def __init__(
16731669
super().__init__()
16741670
if isinstance(data, anndata.AnnData) or isinstance(data, anndata.Raw):
16751671
data = data.X
1676-
elif isinstance(data, _InputDataBase):
1672+
elif isinstance(data, glm.typing.InputDataBaseTyping):
16771673
data = data.x
16781674
self._x = data
16791675
self.sample_description = sample_description
@@ -2090,13 +2086,13 @@ class DifferentialExpressionTestZTest(_DifferentialExpressionTestMulti):
20902086
Pairwise unit_test between more than 2 groups per gene.
20912087
"""
20922088

2093-
model_estim: _EstimatorBase
2089+
model_estim: glm.typing.EstimatorBaseTyping
20942090
theta_mle: np.ndarray
20952091
theta_sd: np.ndarray
20962092

20972093
def __init__(
20982094
self,
2099-
model_estim: _EstimatorBase,
2095+
model_estim: glm.typing.EstimatorBaseTyping,
21002096
grouping,
21012097
groups,
21022098
correction_type: str
@@ -2293,13 +2289,13 @@ class DifferentialExpressionTestZTestLazy(_DifferentialExpressionTestMulti):
22932289
memory.
22942290
"""
22952291

2296-
model_estim: _EstimatorBase
2292+
model_estim: glm.typing.EstimatorBaseTyping
22972293
_theta_mle: np.ndarray
22982294
_theta_sd: np.ndarray
22992295

23002296
def __init__(
23012297
self,
2302-
model_estim: _EstimatorBase,
2298+
model_estim: glm.typing.EstimatorBaseTyping,
23032299
grouping, groups,
23042300
correction_type="global"
23052301
):
@@ -2856,15 +2852,15 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
28562852

28572853
class _DifferentialExpressionTestCont(_DifferentialExpressionTestSingle):
28582854
_de_test: _DifferentialExpressionTestSingle
2859-
_model_estim: _EstimatorBase
2855+
_model_estim: glm.typing.EstimatorBaseTyping
28602856
_size_factors: np.ndarray
28612857
_continuous_coords: np.ndarray
28622858
_spline_coefs: list
28632859

28642860
def __init__(
28652861
self,
28662862
de_test: _DifferentialExpressionTestSingle,
2867-
model_estim: _EstimatorBase,
2863+
model_estim: glm.typing.EstimatorBaseTyping,
28682864
size_factors: np.ndarray,
28692865
continuous_coords: np.ndarray,
28702866
spline_coefs: list,

diffxpy/testing/tests.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
from anndata.base import Raw
44
except ImportError:
55
from anndata import Raw
6+
import batchglm.api as glm
67
import logging
78
import numpy as np
89
import pandas as pd
910
import patsy
1011
import scipy.sparse
1112
from typing import Union, List, Dict, Callable, Tuple
1213

13-
from batchglm import data as data_utils
14-
from batchglm.models.base import _EstimatorBase, _InputDataBase
1514
from diffxpy import pkg_constants
1615
from diffxpy.models.batch_bfgs.optim import Estim_BFGS
1716
from .det import DifferentialExpressionTestLRT, DifferentialExpressionTestWald, \
@@ -40,7 +39,7 @@ def _fit(
4039
quick_scale: bool = None,
4140
close_session=True,
4241
dtype="float64"
43-
) -> _EstimatorBase:
42+
) -> glm.typing.InputDataBaseTyping:
4443
"""
4544
:param noise_model: str, noise model to use in model-based unit_test. Possible options:
4645
@@ -187,7 +186,7 @@ def _fit(
187186

188187

189188
def lrt(
190-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
189+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
191190
full_formula_loc: str,
192191
reduced_formula_loc: str,
193192
full_formula_scale: str = "~1",
@@ -298,25 +297,25 @@ def lrt(
298297
sample_description=sample_description
299298
)
300299

301-
full_design_loc = data_utils.design_matrix(
300+
full_design_loc = glm.data.design_matrix(
302301
sample_description=sample_description,
303302
formula=full_formula_loc,
304303
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
305304
return_type="patsy"
306305
)
307-
reduced_design_loc = data_utils.design_matrix(
306+
reduced_design_loc = glm.data.design_matrix(
308307
sample_description=sample_description,
309308
formula=reduced_formula_loc,
310309
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
311310
return_type="patsy"
312311
)
313-
full_design_scale = data_utils.design_matrix(
312+
full_design_scale = glm.data.design_matrix(
314313
sample_description=sample_description,
315314
formula=full_formula_scale,
316315
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
317316
return_type="patsy"
318317
)
319-
reduced_design_scale = data_utils.design_matrix(
318+
reduced_design_scale = glm.data.design_matrix(
320319
sample_description=sample_description,
321320
formula=reduced_formula_scale,
322321
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values],
@@ -371,7 +370,7 @@ def lrt(
371370

372371

373372
def wald(
374-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
373+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
375374
factor_loc_totest: Union[str, List[str]] = None,
376375
coef_to_test: Union[str, List[str]] = None,
377376
formula_loc: Union[None, str] = None,
@@ -597,7 +596,7 @@ def wald(
597596
elif coef_to_test is not None:
598597
# Directly select coefficients to test from design matrix (xarray):
599598
# Check that coefficients to test are not dependent parameters if constraints are given:
600-
coef_loc_names = data_utils.view_coef_names(design_loc).tolist()
599+
coef_loc_names = glm.data.view_coef_names(design_loc).tolist()
601600
if not np.all([x in coef_loc_names for x in coef_to_test]):
602601
raise ValueError(
603602
"the requested test coefficients %s were found in model coefficients %s" %
@@ -645,7 +644,7 @@ def wald(
645644

646645

647646
def t_test(
648-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
647+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
649648
grouping,
650649
gene_names: Union[np.ndarray, list] = None,
651650
sample_description: pd.DataFrame = None,
@@ -687,7 +686,7 @@ def t_test(
687686

688687

689688
def rank_test(
690-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
689+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
691690
grouping: Union[str, np.ndarray, list],
692691
gene_names: Union[np.ndarray, list] = None,
693692
sample_description: pd.DataFrame = None,
@@ -729,7 +728,7 @@ def rank_test(
729728

730729

731730
def two_sample(
732-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
731+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
733732
grouping: Union[str, np.ndarray, list],
734733
as_numeric: Union[List[str], Tuple[str], str] = (),
735734
test: str = "t-test",
@@ -902,7 +901,7 @@ def two_sample(
902901

903902

904903
def pairwise(
905-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
904+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
906905
grouping: Union[str, np.ndarray, list],
907906
as_numeric: Union[List[str], Tuple[str], str] = (),
908907
test: str = 'z-test',
@@ -1026,7 +1025,7 @@ def pairwise(
10261025

10271026
if test.lower() == 'z-test' or test.lower() == 'z_test' or test.lower() == 'ztest':
10281027
# -1 in formula removes intercept
1029-
dmat = data_utils.design_matrix(
1028+
dmat = glm.data.design_matrix(
10301029
sample_description,
10311030
formula="~ 1 - 1 + grouping"
10321031
)
@@ -1113,7 +1112,7 @@ def pairwise(
11131112

11141113

11151114
def versus_rest(
1116-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
1115+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
11171116
grouping: Union[str, np.ndarray, list],
11181117
as_numeric: Union[List[str], Tuple[str], str] = (),
11191118
test: str = 'wald',
@@ -1275,7 +1274,7 @@ def versus_rest(
12751274

12761275

12771276
def partition(
1278-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
1277+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
12791278
parts: Union[str, np.ndarray, list],
12801279
gene_names: Union[np.ndarray, list] = None,
12811280
sample_description: pd.DataFrame = None
@@ -1318,7 +1317,7 @@ class _Partition:
13181317

13191318
def __init__(
13201319
self,
1321-
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, _InputDataBase],
1320+
data: Union[anndata.AnnData, Raw, np.ndarray, scipy.sparse.csr_matrix, glm.typing.InputDataBaseTyping],
13221321
parts: Union[str, np.ndarray, list],
13231322
gene_names: Union[np.ndarray, list] = None,
13241323
sample_description: pd.DataFrame = None
@@ -1333,7 +1332,7 @@ def __init__(
13331332
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
13341333
:param sample_description: optional pandas.DataFrame containing sample annotations
13351334
"""
1336-
if isinstance(data, _InputDataBase):
1335+
if isinstance(data, glm.typing.InputDataBaseTyping):
13371336
self.x = data.x
13381337
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
13391338
self.x = data.X

0 commit comments

Comments
 (0)