Skip to content

Commit 3fd95a4

Browse files
fixed bug in initialisation with numeric inputs
1 parent ae73643 commit 3fd95a4

File tree

2 files changed

+83
-21
lines changed

2 files changed

+83
-21
lines changed

diffxpy/testing/tests.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def lrt(
223223
use wald() for constraints.
224224
225225
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
226-
Input data
226+
Input data matrix (observations x features) or (cells x genes).
227227
:param reduced_formula_loc: formula
228228
Reduced model formula for location and scale parameter models.
229229
If not specified, `reduced_formula` will be used instead.
@@ -403,7 +403,7 @@ def wald(
403403
Perform Wald test for differential expression for each gene.
404404
405405
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
406-
Input data
406+
Input data matrix (observations x features) or (cells x genes).
407407
:param factor_loc_totest: str, list of strings
408408
List of factors of formula to test with Wald test.
409409
E.g. "condition" or ["batch", "condition"] if formula_loc would be "~ 1 + batch + condition"
@@ -530,6 +530,17 @@ def wald(
530530
formula=formula_loc,
531531
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values]
532532
)
533+
# Check that closed-form is not used if numeric predictors are used and model is not "norm".
534+
if isinstance(init_a, str):
535+
if np.any([True if x in as_numeric else False for x in sample_description.columns.values]):
536+
if noise_model.lower() not in ["normal", "norm"]:
537+
if init_a == "closed_form":
538+
init_a = "standard"
539+
logger.warning("Setting init_a to standard as numeric predictors were supplied.")
540+
logger.warning("Closed-form initialisation is not possible" +
541+
" for noise model %s with numeric predictors." % noise_model)
542+
elif init_a == "AUTO":
543+
init_a = "standard"
533544
else:
534545
design_loc = dmat_loc
535546

@@ -539,6 +550,16 @@ def wald(
539550
formula=formula_scale,
540551
as_categorical=[False if x in as_numeric else True for x in sample_description.columns.values]
541552
)
553+
# Check that closed-form is not used if numeric predictors are used and model is not "norm".
554+
if isinstance(init_b, str):
555+
if np.any([True if x in as_numeric else False for x in sample_description.columns.values]):
556+
if init_b == "closed_form":
557+
init_b = "standard"
558+
logger.warning("Setting init_b to standard as numeric predictors were supplied.")
559+
logger.warning("Closed-form initialisation is not possible" +
560+
" for noise model %s with numeric predictors." % noise_model)
561+
elif init_b == "AUTO":
562+
init_b = "standard"
542563
else:
543564
design_scale = dmat_scale
544565

@@ -622,7 +643,7 @@ def t_test(
622643
between two groups on adata object for each gene.
623644
624645
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
625-
Input data
646+
Input data matrix (observations x features) or (cells x genes).
626647
:param grouping: str, array
627648
628649
- column in data.obs/sample_description which contains the split of observations into the two groups.
@@ -662,7 +683,7 @@ def rank_test(
662683
between two groups on adata object for each gene.
663684
664685
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
665-
Input data
686+
Input data matrix (observations x features) or (cells x genes).
666687
:param grouping: str, array
667688
668689
- column in data.obs/sample_description which contains the split of observations into the two groups.
@@ -736,7 +757,7 @@ def two_sample(
736757
Wilcoxon rank sum (Mann-Whitney U) test between both observation groups.
737758
738759
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
739-
Input data
760+
Input data matrix (observations x features) or (cells x genes).
740761
:param grouping: str, array
741762
742763
- column in data.obs/sample_description which contains the split of observations into the two groups.
@@ -925,7 +946,7 @@ def pairwise(
925946
Wilcoxon rank sum (Mann-Whitney U) test between both observation groups.
926947
927948
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
928-
Input data
949+
Input data matrix (observations x features) or (cells x genes).
929950
:param grouping: str, array
930951
931952
- column in data.obs/sample_description which contains the split of observations into the two groups.
@@ -1143,7 +1164,7 @@ def versus_rest(
11431164
Wilcoxon rank sum (Mann-Whitney U) test between both observation groups.
11441165
11451166
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
1146-
Input data
1167+
Input data matrix (observations x features) or (cells x genes).
11471168
:param grouping: str, array
11481169
11491170
- column in data.obs/sample_description which contains the split of observations into the two groups.
@@ -1271,7 +1292,7 @@ def partition(
12711292
Wraps _Partition so that doc strings are nice.
12721293
12731294
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
1274-
Input data
1295+
Input data matrix (observations x features) or (cells x genes).
12751296
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
12761297
:param sample_description: optional pandas.DataFrame containing sample annotations
12771298
"""
@@ -1300,7 +1321,7 @@ def __init__(
13001321
sample_description: pd.DataFrame = None):
13011322
"""
13021323
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
1303-
Input data
1324+
Input data matrix (observations x features) or (cells x genes).
13041325
:param partition: str, array
13051326
13061327
- column in data.obs/sample_description which contains the split of observations into the two groups.
@@ -1615,9 +1636,8 @@ def continuous_1d(
16151636
continuous: str,
16161637
df: int = 5,
16171638
factor_loc_totest: Union[str, List[str]] = None,
1618-
formula: str = None,
16191639
formula_loc: str = None,
1620-
formula_scale: str = None,
1640+
formula_scale: str = "~1",
16211641
as_numeric: Union[List[str], Tuple[str], str] = (),
16221642
test: str = 'wald',
16231643
init_a: Union[np.ndarray, str] = "standard",
@@ -1647,7 +1667,7 @@ def continuous_1d(
16471667
dmat directly to one of the test routines wald() or lrt().
16481668
16491669
:param data: Array-like, xr.DataArray, xr.Dataset or anndata.Anndata object containing observations.
1650-
Input data
1670+
Input data matrix (observations x features) or (cells x genes).
16511671
:param continuous: str
16521672
16531673
- column in data.obs/sample_description which contains the continuous covariate.
@@ -1690,15 +1710,13 @@ def continuous_1d(
16901710
16911711
- str:
16921712
* "auto": automatically choose best initialization
1693-
* "random": initialize with random values
16941713
* "standard": initialize intercept with observed mean
16951714
- np.ndarray: direct initialization of 'a'
16961715
:param init_b: (Optional) Low-level initial values for b
16971716
Can be:
16981717
16991718
- str:
17001719
* "auto": automatically choose best initialization
1701-
* "random": initialize with random values
17021720
* "standard": initialize with zeros
17031721
- np.ndarray: direct initialization of 'b'
17041722
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
@@ -1735,13 +1753,8 @@ def continuous_1d(
17351753
Should be "float32" for single precision or "float64" for double precision.
17361754
:param kwargs: [Debugging] Additional arguments will be passed to the _fit method.
17371755
"""
1738-
if formula is None and (formula_loc is None or formula_scale is None):
1739-
raise ValueError("supply either formula or fomula_loc and formula_scale")
1740-
if formula is not None and (formula_loc is not None or formula_scale is not None):
1741-
raise ValueError("supply either formula or fomula_loc and formula_scale")
1742-
# Check that continuous factor is contained in model formulas:
1743-
if formula is not None:
1744-
pass
1756+
if formula_loc is None:
1757+
raise ValueError("supply fomula_loc")
17451758
# Set testing default to continuous covariate if not supplied:
17461759
if factor_loc_totest is None:
17471760
factor_loc_totest = [continuous]

diffxpy/unit_test/test_numeric.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
3+
import numpy as np
4+
import pandas as pd
5+
import scipy.stats as stats
6+
import logging
7+
8+
from batchglm.api.models.glm_nb import Simulator
9+
import diffxpy.api as de
10+
11+
12+
class TestNumeric(unittest.TestCase):
13+
14+
def test(self):
15+
"""
16+
Check that factors that are numeric receive the correct number of coefficients.
17+
18+
:return:
19+
"""
20+
logging.getLogger("tensorflow").setLevel(logging.ERROR)
21+
logging.getLogger("batchglm").setLevel(logging.WARNING)
22+
logging.getLogger("diffxpy").setLevel(logging.WARNING)
23+
24+
sim = Simulator(num_observations=2000, num_features=2)
25+
sim.generate_sample_description(num_batches=0, num_conditions=2)
26+
sim.generate_params()
27+
sim.generate_data()
28+
29+
sample_description = sim.sample_description
30+
sample_description["numeric1"] = np.random.random(size=sim.num_observations)
31+
sample_description["numeric2"] = np.random.random(size=sim.num_observations)
32+
33+
test = de.test.wald(
34+
data=sim.X,
35+
sample_description=sample_description,
36+
formula_loc="~ 1 + condition + numeric1 + numeric2",
37+
formula_scale="~ 1",
38+
factor_loc_totest="condition",
39+
as_numeric=["numeric1", "numeric2"],
40+
training_strategy="DEFAULT"
41+
)
42+
# Check that number of coefficients is correct.
43+
assert test.model_estim.a_var.shape[0] == 4
44+
45+
return True
46+
47+
48+
if __name__ == '__main__':
49+
unittest.main()

0 commit comments

Comments
 (0)