Skip to content

Commit 8f1a0a0

Browse files
got coefficient comparison to OLS model to work
1 parent 779e5db commit 8f1a0a0

File tree

2 files changed

+188
-28
lines changed

2 files changed

+188
-28
lines changed

diffxpy/testing/det.py

Lines changed: 162 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def plot_volcano(
292292
highlight_col: str = "red",
293293
show: bool = True,
294294
save: Union[str, None] = None,
295-
suffix: str = "_volcano.png"
295+
suffix: str = "_volcano.png",
296+
return_axs: bool = False
296297
):
297298
"""
298299
Returns a volcano plot of p-value vs. log fold change
@@ -314,6 +315,7 @@ def plot_volcano(
314315
:param save: Path+file name stem to save plots to.
315316
File will be save+suffix. Does not save if save is None.
316317
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
318+
:param return_axs: Whether to return axis objects.
317319
318320
:return: Tuple of matplotlib (figure, axis)
319321
"""
@@ -379,7 +381,10 @@ def plot_volcano(
379381

380382
plt.close(fig)
381383

382-
return ax
384+
if return_axs:
385+
return ax
386+
else:
387+
return
383388

384389
def plot_ma(
385390
self,
@@ -393,7 +398,8 @@ def plot_ma(
393398
highlight_col: str = "red",
394399
show: bool = True,
395400
save: Union[str, None] = None,
396-
suffix: str = "_my_plot.png"
401+
suffix: str = "_ma_plot.png",
402+
return_axs: bool = False
397403
):
398404
"""
399405
Returns an MA plot of mean expression vs. log fold change with significance
@@ -416,7 +422,7 @@ def plot_ma(
416422
:param save: Path+file name stem to save plots to.
417423
File will be save+suffix. Does not save if save is None.
418424
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
419-
425+
:param return_axs: Whether to return axis objects.
420426
421427
:return: Tuple of matplotlib (figure, axis)
422428
"""
@@ -487,7 +493,10 @@ def plot_ma(
487493
plt.close(fig)
488494
plt.ion()
489495

490-
return ax
496+
if return_axs:
497+
return ax
498+
else:
499+
return
491500

492501

493502
class _DifferentialExpressionTestSingle(_DifferentialExpressionTest, metaclass=abc.ABCMeta):
@@ -756,7 +765,8 @@ class DifferentialExpressionTestWald(_DifferentialExpressionTestSingle):
756765
def __init__(
757766
self,
758767
model_estim: _Estimation,
759-
col_indices: np.ndarray
768+
col_indices: np.ndarray,
769+
noise_model: str,
760770
):
761771
"""
762772
:param model_estim:
@@ -766,6 +776,7 @@ def __init__(
766776

767777
self.model_estim = model_estim
768778
self.coef_loc_totest = col_indices
779+
self.noise_model = noise_model
769780

770781
try:
771782
if model_estim._error_codes is not None:
@@ -889,7 +900,18 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
889900

890901
return res
891902

892-
def plot_vs_ttest(self, log10=False):
903+
def plot_vs_ttest(
904+
self,
905+
log10=False,
906+
return_axs: bool = False
907+
):
908+
"""
909+
910+
:param log10:
911+
:param return_axs: Whether to return axis objects.
912+
913+
:return:
914+
"""
893915
import matplotlib.pyplot as plt
894916
import seaborn as sns
895917
from .tests import t_test
@@ -913,7 +935,139 @@ def plot_vs_ttest(self, log10=False):
913935

914936
ax.set(xlabel="t-test", ylabel='wald test')
915937

916-
return fig, ax
938+
if return_axs:
939+
return ax
940+
else:
941+
return
942+
943+
def plot_comparison_ols(
944+
self,
945+
size=20,
946+
show: bool = True,
947+
save: Union[str, None] = None,
948+
suffix: str = "_ols_comparison.png",
949+
ncols=5,
950+
row_gap=0.3,
951+
col_gap=0.25,
952+
return_axs: bool = False
953+
):
954+
"""
955+
Plot location model coefficients of inferred model against those obtained from an OLS model.
956+
957+
Red line shown is the identity line.
958+
959+
:param size: Size of points.
960+
:param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
961+
:param save: Path+file name stem to save plots to.
962+
File will be save+suffix. Does not save if save is None.
963+
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
964+
:param ncols: Number of columns in plot grid if multiple genes are plotted.
965+
:param row_gap: Vertical gap between panel rows relative to panel height.
966+
:param col_gap: Horizontal gap between panel columns relative to panel width.
967+
:param return_axs: Whether to return axis objects.
968+
969+
:return: Matplotlib axis objects.
970+
"""
971+
import seaborn as sns
972+
import matplotlib.pyplot as plt
973+
from matplotlib import gridspec
974+
from matplotlib import rcParams
975+
from batchglm.api.models.glm_norm import Estimator, InputData
976+
977+
# Run OLS model fit to have comparison coefficients.
978+
input_data_ols = InputData.new(
979+
data=self.model_estim.input_data.data,
980+
design_loc=self.model_estim.input_data.design_loc,
981+
design_scale=self.model_estim.input_data.design_scale[:, [0]],
982+
constraints_loc=self.model_estim.input_data.constraints_loc,
983+
constraints_scale=self.model_estim.input_data.constraints_scale[[0], [0]],
984+
size_factors=self.model_estim.input_data.size_factors,
985+
feature_names=self.model_estim.input_data.features,
986+
)
987+
estim_ols = Estimator(
988+
input_data=input_data_ols,
989+
init_model=None,
990+
init_a="standard",
991+
init_b="standard",
992+
dtype=self.model_estim.a_var.dtype
993+
)
994+
estim_ols.initialize()
995+
store_ols = estim_ols.finalize()
996+
997+
# Prepare parameter summary of both model fits.
998+
par_loc = input_data_ols.data.coords["design_loc_params"].values
999+
if self.noise_model == "nb":
1000+
# Translate coefficients from OLS fit to be multiplicative in identity space.
1001+
a_var_ols = store_ols.a_var.values
1002+
a_var_ols[1:, :] = (a_var_ols[1:, :] + a_var_ols[[0], :]) / a_var_ols[[0], :]
1003+
elif self.noise_model == "norm":
1004+
a_var_ols = store_ols.a_var
1005+
else:
1006+
raise ValueError("noise model %s not yet supported for plot_comparison_ols" % self.noise_model)
1007+
1008+
summaries_fits = [
1009+
pd.DataFrame({
1010+
"user": self.model_estim.inverse_link_loc(self.model_estim.a_var[i, :]),
1011+
"ols": a_var_ols[i, :],
1012+
"coef": par_loc[i]
1013+
}) for i in range(self.model_estim.a_var.shape[0])
1014+
]
1015+
1016+
plt.ioff()
1017+
nrows = len(par_loc) // ncols + int((len(par_loc) % ncols) > 0)
1018+
1019+
gs = gridspec.GridSpec(
1020+
nrows=nrows,
1021+
ncols=ncols,
1022+
hspace=row_gap,
1023+
wspace=col_gap
1024+
)
1025+
fig = plt.figure(
1026+
figsize=(
1027+
ncols * rcParams['figure.figsize'][0], # width in inches
1028+
nrows * rcParams['figure.figsize'][1] * (1 + row_gap) # height in inches
1029+
)
1030+
)
1031+
1032+
axs = []
1033+
for i, par_i in enumerate(par_loc):
1034+
ax = plt.subplot(gs[i])
1035+
axs.append(ax)
1036+
1037+
x = summaries_fits[i]["user"].values
1038+
y = summaries_fits[i]["ols"].values
1039+
1040+
sns.scatterplot(
1041+
x=x,
1042+
y=y,
1043+
ax=ax,
1044+
s=size
1045+
)
1046+
sns.lineplot(
1047+
x=np.array([np.min([np.min(x), np.min(y)]), np.max([np.max(x), np.max(y)])]),
1048+
y=np.array([np.min([np.min(x), np.min(y)]), np.max([np.max(x), np.max(y)])]),
1049+
ax=ax,
1050+
color="red",
1051+
legend=False
1052+
)
1053+
ax.set(xlabel="user supplied model", ylabel="OLS model")
1054+
title_i = par_loc[i] + " (R=" + str(np.round(np.corrcoef(x, y)[0, 1], 3)) + ")"
1055+
ax.set_title(title_i)
1056+
1057+
# Save, show and return figure.
1058+
if save is not None:
1059+
plt.savefig(save + suffix)
1060+
1061+
if show:
1062+
plt.show()
1063+
1064+
plt.close(fig)
1065+
plt.ion()
1066+
1067+
if return_axs:
1068+
return axs
1069+
else:
1070+
return
9171071

9181072

9191073
class DifferentialExpressionTestTT(_DifferentialExpressionTestSingle):

diffxpy/testing/tests.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def lrt(
294294
"""
295295
# TODO test nestedness
296296
if len(kwargs) != 0:
297-
logger.info("additional kwargs: %s", str(kwargs))
297+
logging.getLogger("diffxpy").info("additional kwargs: %s", str(kwargs))
298298

299299
if isinstance(as_numeric, str):
300300
as_numeric = [as_numeric]
@@ -486,7 +486,7 @@ def wald(
486486
:param kwargs: [Debugging] Additional arguments will be passed to the _fit method.
487487
"""
488488
if len(kwargs) != 0:
489-
logger.debug("additional kwargs: %s", str(kwargs))
489+
logging.getLogger("diffxpy").debug("additional kwargs: %s", str(kwargs))
490490

491491
if dmat_loc is None and formula_loc is None:
492492
raise ValueError("Supply either dmat_loc or formula_loc or formula.")
@@ -519,9 +519,11 @@ def wald(
519519
if noise_model.lower() not in ["normal", "norm"]:
520520
if init_a == "closed_form":
521521
init_a = "standard"
522-
logger.warning("Setting init_a to standard as numeric predictors were supplied.")
523-
logger.warning("Closed-form initialisation is not possible" +
524-
" for noise model %s with numeric predictors." % noise_model)
522+
logging.getLogger("diffxpy").warning(
523+
"Setting init_a to standard as numeric predictors were supplied.")
524+
logging.getLogger("diffxpy").warning(
525+
"Closed-form initialisation is not possible" +
526+
" for noise model %s with numeric predictors." % noise_model)
525527
elif init_a == "AUTO":
526528
init_a = "standard"
527529
else:
@@ -538,9 +540,11 @@ def wald(
538540
if np.any([True if x in as_numeric else False for x in sample_description.columns.values]):
539541
if init_b == "closed_form":
540542
init_b = "standard"
541-
logger.warning("Setting init_b to standard as numeric predictors were supplied.")
542-
logger.warning("Closed-form initialisation is not possible" +
543-
" for noise model %s with numeric predictors." % noise_model)
543+
logging.getLogger("diffxpy").warning(
544+
"Setting init_b to standard as numeric predictors were supplied.")
545+
logging.getLogger("diffxpy").warning(
546+
"Closed-form initialisation is not possible" +
547+
" for noise model %s with numeric predictors." % noise_model)
544548
elif init_b == "AUTO":
545549
init_b = "standard"
546550
else:
@@ -607,7 +611,8 @@ def wald(
607611

608612
de_test = DifferentialExpressionTestWald(
609613
model_estim=model,
610-
col_indices=col_indices
614+
col_indices=col_indices,
615+
noise_model=noise_model
611616
)
612617

613618
return de_test
@@ -968,7 +973,7 @@ def pairwise(
968973
:param kwargs: [Debugging] Additional arguments will be passed to the _fit method.
969974
"""
970975
if len(kwargs) != 0:
971-
logger.info("additional kwargs: %s", str(kwargs))
976+
logging.getLogger("diffxpy").info("additional kwargs: %s", str(kwargs))
972977

973978
if lazy and not (test.lower() == 'z-test' or test.lower() == 'z_test' or test.lower() == 'ztest'):
974979
raise ValueError("lazy evaluation of pairwise tests only possible if test is z-test")
@@ -1172,7 +1177,7 @@ def versus_rest(
11721177
:param kwargs: [Debugging] Additional arguments will be passed to the _fit method.
11731178
"""
11741179
if len(kwargs) != 0:
1175-
logger.info("additional kwargs: %s", str(kwargs))
1180+
logging.getLogger("diffxpy").info("additional kwargs: %s", str(kwargs))
11761181

11771182
# Do not store all models but only p-value and q-value matrix:
11781183
# genes x groups
@@ -1811,10 +1816,10 @@ def continuous_1d(
18111816
else:
18121817
factor_loc_totest_new = factor_loc_totest
18131818

1814-
logger.debug("model formulas assembled in de.test.continuos():")
1815-
logger.debug("factor_loc_totest_new: " + ",".join(factor_loc_totest_new))
1816-
logger.debug("formula_loc_new: " + formula_loc_new)
1817-
logger.debug("formula_scale_new: " + formula_scale_new)
1819+
logging.getLogger("diffxpy").debug("model formulas assembled in de.test.continuos():")
1820+
logging.getLogger("diffxpy").debug("factor_loc_totest_new: " + ",".join(factor_loc_totest_new))
1821+
logging.getLogger("diffxpy").debug("formula_loc_new: " + formula_loc_new)
1822+
logging.getLogger("diffxpy").debug("formula_scale_new: " + formula_scale_new)
18181823

18191824
de_test = wald(
18201825
data=X,
@@ -1837,6 +1842,7 @@ def continuous_1d(
18371842
)
18381843
de_test = DifferentialExpressionTestWaldCont(
18391844
de_test=de_test,
1845+
noise_model=noise_model,
18401846
size_factors=size_factors,
18411847
continuous_coords=sample_description[continuous].values,
18421848
spline_coefs=new_coefs
@@ -1860,11 +1866,11 @@ def continuous_1d(
18601866
full_formula_scale = formula_scale_new
18611867
reduced_formula_scale = formula_scale_new
18621868

1863-
logger.debug("model formulas assembled in de.test.continuous():")
1864-
logger.debug("full_formula_loc: " + full_formula_loc)
1865-
logger.debug("reduced_formula_loc: " + reduced_formula_loc)
1866-
logger.debug("full_formula_scale: " + full_formula_scale)
1867-
logger.debug("reduced_formula_scale: " + reduced_formula_scale)
1869+
logging.getLogger("diffxpy").debug("model formulas assembled in de.test.continuous():")
1870+
logging.getLogger("diffxpy").debug("full_formula_loc: " + full_formula_loc)
1871+
logging.getLogger("diffxpy").debug("reduced_formula_loc: " + reduced_formula_loc)
1872+
logging.getLogger("diffxpy").debug("full_formula_scale: " + full_formula_scale)
1873+
logging.getLogger("diffxpy").debug("reduced_formula_scale: " + reduced_formula_scale)
18681874

18691875
de_test = lrt(
18701876
data=X,

0 commit comments

Comments
 (0)