Skip to content

Commit ae73643

Browse files
adapted diffxpy to new hessian and fim computition interface in batchglm
1 parent 357df55 commit ae73643

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

diffxpy/pkg_constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
BATCHGLM_OPTIM_NEWTON_TR = False
99
BATCHGLM_OPTIM_IRLS = False
1010
BATCHGLM_OPTIM_IRLS_GD = False
11-
BATCHGLM_OPTIM_IRLS_TR = False
11+
BATCHGLM_OPTIM_IRLS_TR = True
1212
BATCHGLM_OPTIM_IRLS_GD_TR = True
13+
1314
BATCHGLM_PROVIDE_BATCHED = True
15+
BATCHGLM_PROVIDE_FIM = True
16+
BATCHGLM_PROVIDE_HESSIAN = False

diffxpy/testing/det.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,7 +2302,8 @@ def plot_genes(
23022302
show=True,
23032303
ncols=2,
23042304
row_gap=0.3,
2305-
col_gap=0.25
2305+
col_gap=0.25,
2306+
return_axs=False
23062307
):
23072308
"""
23082309
Plot observed data and spline fits of selected genes.
@@ -2317,6 +2318,7 @@ def plot_genes(
23172318
:param ncols: Number of columns in plot grid if multiple genes are plotted.
23182319
:param row_gap: Vertical gap between panel rows relative to panel height.
23192320
:param col_gap: Horizontal gap between panel columns relative to panel width.
2321+
:param return_axs: Whether to return axis objects of plots.
23202322
:return: Matplotlib axis objects.
23212323
"""
23222324

@@ -2390,8 +2392,10 @@ def plot_genes(
23902392

23912393
plt.close(fig)
23922394

2393-
return axs
2394-
2395+
if return_axs:
2396+
return axs
2397+
else:
2398+
return
23952399

23962400
def plot_heatmap(
23972401
self,
@@ -2402,7 +2406,8 @@ def plot_heatmap(
24022406
nticks=10,
24032407
cmap: str = "YlGnBu",
24042408
width=10,
2405-
height_per_gene=0.5
2409+
height_per_gene=0.5,
2410+
return_axs=False
24062411
):
24072412
"""
24082413
Plot observed data and spline fits of selected genes.
@@ -2416,6 +2421,7 @@ def plot_heatmap(
24162421
:param cmap: matplotlib cmap.
24172422
:param width: Width of heatmap figure.
24182423
:param height_per_gene: Height of each row (gene) in heatmap figure.
2424+
:param return_axs: Whether to return axis objects of plots.
24192425
:return: Matplotlib axis objects.
24202426
"""
24212427
import seaborn as sns
@@ -2480,7 +2486,10 @@ def plot_heatmap(
24802486

24812487
plt.close(fig)
24822488

2483-
return ax
2489+
if return_axs:
2490+
return axs
2491+
else:
2492+
return
24842493

24852494

24862495
class DifferentialExpressionTestWaldCont(_DifferentialExpressionTestCont):

diffxpy/testing/tests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def _fit(
170170
init_b=init_b,
171171
provide_optimizers=provide_optimizers,
172172
provide_batched=pkg_constants.BATCHGLM_PROVIDE_BATCHED,
173+
provide_fim=pkg_constants.BATCHGLM_PROVIDE_FIM,
174+
provide_hessian=pkg_constants.BATCHGLM_PROVIDE_HESSIAN,
173175
dtype=dtype,
174176
**constructor_args
175177
)

0 commit comments

Comments
 (0)