3
3
from anndata .base import Raw
4
4
except ImportError :
5
5
from anndata import Raw
6
+ import batchglm .api as glm
6
7
import logging
7
8
import numpy as np
8
9
import pandas as pd
9
10
import patsy
10
11
import scipy .sparse
11
12
from typing import Union , List , Dict , Callable , Tuple
12
13
13
- from batchglm import data as data_utils
14
- from batchglm .models .base import _EstimatorBase , _InputDataBase
15
14
from diffxpy import pkg_constants
16
15
from diffxpy .models .batch_bfgs .optim import Estim_BFGS
17
16
from .det import DifferentialExpressionTestLRT , DifferentialExpressionTestWald , \
@@ -40,7 +39,7 @@ def _fit(
40
39
quick_scale : bool = None ,
41
40
close_session = True ,
42
41
dtype = "float64"
43
- ) -> _EstimatorBase :
42
+ ) -> glm . typing . InputDataBaseTyping :
44
43
"""
45
44
:param noise_model: str, noise model to use in model-based unit_test. Possible options:
46
45
@@ -187,7 +186,7 @@ def _fit(
187
186
188
187
189
188
def lrt (
190
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
189
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
191
190
full_formula_loc : str ,
192
191
reduced_formula_loc : str ,
193
192
full_formula_scale : str = "~1" ,
@@ -298,25 +297,25 @@ def lrt(
298
297
sample_description = sample_description
299
298
)
300
299
301
- full_design_loc = data_utils .design_matrix (
300
+ full_design_loc = glm . data .design_matrix (
302
301
sample_description = sample_description ,
303
302
formula = full_formula_loc ,
304
303
as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
305
304
return_type = "patsy"
306
305
)
307
- reduced_design_loc = data_utils .design_matrix (
306
+ reduced_design_loc = glm . data .design_matrix (
308
307
sample_description = sample_description ,
309
308
formula = reduced_formula_loc ,
310
309
as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
311
310
return_type = "patsy"
312
311
)
313
- full_design_scale = data_utils .design_matrix (
312
+ full_design_scale = glm . data .design_matrix (
314
313
sample_description = sample_description ,
315
314
formula = full_formula_scale ,
316
315
as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
317
316
return_type = "patsy"
318
317
)
319
- reduced_design_scale = data_utils .design_matrix (
318
+ reduced_design_scale = glm . data .design_matrix (
320
319
sample_description = sample_description ,
321
320
formula = reduced_formula_scale ,
322
321
as_categorical = [False if x in as_numeric else True for x in sample_description .columns .values ],
@@ -371,7 +370,7 @@ def lrt(
371
370
372
371
373
372
def wald (
374
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
373
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
375
374
factor_loc_totest : Union [str , List [str ]] = None ,
376
375
coef_to_test : Union [str , List [str ]] = None ,
377
376
formula_loc : Union [None , str ] = None ,
@@ -597,7 +596,7 @@ def wald(
597
596
elif coef_to_test is not None :
598
597
# Directly select coefficients to test from design matrix (xarray):
599
598
# Check that coefficients to test are not dependent parameters if constraints are given:
600
- coef_loc_names = data_utils .view_coef_names (design_loc ).tolist ()
599
+ coef_loc_names = glm . data .view_coef_names (design_loc ).tolist ()
601
600
if not np .all ([x in coef_loc_names for x in coef_to_test ]):
602
601
raise ValueError (
603
602
"the requested test coefficients %s were found in model coefficients %s" %
@@ -645,7 +644,7 @@ def wald(
645
644
646
645
647
646
def t_test (
648
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
647
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
649
648
grouping ,
650
649
gene_names : Union [np .ndarray , list ] = None ,
651
650
sample_description : pd .DataFrame = None ,
@@ -687,7 +686,7 @@ def t_test(
687
686
688
687
689
688
def rank_test (
690
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
689
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
691
690
grouping : Union [str , np .ndarray , list ],
692
691
gene_names : Union [np .ndarray , list ] = None ,
693
692
sample_description : pd .DataFrame = None ,
@@ -729,7 +728,7 @@ def rank_test(
729
728
730
729
731
730
def two_sample (
732
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
731
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
733
732
grouping : Union [str , np .ndarray , list ],
734
733
as_numeric : Union [List [str ], Tuple [str ], str ] = (),
735
734
test : str = "t-test" ,
@@ -902,7 +901,7 @@ def two_sample(
902
901
903
902
904
903
def pairwise (
905
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
904
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
906
905
grouping : Union [str , np .ndarray , list ],
907
906
as_numeric : Union [List [str ], Tuple [str ], str ] = (),
908
907
test : str = 'z-test' ,
@@ -1026,7 +1025,7 @@ def pairwise(
1026
1025
1027
1026
if test .lower () == 'z-test' or test .lower () == 'z_test' or test .lower () == 'ztest' :
1028
1027
# -1 in formula removes intercept
1029
- dmat = data_utils .design_matrix (
1028
+ dmat = glm . data .design_matrix (
1030
1029
sample_description ,
1031
1030
formula = "~ 1 - 1 + grouping"
1032
1031
)
@@ -1113,7 +1112,7 @@ def pairwise(
1113
1112
1114
1113
1115
1114
def versus_rest (
1116
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
1115
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
1117
1116
grouping : Union [str , np .ndarray , list ],
1118
1117
as_numeric : Union [List [str ], Tuple [str ], str ] = (),
1119
1118
test : str = 'wald' ,
@@ -1275,7 +1274,7 @@ def versus_rest(
1275
1274
1276
1275
1277
1276
def partition (
1278
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
1277
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
1279
1278
parts : Union [str , np .ndarray , list ],
1280
1279
gene_names : Union [np .ndarray , list ] = None ,
1281
1280
sample_description : pd .DataFrame = None
@@ -1318,7 +1317,7 @@ class _Partition:
1318
1317
1319
1318
def __init__ (
1320
1319
self ,
1321
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , _InputDataBase ],
1320
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm . typing . InputDataBaseTyping ],
1322
1321
parts : Union [str , np .ndarray , list ],
1323
1322
gene_names : Union [np .ndarray , list ] = None ,
1324
1323
sample_description : pd .DataFrame = None
@@ -1333,7 +1332,7 @@ def __init__(
1333
1332
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
1334
1333
:param sample_description: optional pandas.DataFrame containing sample annotations
1335
1334
"""
1336
- if isinstance (data , _InputDataBase ):
1335
+ if isinstance (data , glm . typing . InputDataBaseTyping ):
1337
1336
self .x = data .x
1338
1337
elif isinstance (data , anndata .AnnData ) or isinstance (data , Raw ):
1339
1338
self .x = data .X
0 commit comments