@@ -946,7 +946,7 @@ def plot_comparison_ols(
946
946
show : bool = True ,
947
947
save : Union [str , None ] = None ,
948
948
suffix : str = "_ols_comparison.png" ,
949
- ncols = 5 ,
949
+ ncols = 3 ,
950
950
row_gap = 0.3 ,
951
951
col_gap = 0.25 ,
952
952
return_axs : bool = False
@@ -1069,6 +1069,126 @@ def plot_comparison_ols(
1069
1069
else :
1070
1070
return
1071
1071
1072
+ def plot_gene_fits (
1073
+ self ,
1074
+ gene_names : Tuple ,
1075
+ covariate_x : str = None ,
1076
+ covariate_hue : str = None ,
1077
+ size_factor_normalize : bool = False ,
1078
+ log_transform : bool = False ,
1079
+ show : bool = True ,
1080
+ save : Union [str , None ] = None ,
1081
+ suffix : str = "_genes.png" ,
1082
+ ncols = 3 ,
1083
+ row_gap = 0.3 ,
1084
+ col_gap = 0.25 ,
1085
+ return_axs : bool = False
1086
+ ):
1087
+ """
1088
+ Plot gene-wise model fits and observed distribution by covariates.
1089
+
1090
+ Use this to inspect fitting performance on individual genes.
1091
+
1092
+ :param gene_names: Genes to generate plots for.
1093
+ :param covariate_x: Covariate in location model to partition x-axis by.
1094
+ :param covariate_hue: Covariate in location model to stack boxplots by.
1095
+ :param size_factor_normalize: Whether to size_factor normalize observations
1096
+ before estimating the distribution with boxplot.
1097
+ :param log_transform: Whether to log transform observations
1098
+ before estimating the distribution with boxplot. Model estimates are adjusted accordingly.
1099
+ :param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
1100
+ :param save: Path+file name stem to save plots to.
1101
+ File will be save+suffix. Does not save if save is None.
1102
+ :param suffix: Suffix for file name to save plot to. Also use this to set the file type.
1103
+ :param ncols: Number of columns in plot grid if multiple genes are plotted.
1104
+ :param row_gap: Vertical gap between panel rows relative to panel height.
1105
+ :param col_gap: Horizontal gap between panel columns relative to panel width.
1106
+ :param return_axs: Whether to return axis objects.
1107
+
1108
+ :return: Matplotlib axis objects.
1109
+ """
1110
+ import seaborn as sns
1111
+ import matplotlib .pyplot as plt
1112
+ from matplotlib import gridspec
1113
+ from matplotlib import rcParams
1114
+
1115
+ plt .ioff ()
1116
+ nrows = len (par_loc ) // ncols + int ((len (par_loc ) % ncols ) > 0 )
1117
+
1118
+ gs = gridspec .GridSpec (
1119
+ nrows = nrows ,
1120
+ ncols = ncols ,
1121
+ hspace = row_gap ,
1122
+ wspace = col_gap
1123
+ )
1124
+ fig = plt .figure (
1125
+ figsize = (
1126
+ ncols * rcParams ['figure.figsize' ][0 ], # width in inches
1127
+ nrows * rcParams ['figure.figsize' ][1 ] * (1 + row_gap ) # height in inches
1128
+ )
1129
+ )
1130
+
1131
+ axs = []
1132
+ for i , g in enumerate (gene_names ):
1133
+ # Prepare data boxplot.
1134
+ y = self .model_estim .X [:, g ]
1135
+ if size_factor_normalize :
1136
+ pass
1137
+ if log_transform :
1138
+ pass
1139
+
1140
+ summary_fit = pd .DataFrame ({
1141
+ "y" : y
1142
+ })
1143
+ if covariate_x is not None :
1144
+ if covariate_x in self .model_estim .design_loc .columns :
1145
+ summary_fit ["x" ] = self .model_estim .design_loc . [covariate_x ]
1146
+ else
1147
+ raise ValueError ("covariate_x=%s not found in location model" % covariate_x )
1148
+
1149
+ if covariate_hue is not None :
1150
+ if covariate_hue in self .model_estim .design_loc .columns :
1151
+ summary_fit ["hue" ] = self .model_estim .design_loc . [covariate_hue ]
1152
+ else
1153
+ raise ValueError ("covariate_x=%s not found in location model" % covariate_x )
1154
+
1155
+ # Prepare model fit plot.
1156
+ if self .noise_model == "nb" :
1157
+ pass
1158
+ elif self .noise_model == "norm" :
1159
+ pass
1160
+ else :
1161
+ raise ValueError ("noise model %s not yet supported for plot_gene_fits" % self .noise_model )
1162
+
1163
+ ax = plt .subplot (gs [i ])
1164
+ axs .append (ax )
1165
+
1166
+ sns .boxplot (
1167
+ x = "x" ,
1168
+ y = "y" ,
1169
+ hue = "hue" ,
1170
+ data = summary_fit ,
1171
+ ax = ax
1172
+ )
1173
+
1174
+ ax .set (xlabel = "covariate" , ylabel = "expression" )
1175
+ ax .set_title (g )
1176
+
1177
+ # Save, show and return figure.
1178
+ if save is not None :
1179
+ plt .savefig (save + suffix )
1180
+
1181
+ if show :
1182
+ plt .show ()
1183
+
1184
+ plt .close (fig )
1185
+ plt .ion ()
1186
+
1187
+ if return_axs :
1188
+ return axs
1189
+ else :
1190
+ return
1191
+
1072
1192
1073
1193
class DifferentialExpressionTestTT (_DifferentialExpressionTestSingle ):
1074
1194
"""
0 commit comments