@@ -292,7 +292,8 @@ def plot_volcano(
292
292
highlight_col : str = "red" ,
293
293
show : bool = True ,
294
294
save : Union [str , None ] = None ,
295
- suffix : str = "_volcano.png"
295
+ suffix : str = "_volcano.png" ,
296
+ return_axs : bool = False
296
297
):
297
298
"""
298
299
Returns a volcano plot of p-value vs. log fold change
@@ -314,6 +315,7 @@ def plot_volcano(
314
315
:param save: Path+file name stem to save plots to.
315
316
File will be save+suffix. Does not save if save is None.
316
317
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
318
+ :param return_axs: Whether to return axis objects.
317
319
318
320
:return: Tuple of matplotlib (figure, axis)
319
321
"""
@@ -379,7 +381,10 @@ def plot_volcano(
379
381
380
382
plt .close (fig )
381
383
382
- return ax
384
+ if return_axs :
385
+ return ax
386
+ else :
387
+ return
383
388
384
389
def plot_ma (
385
390
self ,
@@ -393,7 +398,8 @@ def plot_ma(
393
398
highlight_col : str = "red" ,
394
399
show : bool = True ,
395
400
save : Union [str , None ] = None ,
396
- suffix : str = "_my_plot.png"
401
+ suffix : str = "_ma_plot.png" ,
402
+ return_axs : bool = False
397
403
):
398
404
"""
399
405
Returns an MA plot of mean expression vs. log fold change with significance
@@ -416,7 +422,7 @@ def plot_ma(
416
422
:param save: Path+file name stem to save plots to.
417
423
File will be save+suffix. Does not save if save is None.
418
424
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
419
-
425
+ :param return_axs: Whether to return axis objects.
420
426
421
427
:return: Tuple of matplotlib (figure, axis)
422
428
"""
@@ -487,7 +493,10 @@ def plot_ma(
487
493
plt .close (fig )
488
494
plt .ion ()
489
495
490
- return ax
496
+ if return_axs :
497
+ return ax
498
+ else :
499
+ return
491
500
492
501
493
502
class _DifferentialExpressionTestSingle (_DifferentialExpressionTest , metaclass = abc .ABCMeta ):
@@ -756,7 +765,8 @@ class DifferentialExpressionTestWald(_DifferentialExpressionTestSingle):
756
765
def __init__ (
757
766
self ,
758
767
model_estim : _Estimation ,
759
- col_indices : np .ndarray
768
+ col_indices : np .ndarray ,
769
+ noise_model : str ,
760
770
):
761
771
"""
762
772
:param model_estim:
@@ -766,6 +776,7 @@ def __init__(
766
776
767
777
self .model_estim = model_estim
768
778
self .coef_loc_totest = col_indices
779
+ self .noise_model = noise_model
769
780
770
781
try :
771
782
if model_estim ._error_codes is not None :
@@ -889,7 +900,18 @@ def summary(self, qval_thres=None, fc_upper_thres=None,
889
900
890
901
return res
891
902
892
- def plot_vs_ttest (self , log10 = False ):
903
+ def plot_vs_ttest (
904
+ self ,
905
+ log10 = False ,
906
+ return_axs : bool = False
907
+ ):
908
+ """
909
+
910
+ :param log10:
911
+ :param return_axs: Whether to return axis objects.
912
+
913
+ :return:
914
+ """
893
915
import matplotlib .pyplot as plt
894
916
import seaborn as sns
895
917
from .tests import t_test
@@ -913,7 +935,139 @@ def plot_vs_ttest(self, log10=False):
913
935
914
936
ax .set (xlabel = "t-test" , ylabel = 'wald test' )
915
937
916
- return fig , ax
938
+ if return_axs :
939
+ return ax
940
+ else :
941
+ return
942
+
943
+ def plot_comparison_ols (
944
+ self ,
945
+ size = 20 ,
946
+ show : bool = True ,
947
+ save : Union [str , None ] = None ,
948
+ suffix : str = "_ols_comparison.png" ,
949
+ ncols = 5 ,
950
+ row_gap = 0.3 ,
951
+ col_gap = 0.25 ,
952
+ return_axs : bool = False
953
+ ):
954
+ """
955
+ Plot location model coefficients of inferred model against those obtained from an OLS model.
956
+
957
+ Red line shown is the identity line.
958
+
959
+ :param size: Size of points.
960
+ :param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
961
+ :param save: Path+file name stem to save plots to.
962
+ File will be save+suffix. Does not save if save is None.
963
+ :param suffix: Suffix for file name to save plot to. Also use this to set the file type.
964
+ :param ncols: Number of columns in plot grid if multiple genes are plotted.
965
+ :param row_gap: Vertical gap between panel rows relative to panel height.
966
+ :param col_gap: Horizontal gap between panel columns relative to panel width.
967
+ :param return_axs: Whether to return axis objects.
968
+
969
+ :return: Matplotlib axis objects.
970
+ """
971
+ import seaborn as sns
972
+ import matplotlib .pyplot as plt
973
+ from matplotlib import gridspec
974
+ from matplotlib import rcParams
975
+ from batchglm .api .models .glm_norm import Estimator , InputData
976
+
977
+ # Run OLS model fit to have comparison coefficients.
978
+ input_data_ols = InputData .new (
979
+ data = self .model_estim .input_data .data ,
980
+ design_loc = self .model_estim .input_data .design_loc ,
981
+ design_scale = self .model_estim .input_data .design_scale [:, [0 ]],
982
+ constraints_loc = self .model_estim .input_data .constraints_loc ,
983
+ constraints_scale = self .model_estim .input_data .constraints_scale [[0 ], [0 ]],
984
+ size_factors = self .model_estim .input_data .size_factors ,
985
+ feature_names = self .model_estim .input_data .features ,
986
+ )
987
+ estim_ols = Estimator (
988
+ input_data = input_data_ols ,
989
+ init_model = None ,
990
+ init_a = "standard" ,
991
+ init_b = "standard" ,
992
+ dtype = self .model_estim .a_var .dtype
993
+ )
994
+ estim_ols .initialize ()
995
+ store_ols = estim_ols .finalize ()
996
+
997
+ # Prepare parameter summary of both model fits.
998
+ par_loc = input_data_ols .data .coords ["design_loc_params" ].values
999
+ if self .noise_model == "nb" :
1000
+ # Translate coefficients from OLS fit to be multiplicative in identity space.
1001
+ a_var_ols = store_ols .a_var .values
1002
+ a_var_ols [1 :, :] = (a_var_ols [1 :, :] + a_var_ols [[0 ], :]) / a_var_ols [[0 ], :]
1003
+ elif self .noise_model == "norm" :
1004
+ a_var_ols = store_ols .a_var
1005
+ else :
1006
+ raise ValueError ("noise model %s not yet supported for plot_comparison_ols" % self .noise_model )
1007
+
1008
+ summaries_fits = [
1009
+ pd .DataFrame ({
1010
+ "user" : self .model_estim .inverse_link_loc (self .model_estim .a_var [i , :]),
1011
+ "ols" : a_var_ols [i , :],
1012
+ "coef" : par_loc [i ]
1013
+ }) for i in range (self .model_estim .a_var .shape [0 ])
1014
+ ]
1015
+
1016
+ plt .ioff ()
1017
+ nrows = len (par_loc ) // ncols + int ((len (par_loc ) % ncols ) > 0 )
1018
+
1019
+ gs = gridspec .GridSpec (
1020
+ nrows = nrows ,
1021
+ ncols = ncols ,
1022
+ hspace = row_gap ,
1023
+ wspace = col_gap
1024
+ )
1025
+ fig = plt .figure (
1026
+ figsize = (
1027
+ ncols * rcParams ['figure.figsize' ][0 ], # width in inches
1028
+ nrows * rcParams ['figure.figsize' ][1 ] * (1 + row_gap ) # height in inches
1029
+ )
1030
+ )
1031
+
1032
+ axs = []
1033
+ for i , par_i in enumerate (par_loc ):
1034
+ ax = plt .subplot (gs [i ])
1035
+ axs .append (ax )
1036
+
1037
+ x = summaries_fits [i ]["user" ].values
1038
+ y = summaries_fits [i ]["ols" ].values
1039
+
1040
+ sns .scatterplot (
1041
+ x = x ,
1042
+ y = y ,
1043
+ ax = ax ,
1044
+ s = size
1045
+ )
1046
+ sns .lineplot (
1047
+ x = np .array ([np .min ([np .min (x ), np .min (y )]), np .max ([np .max (x ), np .max (y )])]),
1048
+ y = np .array ([np .min ([np .min (x ), np .min (y )]), np .max ([np .max (x ), np .max (y )])]),
1049
+ ax = ax ,
1050
+ color = "red" ,
1051
+ legend = False
1052
+ )
1053
+ ax .set (xlabel = "user supplied model" , ylabel = "OLS model" )
1054
+ title_i = par_loc [i ] + " (R=" + str (np .round (np .corrcoef (x , y )[0 , 1 ], 3 )) + ")"
1055
+ ax .set_title (title_i )
1056
+
1057
+ # Save, show and return figure.
1058
+ if save is not None :
1059
+ plt .savefig (save + suffix )
1060
+
1061
+ if show :
1062
+ plt .show ()
1063
+
1064
+ plt .close (fig )
1065
+ plt .ion ()
1066
+
1067
+ if return_axs :
1068
+ return axs
1069
+ else :
1070
+ return
917
1071
918
1072
919
1073
class DifferentialExpressionTestTT (_DifferentialExpressionTestSingle ):
0 commit comments