2
2
import logging
3
3
from typing import Union , Dict , Tuple , List , Set
4
4
import pandas as pd
5
+ from random import sample
5
6
6
7
import numpy as np
7
8
import xarray as xr
14
15
except ImportError :
15
16
anndata = None
16
17
18
+ from batchglm .xarray_sparse .base import SparseXArrayDataArray
17
19
from batchglm .models .glm_nb import Model as GeneralizedLinearModel
18
20
19
21
from ..stats import stats
@@ -777,6 +779,7 @@ def __init__(
777
779
self .model_estim = model_estim
778
780
self .coef_loc_totest = col_indices
779
781
self .noise_model = noise_model
782
+ self ._store_ols = None
780
783
781
784
try :
782
785
if model_estim ._error_codes is not None :
@@ -940,12 +943,12 @@ def plot_vs_ttest(
940
943
else :
941
944
return
942
945
943
- def plot_comparison_ols (
946
+ def plot_comparison_ols_coef (
944
947
self ,
945
948
size = 20 ,
946
949
show : bool = True ,
947
950
save : Union [str , None ] = None ,
948
- suffix : str = "_ols_comparison .png" ,
951
+ suffix : str = "_ols_comparison_coef .png" ,
949
952
ncols = 3 ,
950
953
row_gap = 0.3 ,
951
954
col_gap = 0.25 ,
@@ -955,6 +958,8 @@ def plot_comparison_ols(
955
958
Plot location model coefficients of inferred model against those obtained from an OLS model.
956
959
957
960
Red line shown is the identity line.
961
+ Note that this comparison only seems to be useful if the covariates are zero centred. This is
962
+ especially important for continuous covariates.
958
963
959
964
:param size: Size of points.
960
965
:param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
@@ -975,39 +980,47 @@ def plot_comparison_ols(
975
980
from batchglm .api .models .glm_norm import Estimator , InputData
976
981
977
982
# Run OLS model fit to have comparison coefficients.
978
- input_data_ols = InputData .new (
979
- data = self .model_estim .input_data .data ,
980
- design_loc = self .model_estim .input_data .design_loc ,
981
- design_scale = self .model_estim .input_data .design_scale [:, [0 ]],
982
- constraints_loc = self .model_estim .input_data .constraints_loc ,
983
- constraints_scale = self .model_estim .input_data .constraints_scale [[0 ], [0 ]],
984
- size_factors = self .model_estim .input_data .size_factors ,
985
- feature_names = self .model_estim .input_data .features ,
986
- )
987
- estim_ols = Estimator (
988
- input_data = input_data_ols ,
989
- init_model = None ,
990
- init_a = "standard" ,
991
- init_b = "standard" ,
992
- dtype = self .model_estim .a_var .dtype
993
- )
994
- estim_ols .initialize ()
995
- store_ols = estim_ols .finalize ()
983
+ if self ._store_ols is None :
984
+ input_data_ols = InputData .new (
985
+ data = self .model_estim .input_data .data ,
986
+ design_loc = self .model_estim .input_data .design_loc ,
987
+ design_scale = self .model_estim .input_data .design_scale [:, [0 ]],
988
+ constraints_loc = self .model_estim .input_data .constraints_loc ,
989
+ constraints_scale = self .model_estim .input_data .constraints_scale [[0 ], [0 ]],
990
+ size_factors = self .model_estim .input_data .size_factors ,
991
+ feature_names = self .model_estim .input_data .features ,
992
+ )
993
+ estim_ols = Estimator (
994
+ input_data = input_data_ols ,
995
+ init_model = None ,
996
+ init_a = "standard" ,
997
+ init_b = "standard" ,
998
+ dtype = self .model_estim .a_var .dtype
999
+ )
1000
+ estim_ols .initialize ()
1001
+ store_ols = estim_ols .finalize ()
1002
+ self ._store_ols = store_ols
1003
+ else :
1004
+ store_ols = self ._store_ols
996
1005
997
1006
# Prepare parameter summary of both model fits.
998
- par_loc = input_data_ols .data .coords ["design_loc_params" ].values
1007
+ par_loc = self .model_estim .input_data .data .coords ["design_loc_params" ].values
1008
+
1009
+ a_var_ols = store_ols .a_var .values
1010
+ a_var_ols [1 :, :] = (a_var_ols [1 :, :] + a_var_ols [[0 ], :]) / a_var_ols [[0 ], :]
1011
+
1012
+ a_var_user = self .model_estim .a_var .values
1013
+ # Translate coefficients from both fits to be multiplicative in identity space.
999
1014
if self .noise_model == "nb" :
1000
- # Translate coefficients from OLS fit to be multiplicative in identity space.
1001
- a_var_ols = store_ols .a_var .values
1002
- a_var_ols [1 :, :] = (a_var_ols [1 :, :] + a_var_ols [[0 ], :]) / a_var_ols [[0 ], :]
1015
+ a_var_user = np .exp (a_var_user ) # self.model_estim.inverse_link_loc(a_var_user)
1003
1016
elif self .noise_model == "norm" :
1004
- a_var_ols = store_ols . a_var
1017
+ a_var_user [ 1 :, :] = ( a_var_user [ 1 :, :] + a_var_user [[ 0 ], :]) / a_var_user [[ 0 ], :]
1005
1018
else :
1006
1019
raise ValueError ("noise model %s not yet supported for plot_comparison_ols" % self .noise_model )
1007
1020
1008
1021
summaries_fits = [
1009
1022
pd .DataFrame ({
1010
- "user" : self . model_estim . inverse_link_loc ( self . model_estim . a_var [i , :]) ,
1023
+ "user" : a_var_user [i , :],
1011
1024
"ols" : a_var_ols [i , :],
1012
1025
"coef" : par_loc [i ]
1013
1026
}) for i in range (self .model_estim .a_var .shape [0 ])
@@ -1069,6 +1082,174 @@ def plot_comparison_ols(
1069
1082
else :
1070
1083
return
1071
1084
1085
+ def plot_comparison_ols_pred (
1086
+ self ,
1087
+ size = 20 ,
1088
+ log1p_transform : bool = True ,
1089
+ show : bool = True ,
1090
+ save : Union [str , None ] = None ,
1091
+ suffix : str = "_ols_comparison_pred.png" ,
1092
+ row_gap = 0.3 ,
1093
+ col_gap = 0.25 ,
1094
+ return_axs : bool = False
1095
+ ):
1096
+ """
1097
+ Compare location model prediction of inferred model with one obtained from an OLS model.
1098
+
1099
+ Red line shown is the identity line.
1100
+
1101
+ :param size: Size of points.
1102
+ :param log1p_transform: Whether to log1p transform the data.
1103
+ :param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
1104
+ :param save: Path+file name stem to save plots to.
1105
+ File will be save+suffix. Does not save if save is None.
1106
+ :param suffix: Suffix for file name to save plot to. Also use this to set the file type.
1107
+ :param row_gap: Vertical gap between panel rows relative to panel height.
1108
+ :param col_gap: Horizontal gap between panel columns relative to panel width.
1109
+ :param return_axs: Whether to return axis objects.
1110
+
1111
+ :return: Matplotlib axis objects.
1112
+ """
1113
+ import seaborn as sns
1114
+ import matplotlib .pyplot as plt
1115
+ from matplotlib import gridspec
1116
+ from matplotlib import rcParams
1117
+ from batchglm .api .models .glm_norm import Estimator , InputData
1118
+
1119
+ # Run OLS model fit to have comparison coefficients.
1120
+ if self ._store_ols is None :
1121
+ input_data_ols = InputData .new (
1122
+ data = self .model_estim .input_data .data ,
1123
+ design_loc = self .model_estim .input_data .design_loc ,
1124
+ design_scale = self .model_estim .input_data .design_scale [:, [0 ]],
1125
+ constraints_loc = self .model_estim .input_data .constraints_loc ,
1126
+ constraints_scale = self .model_estim .input_data .constraints_scale [[0 ], [0 ]],
1127
+ size_factors = self .model_estim .input_data .size_factors ,
1128
+ feature_names = self .model_estim .input_data .features ,
1129
+ )
1130
+ estim_ols = Estimator (
1131
+ input_data = input_data_ols ,
1132
+ init_model = None ,
1133
+ init_a = "standard" ,
1134
+ init_b = "standard" ,
1135
+ dtype = self .model_estim .a_var .dtype
1136
+ )
1137
+ estim_ols .initialize ()
1138
+ store_ols = estim_ols .finalize ()
1139
+ self ._store_ols = store_ols
1140
+ else :
1141
+ store_ols = self ._store_ols
1142
+
1143
+ # Prepare parameter summary of both model fits.
1144
+ plt .ioff ()
1145
+ nrows = 1
1146
+ ncols = 2
1147
+
1148
+ axs = []
1149
+ gs = gridspec .GridSpec (
1150
+ nrows = nrows ,
1151
+ ncols = ncols ,
1152
+ hspace = row_gap ,
1153
+ wspace = col_gap
1154
+ )
1155
+ fig = plt .figure (
1156
+ figsize = (
1157
+ ncols * rcParams ['figure.figsize' ][0 ], # width in inches
1158
+ nrows * rcParams ['figure.figsize' ][1 ] * (1 + row_gap ) # height in inches
1159
+ )
1160
+ )
1161
+
1162
+ pred_n_cells = sample (
1163
+ population = list (np .arange (0 , self .model_estim .X .shape [0 ])),
1164
+ k = np .min ([20 , self .model_estim .design_loc .shape [0 ]])
1165
+ )
1166
+
1167
+ if isinstance (self .model_estim .X , SparseXArrayDataArray ):
1168
+ x = np .asarray (self .model_estim .X .X [pred_n_cells , :].todense ()).flatten ()
1169
+ else :
1170
+ x = np .asarray (self .model_estim .X [pred_n_cells , :]).flatten ()
1171
+
1172
+ y_user = self .model_estim .inverse_link_loc (
1173
+ np .matmul (self .model_estim .design_loc [pred_n_cells , :].values , self .model_estim .a_var .values ).flatten ()
1174
+ )
1175
+ y_ols = store_ols .inverse_link_loc (
1176
+ np .matmul (store_ols .design_loc [pred_n_cells , :].values , store_ols .a_var .values ).flatten ()
1177
+ )
1178
+ if log1p_transform :
1179
+ x = np .log (x + 1 )
1180
+ y_user = np .log (y_user + 1 )
1181
+ y_ols = np .log (y_ols + 1 )
1182
+
1183
+ y = np .concatenate ([y_user , y_ols ])
1184
+
1185
+ summary0_fit = pd .concat ([
1186
+ pd .DataFrame ({
1187
+ "observed" : y_user ,
1188
+ "predicted" : x ,
1189
+ "model" : ["user" for i in x ]
1190
+ }),
1191
+ pd .DataFrame ({
1192
+ "observed" : y_ols ,
1193
+ "predicted" : x ,
1194
+ "model" : ["OLS" for i in x ]
1195
+ })
1196
+ ])
1197
+
1198
+ ax0 = plt .subplot (gs [0 ])
1199
+ axs .append (ax0 )
1200
+ sns .scatterplot (
1201
+ x = "observed" ,
1202
+ y = "predicted" ,
1203
+ hue = "model" ,
1204
+ data = summary0_fit ,
1205
+ ax = ax0 ,
1206
+ s = size
1207
+ )
1208
+ sns .lineplot (
1209
+ x = np .array ([np .min ([np .min (x ), np .min (y )]), np .max ([np .max (x ), np .max (y )])]),
1210
+ y = np .array ([np .min ([np .min (x ), np .min (y )]), np .max ([np .max (x ), np .max (y )])]),
1211
+ ax = ax0 ,
1212
+ color = "red" ,
1213
+ legend = False
1214
+ )
1215
+ ax0 .set (xlabel = "observed value" , ylabel = "model" )
1216
+
1217
+ summary1_fit = pd .concat ([
1218
+ pd .DataFrame ({
1219
+ "dev" : y_user - x ,
1220
+ "model" : ["user" for i in x ]
1221
+ }),
1222
+ pd .DataFrame ({
1223
+ "dev" : y_ols - x ,
1224
+ "model" : ["OLS" for i in x ]
1225
+ })
1226
+ ])
1227
+
1228
+ ax1 = plt .subplot (gs [1 ])
1229
+ axs .append (ax0 )
1230
+ sns .boxplot (
1231
+ x = "model" ,
1232
+ y = "dev" ,
1233
+ data = summary1_fit ,
1234
+ ax = ax1
1235
+ )
1236
+ ax1 .set (xlabel = "model" , ylabel = "deviation from observations" )
1237
+
1238
+ # Save, show and return figure.
1239
+ if save is not None :
1240
+ plt .savefig (save + suffix )
1241
+
1242
+ if show :
1243
+ plt .show ()
1244
+
1245
+ plt .close (fig )
1246
+ plt .ion ()
1247
+
1248
+ if return_axs :
1249
+ return axs
1250
+ else :
1251
+ return
1252
+
1072
1253
def plot_gene_fits (
1073
1254
self ,
1074
1255
gene_names : Tuple ,
0 commit comments