Skip to content

Commit 02218ca

Browse files
fixed plotting calls
1 parent dba866a commit 02218ca

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

diffxpy/testing/det.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def plot_volcano(
305305
plt.show()
306306

307307
plt.close(fig)
308+
plt.ion()
308309

309310
if return_axs:
310311
return ax
@@ -853,12 +854,19 @@ def summary(
853854
def plot_vs_ttest(
854855
self,
855856
log10=False,
857+
show: bool = True,
858+
save: Union[str, None] = None,
859+
suffix: str = "_plot_vs_ttest.png",
856860
return_axs: bool = False
857861
):
858862
"""
859863
Normalizes data by size factors if any were used in model.
860864
861865
:param log10:
866+
:param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
867+
:param save: Path+file name stem to save plots to.
868+
File will be save+suffix. Does not save if save is None.
869+
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
862870
:param return_axs: Whether to return axis objects.
863871
864872
:return:
@@ -867,12 +875,17 @@ def plot_vs_ttest(
867875
import seaborn as sns
868876
from .tests import t_test
869877

878+
plt.ioff()
879+
870880
grouping = np.asarray(self.model_estim.input_data.design_loc[:, self.coef_loc_totest])
871881
# Normalize by size factors that were used in regression.
872-
sf = np.broadcast_to(np.expand_dims(self.model_estim.input_data.size_factors, axis=1),
873-
shape=self.model_estim.x.shape)
882+
if self.model_estim.input_data.size_factors is not None:
883+
sf = np.broadcast_to(np.expand_dims(self.model_estim.input_data.size_factors, axis=1),
884+
shape=self.model_estim.x.shape)
885+
else:
886+
sf = np.ones(shape=(self.model_estim.x.shape[0], 1))
874887
ttest = t_test(
875-
data=self.model_estim.X.multiply(1 / sf, copy=True),
888+
data=self.model_estim.x / sf,
876889
grouping=grouping,
877890
gene_names=self.gene_ids,
878891
)
@@ -889,6 +902,16 @@ def plot_vs_ttest(
889902

890903
ax.set(xlabel="t-test", ylabel='wald test')
891904

905+
# Save, show and return figure.
906+
if save is not None:
907+
plt.savefig(save + suffix)
908+
909+
if show:
910+
plt.show()
911+
912+
plt.close(fig)
913+
plt.ion()
914+
892915
if return_axs:
893916
return ax
894917
else:
@@ -930,6 +953,8 @@ def plot_comparison_ols_coef(
930953
from matplotlib import rcParams
931954
from batchglm.api.models.glm_norm import Estimator, InputDataGLM
932955

956+
plt.ioff()
957+
933958
# Run OLS model fit to have comparison coefficients.
934959
if self._store_ols is None:
935960
input_data_ols = InputDataGLM(
@@ -1067,6 +1092,8 @@ def plot_comparison_ols_pred(
10671092
from matplotlib import rcParams
10681093
from batchglm.api.models.glm_norm import Estimator, InputDataGLM
10691094

1095+
plt.ioff()
1096+
10701097
# Run OLS model fit to have comparison coefficients.
10711098
if self._store_ols is None:
10721099
input_data_ols = InputDataGLM(

0 commit comments

Comments
 (0)