Skip to content

Add metric avg batch r2 #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

* Added `methods/cycombine_nocontrol` (PR #35).

* Added `metrics/average_batch_r2` + helper function (#PR36)

## MAJOR CHANGES

* Updated file schema (PR #18):
Expand Down
92 changes: 92 additions & 0 deletions src/metrics/average_batch_r2/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

__merge__: ../../api/comp_metric.yaml

name: average_batch_r2

# Metadata for your component
info:
metrics:
# A unique identifier for your metric (required).
# Can contain only lowercase letters or underscores.
- name: average_batch_r2
# A relatively short label, used when rendering visualisarions (required)
label: Average Batch R-squared ($\overline{R^2_B}$)
# A one sentence summary of how this metric works (required). Used when
# rendering summary tables.
summary: "The average batch R-squared quantifies, on average, how strongly the batch variable B explains the variance in the data."
# A multi-line description of how this component works (required). Used
# when rendering reference documentation.
description: |
First, a simple linear model `sklearn.linear_model.LinearRegression` is fitted for each paired sample, marker (and cell type) to determine the fraction of variance (R^2) explained by the batch covariate B. |
The average batch R_squared is then computed as the average of the $R^2$ values across all paired samples, markers (and cell types). |
As a result, $\overline{R^2_B}$ quantifies how much of the total variability in the data is driven by batch effects. Consequently, a lower values are desirable. |

$\overline{R^2_B} \text{} = \frac{1}{N*C*M}\sum_{\substack{(x_{\mathrm{int}},\,x_{\mathrm{val}})\\ \text{paired samples}}}^{N} \sum_{j=1}^{C} \sum_{i=1}^{M}\,R^2\!\bigl(\mathrm{marker}_i \mid B\bigr)$

Where:
- $N$ is the number of paired samples, where x_{\mathrm{int}} is the replicate that has been batch-corrected and x_{\mathrm{val}} is replicate used for validation. Paired samples belong to different batches.
- $C$ is the number of cell types
- $M$ is the number of markers
- $B$ is the batch covariate

The $\overline{Rˆ2_B}_{global}$ is a variation of the latter metric, where the average is computed across paired samples and markers only, without taking into account the cell types. |

$\overline{R^2_B}_{global} = \frac{1}{N*M}\sum_{\substack{(x_{\mathrm{int}},\,x_{\mathrm{val}})\\ \text{paired samples}}}^{N} \sum_{i=1}^{M} \,R^2\!\bigl(\mathrm{marker}_i \mid B\bigr)$

A higher value of $\overline{R^2_B}$ indicates that the batch variable explains more of the variance in the data, which indicates a higher level of batch effects. |

A good performance on $\overline{R^2_B}_{global} but not on $\overline{R^2_B}$ might indicate that the batch effect correction is discarding cell type specific batch effects. |

references:
bibtex:
- |
@book{draper1998applied,
title={Applied regression analysis},
author={Draper, Norman R and Smith, Harry},
publisher={John Wiley \& Sons}
}
links:
# URL to the documentation for this metric (required).
documentation: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html
# URL to the code repository for this metric (required).
repository: https://github.com/openproblems-bio/task_cyto_batch_integration/tree/add_metric_avg_batch_r2/src/metrics/average_batch_r2
# The minimum possible value for this metric (required)
min: -0.001
# The maximum possible value for this metric (required)
max: 1
# Whether a higher value represents a 'better' solution (required)
maximize: false

# Component-specific parameters (optional)
# arguments:
# - name: "--n_neighbors"
# type: "integer"
# default: 5
# description: Number of neighbors to use.

# Resources required to run the component
resources:
# The script of your component (required)
- type: python_script
path: script.py
- path: helper.py
- path: /src/utils/helper_functions.py


engines:
# Specifications for the Docker image for this component.
- type: docker
image: openproblems/base_python:1.0.0
# Add custom dependencies here (optional). For more information, see
# https://viash.io/reference/config/engines/docker/#setup .
# setup:
# - type: python
# packages: numpy<2

runners:
# This platform allows running the component natively
- type: executable
# Allows turning the component into a Nextflow module / pipeline.
- type: nextflow
directives:
label: [midtime,midmem,midcpu]
88 changes: 88 additions & 0 deletions src/metrics/average_batch_r2/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
import pandas as pd
import anndata as ad


def concat_paired_samples(adata_i: ad.AnnData,
adata_v:ad.AnnData) -> pd.DataFrame:
'''
Concatenate the integrated and validation datasets to a single dataframe. Columns are markers, except last column ('batch')

Inputs:
adata_i: AnnData object, batch-integrated dataset
adata_v: AnnData object, validation dataset

Returns:
df: pd.DataFrame with the integrated and validation datasets concatenated
'''
assert adata_i.shape[1] == adata_v.shape[1], "The number of markers in the integrated and validation datasets do not match"

df_int = adata_i.to_df(layer='integrated')
df_int['batch'] = adata_i.obs['batch']

df_val = adata_v.to_df(layer='preprocessed')
df_val['batch'] = adata_v.obs['batch']

df = pd.concat([df_int, df_val], axis=0)

return df

def fit_r2(df: pd.DataFrame, markername: str) -> float:
''''
Fit a linear regression model to the marker expression data and calculate the R^2 value.

Inputs:
df: pd.DataFrame from `concat_paired_samples`
markername: str, name of the marker to fit the R^2 value for

Outputs:
r2: float, The Rˆ2 represents the proportion of the variance explained by the batch variable in the marker expression.
'''
from sklearn.linear_model import LinearRegression

data = df[[markername, 'batch']]
X = pd.get_dummies(data['batch'], drop_first=True)
Y = data[markername]
model = LinearRegression().fit(X,Y)
r2 = model.score(X,Y)

# #Uncomment for debugging
# import seaborn as sns
# import matplotlib.pyplot as plt
# sns.scatterplot(x='batch', y=markername, data=data)
# plt.show()
# print(markername,"r2 =",r2)

return r2


def batch_r2(adata_i: ad.AnnData,
adata_v:ad.AnnData) -> (list, list):
'''
Calculate the batch R^2 metric given 2 paired samples (integrated and validation).
For each marker, the function calculates the R^2 value between the marker expression and batch covariate.
Note: since adata_i and adata_v are paired samples, they have to come from the same donor.

Inputs:
adata_i: AnnData object, batch-integrated dataset
adata_v: AnnData object, validation dataset

Outputs:
markers_r2: list of floats, R^2 values for each marker
markerlist: list of str, marker
'''

assert np.unique(adata_i.obs[ 'donor']) == np.unique(adata_v.obs[ 'donor']), "The donors in the integrated and validation datasets do not match"

df = concat_paired_samples(adata_i, adata_v)

markers_r2 = []
markerlist = []
for marker in df.columns:
if marker != 'batch':
r2 = fit_r2(df, marker)
markers_r2.append(r2)
markerlist.append(marker)


return markers_r2,markerlist
96 changes: 96 additions & 0 deletions src/metrics/average_batch_r2/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import anndata as ad
import sys
import numpy as np

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input_validation': 'resources_test/.../validation.h5ad',
'input_unintegrated': 'resources_test/.../unintegrated.h5ad',
'input_integrated': 'resources_test/.../integrated.h5ad',
'output': 'output.h5ad'
}
meta = {
'name': 'average_batch_r2'
}
## VIASH END

sys.path.append(meta["resources_dir"])
from helper import concat_paired_samples, fit_r2, batch_r2
from helper_functions import get_obs_var_for_integrated, subset_nocontrols, subset_markers_tocorrect, remove_unlabelled

print('Reading input files', flush=True)
input_validation = ad.read_h5ad(par['input_validation'])
input_unintegrated = ad.read_h5ad(par['input_unintegrated'])
input_integrated = ad.read_h5ad(par['input_integrated'])

print('Formatting input files', flush=True)
#Format data integrated data
input_integrated = get_obs_var_for_integrated(input_integrated,input_validation,input_unintegrated)
input_integrated = subset_markers_tocorrect(input_integrated)
input_integrated = subset_nocontrols(input_integrated)
#Format validation data
input_validation = subset_markers_tocorrect(input_validation)

#### TEMPORARY SOLUTION: change the get_obs_var_for_integrated to return a different batch in case of perfect integration
## adding 3 to the batch number otherwise the batch number is the same for integrated and validation data
if input_integrated.uns['method_id'] == 'perfect_integration':
input_integrated.obs['batch'] = input_integrated.obs['batch'] + 3
######################################################################################################################

print('Computing average_batch_r2 global', flush=True)

donor_list = input_validation.obs['donor'].unique()

r2_values = []
for donor in donor_list:
integrated_view = input_integrated[input_integrated.obs['donor'] == donor]
validation_view = input_validation[input_validation.obs['donor'] == donor]

if integrated_view.shape[0] < 10 or validation_view.shape[0] < 10: #Skip Rˆ2 calculation if there are less than 10 cells
print(f"Warning: Rˆ2 not computed for donor {donor}. Too few cells were present: {integrated_view.shape[0]} for integrated and {validation_view.shape[0]} for validation")
continue

r2_list,_ = batch_r2(integrated_view, validation_view)
r2_values = [*r2_values, *r2_list]

average_batch_r2_global = np.mean(r2_values)


print('Computing average_batch_r2 cell-type specific', flush=True)

r2_values = []
for donor in donor_list:
integrated_view = input_integrated[input_integrated.obs['donor'] == donor]
integrated_view = remove_unlabelled(integrated_view)
validation_view = input_validation[input_validation.obs['donor'] == donor]
validation_view = remove_unlabelled(validation_view)

ct_list = validation_view.obs['cell_type'].unique()

for ct in ct_list:
integrated_view_ct = integrated_view[integrated_view.obs['cell_type'] == ct]
validation_view_ct = validation_view[validation_view.obs['cell_type'] == ct]
if integrated_view_ct.shape[0] < 10 or validation_view_ct.shape[0] < 10: #Skip Rˆ2 calculation if there are less than 10 cells
print(f"Warning: Rˆ2 not computed for donor {donor} cell type {ct}. Too few cells were present: {integrated_view_ct.shape[0]} for integrated and {validation_view_ct.shape[0]} for validation")
continue

r2_list,_ = batch_r2(integrated_view_ct, validation_view_ct)
r2_values = [*r2_values, *r2_list]

average_batch_r2_ct = np.mean(r2_values)

uns_metric_ids = [ 'average_batch_r2_global', 'average_batch_r2_ct' ]
uns_metric_values = [ average_batch_r2_global, average_batch_r2_ct ]

print("Write output AnnData to file", flush=True)
output = ad.AnnData(
uns={
'dataset_id': input_integrated.uns['dataset_id'],
'method_id': input_integrated.uns['method_id'],
'metric_ids': uns_metric_ids,
'metric_values': uns_metric_values
}
)
output.write_h5ad(par['output'], compression='gzip')
19 changes: 18 additions & 1 deletion src/utils/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,21 @@ def subset_markers_tocorrect(adata)-> ad.AnnData:
adata = adata[:,adata.var['to_correct']].copy()

return adata


def remove_unlabelled(adata)-> ad.AnnData:
'''
Subsets the anndata object to remove all cells where the marker is not labelled.
This is determined by the column 'cell_type' in adata.obs.
Particularly usefull when dealing with cell type specific metrics

Inputs:
adata: AnnData object

Outputs:
adata: AnnData object with only the labeled cells
'''

adata = adata[adata.obs['cell_type'].str.lower() != 'unlabelled'].copy()
adata = adata[adata.obs['cell_type'].str.lower() != 'unlabeled'].copy()

return adata