Skip to content

Commit e0c3095

Browse files
improved multiple fit evaluation funcitons
1 parent 5d92faa commit e0c3095

File tree

2 files changed

+232
-43
lines changed

2 files changed

+232
-43
lines changed

diffxpy/testing/det.py

Lines changed: 228 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ class DifferentialExpressionTestWald(_DifferentialExpressionTestSingle):
758758
"""
759759

760760
model_estim: _Estimation
761+
sample_description: pd.DataFrame
761762
coef_loc_totest: np.ndarray
762763
theta_mle: np.ndarray
763764
theta_sd: np.ndarray
@@ -769,13 +770,15 @@ def __init__(
769770
model_estim: _Estimation,
770771
col_indices: np.ndarray,
771772
noise_model: str,
773+
sample_description: pd.DataFrame
772774
):
773775
"""
774776
:param model_estim:
775777
:param col_indices: indices of coefs to test
776778
"""
777779
super().__init__()
778780

781+
self.sample_description = sample_description
779782
self.model_estim = model_estim
780783
self.coef_loc_totest = col_indices
781784
self.noise_model = noise_model
@@ -909,6 +912,7 @@ def plot_vs_ttest(
909912
return_axs: bool = False
910913
):
911914
"""
915+
Normalizes data by size factors if any were used in model.
912916
913917
:param log10:
914918
:param return_axs: Whether to return axis objects.
@@ -920,8 +924,11 @@ def plot_vs_ttest(
920924
from .tests import t_test
921925

922926
grouping = np.asarray(self.model_estim.design_loc[:, self.coef_loc_totest])
927+
# Normalize by size factors that were used in regression.
928+
sf = np.broadcast_to(np.expand_dims(self.model_estim.size_factors, axis=1),
929+
shape=self.model_estim.X.shape)
923930
ttest = t_test(
924-
data=self.model_estim.X,
931+
data=self.model_estim.X.multiply(1 / sf, copy=True),
925932
grouping=grouping,
926933
gene_names=self.gene_ids,
927934
)
@@ -1250,19 +1257,97 @@ def plot_comparison_ols_pred(
12501257
else:
12511258
return
12521259

1253-
def plot_gene_fits(
1260+
def _assemble_gene_fits(
1261+
self,
1262+
gene_names: Tuple,
1263+
covariate_x: str,
1264+
covariate_hue: Union[None, str],
1265+
log1p_transform: bool,
1266+
):
1267+
"""
1268+
Prepare data for gene-wise model plots.
1269+
1270+
:param gene_names: Genes to generate plots for.
1271+
:param covariate_x: Covariate in location model to partition x-axis by.
1272+
:param covariate_hue: Covariate in location model to stack boxplots by.
1273+
:param log1p_transform: Whether to log transform observations
1274+
before estimating the distribution with boxplot. Model estimates are adjusted accordingly.
1275+
:return summaries_genes: List with data frame for seabron in it.
1276+
"""
1277+
1278+
summaries_genes = []
1279+
for i, g in enumerate(gene_names):
1280+
assert g in self.model_estim.features, "gene %g not found" % g
1281+
g_idx = self.model_estim.features.tolist().index(g)
1282+
# Raw data for boxplot:
1283+
y = self.model_estim.X[:, g_idx]
1284+
# Model fits:
1285+
loc = self.model_estim.location[:, g_idx]
1286+
scale = self.model_estim.scale[:, g_idx]
1287+
if self.noise_model == "nb":
1288+
yhat = np.random.negative_binomial(
1289+
n=scale,
1290+
p=1 - loc / (scale + loc)
1291+
)
1292+
elif self.noise_model == "norm":
1293+
yhat = np.random.normal(
1294+
loc=loc,
1295+
scale=scale
1296+
)
1297+
else:
1298+
raise ValueError("noise model %s not yet supported for plot_gene_fits" % self.noise_model)
1299+
1300+
# Transform observed data:
1301+
if log1p_transform:
1302+
y = np.log(y + 1)
1303+
yhat = np.log(yhat + 1)
1304+
1305+
# Build DataFrame which contains all information for raw data:
1306+
summary_raw = pd.DataFrame({"y": y, "data": "obs"})
1307+
summary_fit = pd.DataFrame({"y": yhat, "data": "fit"})
1308+
if covariate_x is not None:
1309+
assert self.sample_description is not None, "sample_description was not provided to test.wald()"
1310+
if covariate_x in self.sample_description.columns:
1311+
summary_raw["x"] = self.sample_description[covariate_x].values.astype(str)
1312+
summary_fit["x"] = self.sample_description[covariate_x].values.astype(str)
1313+
else:
1314+
raise ValueError("covariate_x=%s not found in location model" % covariate_x)
1315+
else:
1316+
summary_raw["x"] = " "
1317+
summary_fit["x"] = " "
1318+
1319+
if covariate_hue is not None:
1320+
assert self.sample_description is not None, "sample_description was not provided to test.wald()"
1321+
if covariate_hue in self.sample_description.columns:
1322+
summary_raw["hue"] = [str(x)+"_obs" for x in self.sample_description[covariate_hue].values]
1323+
summary_fit["hue"] = [str(x)+"_fit" for x in self.sample_description[covariate_hue].values]
1324+
else:
1325+
raise ValueError("covariate_x=%s not found in location model" % covariate_x)
1326+
else:
1327+
summary_raw["hue"] = "obs"
1328+
summary_fit["hue"] = "fit"
1329+
1330+
summaries = pd.concat([summary_raw, summary_fit])
1331+
summaries.x = pd.Categorical(summaries.x, ordered=True)
1332+
summaries.hue = pd.Categorical(summaries.hue, ordered=True)
1333+
1334+
summaries_genes.append(summaries)
1335+
1336+
return summaries_genes
1337+
1338+
def plot_gene_fits_boxplots(
12541339
self,
12551340
gene_names: Tuple,
12561341
covariate_x: str = None,
12571342
covariate_hue: str = None,
1258-
size_factor_normalize: bool = False,
1259-
log_transform: bool = False,
1343+
log1p_transform: bool = False,
12601344
show: bool = True,
12611345
save: Union[str, None] = None,
1262-
suffix: str = "_genes.png",
1346+
suffix: str = "_genes_boxplot.png",
12631347
ncols=3,
12641348
row_gap=0.3,
12651349
col_gap=0.25,
1350+
xtick_rotation=0,
12661351
return_axs: bool = False
12671352
):
12681353
"""
@@ -1273,9 +1358,7 @@ def plot_gene_fits(
12731358
:param gene_names: Genes to generate plots for.
12741359
:param covariate_x: Covariate in location model to partition x-axis by.
12751360
:param covariate_hue: Covariate in location model to stack boxplots by.
1276-
:param size_factor_normalize: Whether to size_factor normalize observations
1277-
before estimating the distribution with boxplot.
1278-
:param log_transform: Whether to log transform observations
1361+
:param log1p_transform: Whether to log transform observations
12791362
before estimating the distribution with boxplot. Model estimates are adjusted accordingly.
12801363
:param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
12811364
:param save: Path+file name stem to save plots to.
@@ -1284,6 +1367,7 @@ def plot_gene_fits(
12841367
:param ncols: Number of columns in plot grid if multiple genes are plotted.
12851368
:param row_gap: Vertical gap between panel rows relative to panel height.
12861369
:param col_gap: Horizontal gap between panel columns relative to panel width.
1370+
:param xtick_rotation: Angle to rotate x-ticks by.
12871371
:param return_axs: Whether to return axis objects.
12881372
12891373
:return: Matplotlib axis objects.
@@ -1310,51 +1394,137 @@ def plot_gene_fits(
13101394
)
13111395

13121396
axs = []
1397+
summaries = self._assemble_gene_fits(
1398+
gene_names=gene_names,
1399+
covariate_x=covariate_x,
1400+
covariate_hue=covariate_hue,
1401+
log1p_transform=log1p_transform
1402+
)
13131403
for i, g in enumerate(gene_names):
1314-
# Prepare data boxplot.
1315-
y = self.model_estim.X[:, g]
1316-
if size_factor_normalize:
1317-
pass
1318-
if log_transform:
1319-
pass
1320-
1321-
summary_fit = pd.DataFrame({
1322-
"y": y
1323-
})
1324-
if covariate_x is not None:
1325-
if covariate_x in self.model_estim.design_loc.columns:
1326-
summary_fit["x"] = self.model_estim.design_loc[covariate_x]
1327-
else:
1328-
raise ValueError("covariate_x=%s not found in location model" % covariate_x)
1329-
1330-
if covariate_hue is not None:
1331-
if covariate_hue in self.model_estim.design_loc.columns:
1332-
summary_fit["hue"] = self.model_estim.design_loc[covariate_hue]
1333-
else:
1334-
raise ValueError("covariate_x=%s not found in location model" % covariate_x)
1335-
1336-
# Prepare model fit plot.
1337-
if self.noise_model == "nb":
1338-
pass
1339-
elif self.noise_model == "norm":
1340-
pass
1341-
else:
1342-
raise ValueError("noise model %s not yet supported for plot_gene_fits" % self.noise_model)
1343-
13441404
ax = plt.subplot(gs[i])
13451405
axs.append(ax)
13461406

1407+
if log1p_transform:
1408+
ylabel = "log1p expression"
1409+
else:
1410+
ylabel = "expression"
1411+
13471412
sns.boxplot(
13481413
x="x",
13491414
y="y",
13501415
hue="hue",
1351-
data=summary_fit,
1416+
data=summaries[i],
13521417
ax=ax
13531418
)
13541419

1355-
ax.set(xlabel="covariate", ylabel="expression")
1420+
ax.set(xlabel="covariate", ylabel=ylabel)
1421+
ax.set_title(g)
1422+
1423+
plt.xticks(rotation=xtick_rotation)
1424+
1425+
# Save, show and return figure.
1426+
if save is not None:
1427+
plt.savefig(save + suffix)
1428+
1429+
if show:
1430+
plt.show()
1431+
1432+
plt.close(fig)
1433+
plt.ion()
1434+
1435+
if return_axs:
1436+
return axs
1437+
else:
1438+
return
1439+
1440+
def plot_gene_fits_violins(
1441+
self,
1442+
gene_names: Tuple,
1443+
covariate_x: str = None,
1444+
log1p_transform: bool = False,
1445+
show: bool = True,
1446+
save: Union[str, None] = None,
1447+
suffix: str = "_genes_violin.png",
1448+
ncols=3,
1449+
row_gap=0.3,
1450+
col_gap=0.25,
1451+
xtick_rotation=0,
1452+
return_axs: bool = False,
1453+
**kwargs
1454+
):
1455+
"""
1456+
Plot gene-wise model fits and observed distribution by covariates as violins.
1457+
1458+
Use this to inspect fitting performance on individual genes.
1459+
1460+
:param gene_names: Genes to generate plots for.
1461+
:param covariate_x: Covariate in location model to partition x-axis by.
1462+
:param log1p_transform: Whether to log transform observations
1463+
before estimating the distribution with boxplot. Model estimates are adjusted accordingly.
1464+
:param show: Whether (if save is not None) and where (save indicates dir and file stem) to display plot.
1465+
:param save: Path+file name stem to save plots to.
1466+
File will be save+suffix. Does not save if save is None.
1467+
:param suffix: Suffix for file name to save plot to. Also use this to set the file type.
1468+
:param ncols: Number of columns in plot grid if multiple genes are plotted.
1469+
:param row_gap: Vertical gap between panel rows relative to panel height.
1470+
:param col_gap: Horizontal gap between panel columns relative to panel width.
1471+
:param xtick_rotation: Angle to rotate x-ticks by.
1472+
:param return_axs: Whether to return axis objects.
1473+
1474+
:return: Matplotlib axis objects.
1475+
"""
1476+
import seaborn as sns
1477+
import matplotlib.pyplot as plt
1478+
from matplotlib import gridspec
1479+
from matplotlib import rcParams
1480+
1481+
plt.ioff()
1482+
nrows = len(gene_names) // ncols + int((len(gene_names) % ncols) > 0)
1483+
1484+
gs = gridspec.GridSpec(
1485+
nrows=nrows,
1486+
ncols=ncols,
1487+
hspace=row_gap,
1488+
wspace=col_gap
1489+
)
1490+
fig = plt.figure(
1491+
figsize=(
1492+
ncols * rcParams['figure.figsize'][0], # width in inches
1493+
nrows * rcParams['figure.figsize'][1] * (1 + row_gap) # height in inches
1494+
)
1495+
)
1496+
1497+
axs = []
1498+
summaries = self._assemble_gene_fits(
1499+
gene_names=gene_names,
1500+
covariate_x=covariate_x,
1501+
covariate_hue=None,
1502+
log1p_transform=log1p_transform
1503+
)
1504+
for i, g in enumerate(gene_names):
1505+
ax = plt.subplot(gs[i])
1506+
axs.append(ax)
1507+
1508+
if log1p_transform:
1509+
ylabel = "log1p expression"
1510+
else:
1511+
ylabel = "expression"
1512+
1513+
sns.violinplot(
1514+
x="x",
1515+
y="y",
1516+
hue="data",
1517+
split=True,
1518+
data=summaries[i],
1519+
ax=ax,
1520+
**kwargs
1521+
)
1522+
1523+
ax.set(xlabel="covariate", ylabel=ylabel)
13561524
ax.set_title(g)
13571525

1526+
plt.xticks(rotation=xtick_rotation)
1527+
13581528
# Save, show and return figure.
13591529
if save is not None:
13601530
plt.savefig(save + suffix)
@@ -1376,9 +1546,17 @@ class DifferentialExpressionTestTT(_DifferentialExpressionTestSingle):
13761546
Single t-test test per gene.
13771547
"""
13781548

1379-
def __init__(self, data, grouping, gene_names, is_logged):
1549+
def __init__(
1550+
self,
1551+
data,
1552+
sample_description: pd.DataFrame,
1553+
grouping,
1554+
gene_names,
1555+
is_logged
1556+
):
13801557
super().__init__()
13811558
self._X = data
1559+
self.sample_description = sample_description
13821560
self.grouping = grouping
13831561
self._gene_names = np.asarray(gene_names)
13841562

@@ -1484,9 +1662,17 @@ class DifferentialExpressionTestRank(_DifferentialExpressionTestSingle):
14841662
Single rank test per gene (Mann-Whitney U test).
14851663
"""
14861664

1487-
def __init__(self, data, grouping, gene_names, is_logged):
1665+
def __init__(
1666+
self,
1667+
data,
1668+
sample_description: pd.DataFrame,
1669+
grouping,
1670+
gene_names,
1671+
is_logged
1672+
):
14881673
super().__init__()
14891674
self._X = data
1675+
self.sample_description = sample_description
14901676
self.grouping = grouping
14911677
self._gene_names = np.asarray(gene_names)
14921678

diffxpy/testing/tests.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,8 @@ def wald(
612612
de_test = DifferentialExpressionTestWald(
613613
model_estim=model,
614614
col_indices=col_indices,
615-
noise_model=noise_model
615+
noise_model=noise_model,
616+
sample_description=sample_description
616617
)
617618

618619
return de_test
@@ -650,6 +651,7 @@ def t_test(
650651

651652
de_test = DifferentialExpressionTestTT(
652653
data=X.astype(dtype),
654+
sample_description=sample_description,
653655
grouping=grouping,
654656
gene_names=gene_names,
655657
is_logged=is_logged
@@ -690,6 +692,7 @@ def rank_test(
690692

691693
de_test = DifferentialExpressionTestRank(
692694
data=X.astype(dtype),
695+
sample_description=sample_description,
693696
grouping=grouping,
694697
gene_names=gene_names,
695698
is_logged=is_logged

0 commit comments

Comments
 (0)