Skip to content

Commit f6dbf64

Browse files
improved model fit result evaluation via comparison to OLS
1 parent 82d6486 commit f6dbf64

File tree

1 file changed

+207
-26
lines changed

1 file changed

+207
-26
lines changed

diffxpy/testing/det.py

Lines changed: 207 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from typing import Union, Dict, Tuple, List, Set
44
import pandas as pd
5+
from random import sample
56

67
import numpy as np
78
import xarray as xr
@@ -14,6 +15,7 @@
1415
except ImportError:
1516
anndata = None
1617

18+
from batchglm.xarray_sparse.base import SparseXArrayDataArray
1719
from batchglm.models.glm_nb import Model as GeneralizedLinearModel
1820

1921
from ..stats import stats
@@ -777,6 +779,7 @@ def __init__(
777779
self.model_estim = model_estim
778780
self.coef_loc_totest = col_indices
779781
self.noise_model = noise_model
782+
self._store_ols = None
780783

781784
try:
782785
if model_estim._error_codes is not None:
@@ -940,12 +943,12 @@ def plot_vs_ttest(
940943
else:
941944
return
942945

943-
def plot_comparison_ols(
946+
def plot_comparison_ols_coef(
944947
self,
945948
size=20,
946949
show: bool = True,
947950
save: Union[str, None] = None,
948-
suffix: str = "_ols_comparison.png",
951+
suffix: str = "_ols_comparison_coef.png",
949952
ncols=3,
950953
row_gap=0.3,
951954
col_gap=0.25,
@@ -955,6 +958,8 @@ def plot_comparison_ols(
955958
Plot location model coefficients of inferred model against those obtained from an OLS model.
956959
957960
Red line shown is the identity line.
961+
Note that this comparison only seems to be useful if the covariates are zero centred. This is
962+
especially important for continuous covariates.
958963
959964
:param size: Size of points.
960965
:param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
@@ -975,39 +980,47 @@ def plot_comparison_ols(
975980
from batchglm.api.models.glm_norm import Estimator, InputData
976981

977982
# Run OLS model fit to have comparison coefficients.
978-
input_data_ols = InputData.new(
979-
data=self.model_estim.input_data.data,
980-
design_loc=self.model_estim.input_data.design_loc,
981-
design_scale=self.model_estim.input_data.design_scale[:, [0]],
982-
constraints_loc=self.model_estim.input_data.constraints_loc,
983-
constraints_scale=self.model_estim.input_data.constraints_scale[[0], [0]],
984-
size_factors=self.model_estim.input_data.size_factors,
985-
feature_names=self.model_estim.input_data.features,
986-
)
987-
estim_ols = Estimator(
988-
input_data=input_data_ols,
989-
init_model=None,
990-
init_a="standard",
991-
init_b="standard",
992-
dtype=self.model_estim.a_var.dtype
993-
)
994-
estim_ols.initialize()
995-
store_ols = estim_ols.finalize()
983+
if self._store_ols is None:
984+
input_data_ols = InputData.new(
985+
data=self.model_estim.input_data.data,
986+
design_loc=self.model_estim.input_data.design_loc,
987+
design_scale=self.model_estim.input_data.design_scale[:, [0]],
988+
constraints_loc=self.model_estim.input_data.constraints_loc,
989+
constraints_scale=self.model_estim.input_data.constraints_scale[[0], [0]],
990+
size_factors=self.model_estim.input_data.size_factors,
991+
feature_names=self.model_estim.input_data.features,
992+
)
993+
estim_ols = Estimator(
994+
input_data=input_data_ols,
995+
init_model=None,
996+
init_a="standard",
997+
init_b="standard",
998+
dtype=self.model_estim.a_var.dtype
999+
)
1000+
estim_ols.initialize()
1001+
store_ols = estim_ols.finalize()
1002+
self._store_ols = store_ols
1003+
else:
1004+
store_ols = self._store_ols
9961005

9971006
# Prepare parameter summary of both model fits.
998-
par_loc = input_data_ols.data.coords["design_loc_params"].values
1007+
par_loc = self.model_estim.input_data.data.coords["design_loc_params"].values
1008+
1009+
a_var_ols = store_ols.a_var.values
1010+
a_var_ols[1:, :] = (a_var_ols[1:, :] + a_var_ols[[0], :]) / a_var_ols[[0], :]
1011+
1012+
a_var_user = self.model_estim.a_var.values
1013+
# Translate coefficients from both fits to be multiplicative in identity space.
9991014
if self.noise_model == "nb":
1000-
# Translate coefficients from OLS fit to be multiplicative in identity space.
1001-
a_var_ols = store_ols.a_var.values
1002-
a_var_ols[1:, :] = (a_var_ols[1:, :] + a_var_ols[[0], :]) / a_var_ols[[0], :]
1015+
a_var_user = np.exp(a_var_user) # self.model_estim.inverse_link_loc(a_var_user)
10031016
elif self.noise_model == "norm":
1004-
a_var_ols = store_ols.a_var
1017+
a_var_user[1:, :] = (a_var_user[1:, :] + a_var_user[[0], :]) / a_var_user[[0], :]
10051018
else:
10061019
raise ValueError("noise model %s not yet supported for plot_comparison_ols" % self.noise_model)
10071020

10081021
summaries_fits = [
10091022
pd.DataFrame({
1010-
"user": self.model_estim.inverse_link_loc(self.model_estim.a_var[i, :]),
1023+
"user": a_var_user[i, :],
10111024
"ols": a_var_ols[i, :],
10121025
"coef": par_loc[i]
10131026
}) for i in range(self.model_estim.a_var.shape[0])
@@ -1069,6 +1082,174 @@ def plot_comparison_ols(
10691082
else:
10701083
return
10711084

1085+
def plot_comparison_ols_pred(
1086+
self,
1087+
size=20,
1088+
log1p_transform: bool = True,
1089+
show: bool = True,
1090+
save: Union[str, None] = None,
1091+
suffix: str = "_ols_comparison_pred.png",
1092+
row_gap=0.3,
1093+
col_gap=0.25,
1094+
return_axs: bool = False
1095+
):
1096+
"""
1097+
Compare location model prediction of inferred model with one obtained from an OLS model.
1098+
1099+
Red line shown is the identity line.
1100+
1101+
:param size: Size of points.
1102+
:param log1p_transform: Whether to log1p transform the data.
1103+
:param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
1104+
:param save: Path+file name stem to save plots to.
1105+
File will be save+suffix. Does not save if save is None.
1106+
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
1107+
:param row_gap: Vertical gap between panel rows relative to panel height.
1108+
:param col_gap: Horizontal gap between panel columns relative to panel width.
1109+
:param return_axs: Whether to return axis objects.
1110+
1111+
:return: Matplotlib axis objects.
1112+
"""
1113+
import seaborn as sns
1114+
import matplotlib.pyplot as plt
1115+
from matplotlib import gridspec
1116+
from matplotlib import rcParams
1117+
from batchglm.api.models.glm_norm import Estimator, InputData
1118+
1119+
# Run OLS model fit to have comparison coefficients.
1120+
if self._store_ols is None:
1121+
input_data_ols = InputData.new(
1122+
data=self.model_estim.input_data.data,
1123+
design_loc=self.model_estim.input_data.design_loc,
1124+
design_scale=self.model_estim.input_data.design_scale[:, [0]],
1125+
constraints_loc=self.model_estim.input_data.constraints_loc,
1126+
constraints_scale=self.model_estim.input_data.constraints_scale[[0], [0]],
1127+
size_factors=self.model_estim.input_data.size_factors,
1128+
feature_names=self.model_estim.input_data.features,
1129+
)
1130+
estim_ols = Estimator(
1131+
input_data=input_data_ols,
1132+
init_model=None,
1133+
init_a="standard",
1134+
init_b="standard",
1135+
dtype=self.model_estim.a_var.dtype
1136+
)
1137+
estim_ols.initialize()
1138+
store_ols = estim_ols.finalize()
1139+
self._store_ols = store_ols
1140+
else:
1141+
store_ols = self._store_ols
1142+
1143+
# Prepare parameter summary of both model fits.
1144+
plt.ioff()
1145+
nrows = 1
1146+
ncols = 2
1147+
1148+
axs = []
1149+
gs = gridspec.GridSpec(
1150+
nrows=nrows,
1151+
ncols=ncols,
1152+
hspace=row_gap,
1153+
wspace=col_gap
1154+
)
1155+
fig = plt.figure(
1156+
figsize=(
1157+
ncols * rcParams['figure.figsize'][0], # width in inches
1158+
nrows * rcParams['figure.figsize'][1] * (1 + row_gap) # height in inches
1159+
)
1160+
)
1161+
1162+
pred_n_cells = sample(
1163+
population=list(np.arange(0, self.model_estim.X.shape[0])),
1164+
k=np.min([20, self.model_estim.design_loc.shape[0]])
1165+
)
1166+
1167+
if isinstance(self.model_estim.X, SparseXArrayDataArray):
1168+
x = np.asarray(self.model_estim.X.X[pred_n_cells, :].todense()).flatten()
1169+
else:
1170+
x = np.asarray(self.model_estim.X[pred_n_cells, :]).flatten()
1171+
1172+
y_user = self.model_estim.inverse_link_loc(
1173+
np.matmul(self.model_estim.design_loc[pred_n_cells, :].values, self.model_estim.a_var.values).flatten()
1174+
)
1175+
y_ols = store_ols.inverse_link_loc(
1176+
np.matmul(store_ols.design_loc[pred_n_cells, :].values, store_ols.a_var.values).flatten()
1177+
)
1178+
if log1p_transform:
1179+
x = np.log(x+1)
1180+
y_user = np.log(y_user + 1)
1181+
y_ols = np.log(y_ols + 1)
1182+
1183+
y = np.concatenate([y_user, y_ols])
1184+
1185+
summary0_fit = pd.concat([
1186+
pd.DataFrame({
1187+
"observed": y_user,
1188+
"predicted": x,
1189+
"model": ["user" for i in x]
1190+
}),
1191+
pd.DataFrame({
1192+
"observed": y_ols,
1193+
"predicted": x,
1194+
"model": ["OLS" for i in x]
1195+
})
1196+
])
1197+
1198+
ax0 = plt.subplot(gs[0])
1199+
axs.append(ax0)
1200+
sns.scatterplot(
1201+
x="observed",
1202+
y="predicted",
1203+
hue="model",
1204+
data=summary0_fit,
1205+
ax=ax0,
1206+
s=size
1207+
)
1208+
sns.lineplot(
1209+
x=np.array([np.min([np.min(x), np.min(y)]), np.max([np.max(x), np.max(y)])]),
1210+
y=np.array([np.min([np.min(x), np.min(y)]), np.max([np.max(x), np.max(y)])]),
1211+
ax=ax0,
1212+
color="red",
1213+
legend=False
1214+
)
1215+
ax0.set(xlabel="observed value", ylabel="model")
1216+
1217+
summary1_fit = pd.concat([
1218+
pd.DataFrame({
1219+
"dev": y_user-x,
1220+
"model": ["user" for i in x]
1221+
}),
1222+
pd.DataFrame({
1223+
"dev": y_ols-x,
1224+
"model": ["OLS" for i in x]
1225+
})
1226+
])
1227+
1228+
ax1 = plt.subplot(gs[1])
1229+
axs.append(ax0)
1230+
sns.boxplot(
1231+
x="model",
1232+
y="dev",
1233+
data=summary1_fit,
1234+
ax=ax1
1235+
)
1236+
ax1.set(xlabel="model", ylabel="deviation from observations")
1237+
1238+
# Save, show and return figure.
1239+
if save is not None:
1240+
plt.savefig(save + suffix)
1241+
1242+
if show:
1243+
plt.show()
1244+
1245+
plt.close(fig)
1246+
plt.ion()
1247+
1248+
if return_axs:
1249+
return axs
1250+
else:
1251+
return
1252+
10721253
def plot_gene_fits(
10731254
self,
10741255
gene_names: Tuple,

0 commit comments

Comments
 (0)