Skip to content

Commit 8d2ddf1

Browse files
Merge branch 'dev'
2 parents 7e47156 + 6eac8e6 commit 8d2ddf1

19 files changed

+5052
-4456
lines changed

diffxpy/api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from . import enrich
66
from . import stats
77
from . import utils
8+
from .. import pkg_constants

diffxpy/api/stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from diffxpy.stats.stats import wald_test
44
from diffxpy.stats.stats import wald_test_chisq
55
from diffxpy.stats.stats import two_coef_z_test
6-
from diffxpy.stats.stats import wilcoxon_test
6+
from diffxpy.stats.stats import mann_whitney_u_test
77
from diffxpy.stats.stats import t_test_moments
88
from diffxpy.stats.stats import t_test_raw
99

diffxpy/api/test.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
from diffxpy.testing.base import two_sample
2-
from diffxpy.testing.base import lrt
3-
from diffxpy.testing.base import wald
4-
from diffxpy.testing.base import t_test
5-
from diffxpy.testing.base import wilcoxon
6-
from diffxpy.testing.base import partition
7-
from diffxpy.testing.base import pairwise
8-
from diffxpy.testing.base import versus_rest
9-
from diffxpy.testing.base import continuous_1d
10-
from diffxpy.testing.base import design_matrix
11-
from diffxpy.testing.base import coef_names
1+
from diffxpy.testing import lrt, wald, t_test, rank_test, two_sample, pairwise, \
2+
versus_rest, partition, continuous_1d
3+
from diffxpy.testing import design_matrix, coef_names

diffxpy/enrichment/enrich.py

Lines changed: 76 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import logging
12
import numpy as np
23
import pandas as pd
4+
from typing import Union
35

46
from ..stats import stats
57
from ..testing import correction
6-
from ..testing.base import _DifferentialExpressionTest
8+
from ..testing.det import _DifferentialExpressionTest
79

10+
logger = logging.getLogger(__name__)
811

9-
class RefSets():
12+
class RefSets:
1013
"""
1114
Class for a list of gene sets.
1215
@@ -18,7 +21,7 @@ class RefSets():
1821
.gmt files can be downloaded from http://software.broadinstitute.org/gsea/msigdb/collections.jsp for example.
1922
"""
2023

21-
class _Set():
24+
class _Set:
2225
"""
2326
Class for a single gene set.
2427
"""
@@ -182,18 +185,19 @@ def overlap(self, enq_set: set, set_id=None):
182185
for x in self.sets:
183186
x.intersect = x.genes.intersection(enq_set)
184187
else:
185-
x.intersect = self.get_set(id).genes.intersection(enq_set)
188+
x.intersect = self.get_set(id).genes.intersection(enq_set) # bug
186189

187190

188191
def test(
189-
RefSets: RefSets,
190-
DETest: _DifferentialExpressionTest = None,
191-
pval: np.array = None,
192-
gene_ids: list = None,
192+
ref: RefSets,
193+
det: Union[_DifferentialExpressionTest, None] = None,
194+
pval: Union[np.array, None] = None,
195+
gene_ids: Union[list, None] = None,
193196
de_threshold=0.05,
197+
incl_all_zero=False,
194198
all_ids=None,
195-
clean_ref=True,
196-
upper=False
199+
clean_ref=False,
200+
capital=True
197201
):
198202
""" Perform gene set enrichment.
199203
@@ -214,51 +218,61 @@ def test(
214218
which can be matched against the identifieres in the sets in RefSets.
215219
:param de_threshold:
216220
Significance threshold at which a differential test (a multiple-testing
217-
corrected p-value) is called siginficant. This
221+
corrected p-value) is called siginficant. T
222+
:param incl_all_zero:
223+
Wehther to include genes in gene universe which were all zero.
218224
:param all_ids:
219225
Set of all gene identifiers, this is used as the background set in the
220226
hypergeometric test. Only supply this if not all genes were tested
221227
and are supplied above in DETest or gene_ids.
222228
:param clean_ref:
223229
Whether or not to only retain gene identifiers in RefSets that occur in
224230
the background set of identifiers supplied here through all_ids.
225-
:param upper:
231+
:param capital:
226232
Make all gene IDs captial.
227233
"""
228234
return Enrich(
229-
RefSets=RefSets,
230-
DETest=DETest,
235+
ref=ref,
236+
det=det,
231237
pval=pval,
232238
gene_ids=gene_ids,
233239
de_threshold=de_threshold,
240+
incl_all_zero=incl_all_zero,
234241
all_ids=all_ids,
235242
clean_ref=clean_ref,
236-
upper=upper)
243+
capital=capital
244+
)
237245

238246

239-
class Enrich():
247+
class Enrich:
240248
"""
241249
"""
242250

243251
def __init__(
244252
self,
245-
RefSets: RefSets,
246-
DETest: _DifferentialExpressionTest = None,
247-
pval: np.array = None,
248-
gene_ids: list = None,
249-
de_threshold=0.05,
250-
all_ids=None,
251-
clean_ref=True,
252-
upper=False
253+
ref: RefSets,
254+
det: Union[_DifferentialExpressionTest, None],
255+
pval: Union[np.array, None],
256+
gene_ids: Union[list, None],
257+
de_threshold,
258+
incl_all_zero,
259+
all_ids,
260+
clean_ref,
261+
capital
253262
):
254263
self._n_overlaps = None
255264
self._pval_enrich = None
256265
self._qval_enrich = None
257266
# Load multiple-testing-corrected differential expression
258267
# p-values from differential expression output.
259-
if DETest is not None:
260-
self._qval_de = DETest.qval
261-
self._gene_ids = DETest.gene_ids
268+
if det is not None:
269+
if incl_all_zero:
270+
self._qval_de = det.qval
271+
self._gene_ids = det.gene_ids
272+
else:
273+
idx_not_all_zero = np.where(np.logical_not(det.summary()["zero_mean"].values))[0]
274+
self._qval_de = det.qval[idx_not_all_zero]
275+
self._gene_ids = det.gene_ids[idx_not_all_zero]
262276
elif pval is not None and gene_ids is not None:
263277
self._qval_de = np.asarray(pval)
264278
self._gene_ids = gene_ids
@@ -268,8 +282,11 @@ def __init__(
268282
# Select significant genes based on user defined threshold.
269283
if any([x is np.nan for x in self._gene_ids]):
270284
idx_notnan = np.where([x is not np.nan for x in self._gene_ids])[0]
271-
print('Discarded ' + str(len(self._gene_ids) - len(idx_notnan)) + ' nan gene ids, leaving ' +
272-
str(len(idx_notnan)) + ' genes.')
285+
logger.info(
286+
" Discarded %i nan gene ids, leaving %i genes.",
287+
len(self._gene_ids) - len(idx_notnan),
288+
len(idx_notnan)
289+
)
273290
self._qval_de = self._qval_de[idx_notnan]
274291
self._gene_ids = self._gene_ids[idx_notnan]
275292

@@ -280,25 +297,31 @@ def __init__(
280297
else:
281298
self._all_ids = set(self._gene_ids)
282299

283-
if upper == True:
300+
if capital:
284301
self._gene_ids = [x.upper() for x in self._gene_ids]
285302
self._all_ids = set([x.upper() for x in self._all_ids])
303+
self._significant_ids = set([x.upper() for x in self._significant_ids])
286304

287305
# Generate diagnostic statistic of number of possible overlaps in total.
288-
print(str(len(set(self._all_ids).intersection(set(RefSets._genes)))) +
289-
' overlaps found between refset (' + str(len(RefSets._genes)) +
290-
') and provided gene list (' + str(len(self._all_ids)) + ').')
291-
self.missing_genes = list(set(RefSets._genes).difference(set(self._all_ids)))
306+
logger.info(
307+
" %i overlaps found between refset (%i) and provided gene list (%i).",
308+
len(set(self._all_ids).intersection(set(ref._genes))),
309+
len(ref._genes),
310+
len(self._all_ids)
311+
)
312+
self.missing_genes = list(set(ref._genes).difference(set(self._all_ids)))
292313
# Clean reference set to only contains ids that were observed in
293314
# current study if required.
294-
self.RefSets = RefSets
295-
if clean_ref == True:
315+
self.RefSets = ref
316+
if clean_ref:
296317
self.RefSets.clean(self._all_ids)
297318
# Print if there are empty sets.
298319
idx_nonempty = np.where([len(x.genes) > 0 for x in self.RefSets.sets])[0]
299320
if len(self.RefSets.sets) - len(idx_nonempty) > 0:
300-
print('Found ' + str(len(self.RefSets.sets) - len(idx_nonempty)) +
301-
' empty sets, removing those.')
321+
logger.info(
322+
" Found %i empty sets, removing those.",
323+
len(self.RefSets.sets) - len(idx_nonempty)
324+
)
302325
self.RefSets = self.RefSets.subset(idx=idx_nonempty)
303326
elif len(idx_nonempty) == 0:
304327
raise ValueError('all RefSets were empty')
@@ -356,7 +379,7 @@ def grepv_sets(self, x):
356379
"""
357380
return self.RefSets.grepv_sets(x)
358381

359-
def set(id):
382+
def set(self, id):
360383
"""
361384
Return the set with a given set identifier.
362385
"""
@@ -374,9 +397,11 @@ def significant_set_ids(self, threshold=0.05) -> np.array:
374397
"""
375398
return [self.RefSets._ids[i] for i in np.where(self.qval <= threshold)[0]]
376399

377-
def summary(self) -> pd.DataFrame:
400+
def summary(self, sort=True) -> pd.DataFrame:
378401
"""
379402
Summarize gene set enrichement analysis as an output table.
403+
404+
:param sort: Whether to sort table by p-value.
380405
"""
381406
res = pd.DataFrame({
382407
"set": self.RefSets._ids,
@@ -388,5 +413,15 @@ def summary(self) -> pd.DataFrame:
388413
"background": len(self._all_ids)
389414
})
390415
# Sort by p-value
391-
res = res.iloc[np.argsort(res['pval'].values), :]
416+
if sort:
417+
res = res.iloc[np.argsort(res['pval'].values), :]
392418
return res
419+
420+
def set_summary(self, id: str):
421+
"""
422+
Summarize gene set enrichement analysis for a given set.
423+
:param id: Gene set to enquire.
424+
425+
:return: Slice of summary table.
426+
"""
427+
return self.summary(sort=False).iloc[self.RefSets._ids.index(id), :]

diffxpy/pkg_constants.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
BATCHGLM_OPTIM_GD = True
2-
BATCHGLM_OPTIM_ADAM = True
1+
DE_TREAT_ZEROVAR_TT_AS_SIG = True
2+
3+
BATCHGLM_OPTIM_GD = False
4+
BATCHGLM_OPTIM_ADAM = False
35
BATCHGLM_OPTIM_ADAGRAD = False
46
BATCHGLM_OPTIM_RMSPROP = False
5-
BATCHGLM_OPTIM_NEWTON = True
6-
BATCHGLM_OPTIM_IRLS = True
7-
BATCHGLM_TERMINATION_TYPE = "by_feature"
7+
BATCHGLM_OPTIM_NEWTON = False
8+
BATCHGLM_OPTIM_NEWTON_TR = False
9+
BATCHGLM_OPTIM_IRLS = False
10+
BATCHGLM_OPTIM_IRLS_GD = False
11+
BATCHGLM_OPTIM_IRLS_TR = True
12+
BATCHGLM_OPTIM_IRLS_GD_TR = True
13+
14+
BATCHGLM_PROVIDE_BATCHED = True
15+
BATCHGLM_PROVIDE_FIM = True
16+
BATCHGLM_PROVIDE_HESSIAN = False

diffxpy/stats/stats.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from typing import Union
2+
13
import numpy as np
24
import numpy.linalg
35
import scipy.stats
4-
from typing import Union
6+
import xarray as xr
57

68

79
def likelihood_ratio_test(
@@ -37,7 +39,7 @@ def likelihood_ratio_test(
3739
return pvals
3840

3941

40-
def wilcoxon_test(
42+
def mann_whitney_u_test(
4143
x0: np.ndarray,
4244
x1: np.ndarray,
4345
):
@@ -68,7 +70,8 @@ def wilcoxon_test(
6870
scipy.stats.mannwhitneyu(
6971
x=x0[:, i].flatten(),
7072
y=x1[:, i].flatten(),
71-
alternative='two-sided'
73+
use_continuity=True,
74+
alternative="two-sided"
7275
).pvalue for i in range(x0.shape[1])
7376
])
7477
return pvals
@@ -152,7 +155,7 @@ def t_test_moments(
152155
out=s_delta
153156
)
154157

155-
t_statistic = np.abs((mu0 - mu1) / s_delta)
158+
t_statistic = np.abs(mu0 - mu1) / s_delta
156159

157160
divisor = (
158161
(np.square(var0 / n0) / (n0 - 1)) +
@@ -165,16 +168,15 @@ def t_test_moments(
165168
out=divisor
166169
)
167170

168-
with np.errstate(over='ignore'):
169-
df = np.square((var0 / n0) + (var1 / n1)) / divisor
171+
df = np.square((var0 / n0) + (var1 / n1)) / divisor
170172
np.clip(
171173
df,
172174
a_min=np.nextafter(0, np.inf, dtype=df.dtype),
173175
a_max=np.nextafter(np.inf, 0, dtype=df.dtype),
174176
out=df
175177
)
176178

177-
pval = 2 * (1 - scipy.stats.t(df).cdf(t_statistic))
179+
pval = 2 * scipy.stats.t.sf(t_statistic, df)
178180
return pval
179181

180182

@@ -261,6 +263,9 @@ def wald_test_chisq(
261263
raise ValueError('stats.wald_test(): theta_mle and theta0 have to contain the same number of entries')
262264

263265
theta_diff = theta_mle - theta0
266+
# Convert to nd.array to avoid gufunc error.
267+
if isinstance(theta_diff, xr.DataArray):
268+
theta_diff = theta_diff.values
264269
wald_statistic = np.array([
265270
np.matmul(
266271
np.matmul(
@@ -337,6 +342,9 @@ def hypergeom_test(
337342
:param background: int
338343
Size of background set.
339344
"""
340-
pvals = np.array([1 - scipy.stats.hypergeom(M=background, n=references[i], N=enquiry).cdf(x - 1) for i, x in
341-
enumerate(intersections)])
342-
return (pvals)
345+
pvals = np.array([1 - scipy.stats.hypergeom(
346+
M=background,
347+
n=references[i],
348+
N=enquiry
349+
).cdf(x - 1) for i, x in enumerate(intersections)])
350+
return pvals

diffxpy/testing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .tests import lrt, wald, t_test, rank_test, two_sample, pairwise, \
2+
versus_rest, partition, continuous_1d
3+
from .utils import design_matrix, coef_names

0 commit comments

Comments
 (0)