This is the repository associated with the paper Multivariate Latent Recalibration for Conditional Normalizing Flows.
Abstract: Reliably characterizing the full conditional distribution of a multivariate response variable given a set of covariates is crucial for trustworthy decision-making. However, misspecified or miscalibrated multivariate models may yield a poor approximation of the joint distribution of the response variables, leading to unreliable predictions and suboptimal decisions. Furthermore, standard recalibration methods are primarily limited to univariate settings, while conformal prediction techniques, despite generating multivariate prediction regions with coverage guarantees, do not provide a full probability density function. We address this gap by first introducing a novel notion of latent calibration, which assesses probabilistic calibration in the latent space of a conditional normalizing flow. Second, we propose latent recalibration (LR), a novel post-hoc model recalibration method that learns a transformation of the latent space with finite-sample bounds on latent calibration. Unlike existing methods, LR produces a recalibrated distribution with an explicit multivariate density function while remaining computationally efficient. Extensive experiments on both tabular and image datasets show that LR consistently improves latent calibration error and the negative log-likelihood of the recalibrated models.
For convenience, all datasets except MEPS and AFHQ are directly provided in this repository. MEPS requires accepting data usage agreements [1] and [2] (see step 3 of the installation). Licenses are detailed below.
The following code shows an example usage of the code in this repository.
import torch
from moc.configs.config import get_config
from moc.recalibration import LatentRecalibrator
from moc.datamodules.real_datamodule import RealDataModule
from moc.metrics.distribution_metrics import nll
from moc.metrics.calibration import latent_calibration_error
from moc.models.mqf2.lightning_module import MQF2LightningModule
from moc.models.trainers.lightning_trainer import get_lightning_trainer
from moc.utils.run_config import RunConfig
# Data loading and model training
config = get_config()
config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
rc = RunConfig(config, 'mulan', 'sf2')
datamodule = RealDataModule(rc)
p, q = datamodule.input_dim, datamodule.output_dim
model = MQF2LightningModule(p, q)
trainer = get_lightning_trainer(rc)
trainer.fit(model, datamodule)
model.eval()
# Recalibration
recalibrated_model = LatentRecalibrator(model, datamodule)
# Evaluation on one test batch
test_batch = next(iter(datamodule.test_dataloader()))
x, y = test_batch
x, y = x.to(model.device), y.to(model.device)
dist = recalibrated_model.predict(x)
with torch.no_grad():
nll_value = nll(dist, y).mean()
latent_calibration = latent_calibration_error(dist, y)
print(nll_value)
print(latent_calibration)
- Python >= 3.9
- (Optional) Create and activate a Python virtual environment:
python -m venv venv
source venv/bin/activate
- Install the package with its dependencies:
pip install .
For exact versions ensuring reproducibility, use instead:
pip install -r requirements.txt
- (Optional) For running experiments on MEPS datasets, read the data usage agreements and download the datasets according to these instructions, summarized below:
git clone https://github.com/yromano/cqr
cd cqr/get_meps_data/
Rscript download_data.R
python main_clean_and_save_to_csv.py
cd ../../
for id in 19 20 21; do mv "cqr/get_meps_data/meps_${id}_reg.csv" "data/feldman/meps_${id}.csv"; done
rm -rf cqr
- (Optional) For running experiments on AFHQ, use instructions from clovaai/stargan-v2 summarized below:
URL=https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0
ZIP_FILE=./data/afhq.zip
wget -N $URL -O $ZIP_FILE
unzip $ZIP_FILE -d ./data
rm $ZIP_FILE
To compute the main results of the paper:
python run.py name="lr" datasets="real" repeat_tuning=10 tuning_type="lr_mqf2" device="cuda" default_batch_size=1024
Other results can be computed using:
# Masked auto-regressive flow
python run.py name="lr" datasets="real" repeat_tuning=10 tuning_type="lr_arflow" device="cuda" default_batch_size=512
# Misspecified model
python run.py name="lr" datasets="real" repeat_tuning=10 tuning_type="lr_mqf2" device="cuda" default_batch_size=1024
# TarFlow with noisy targets
python run.py name="lr_tarflow_noisy" datasets="afhq" repeat_tuning=20 tuning_type="lr_tarflow" default_batch_size=256 device="cuda" only_cheap_metrics=True afhq_noise=0.07
# TarFlow without noise on the targets
python run.py name="lr_tarflow_no_noise" datasets="afhq" tuning_type="lr_tarflow" default_batch_size=256 device="cuda" only_cheap_metrics=True afhq_noise=0
Then, plots and tables can be generated using analysis_lr.ipynb
.
Other figures can be generated using visualizations_lr.ipynb
.
This project utilizes data from various sources, each with its own licensing terms. The table below details the origin and license for each dataset.
Dataset Group / Name | Data Source(s) | License(s) |
---|---|---|
Mulan Datasets | Main repository: tsoumakas/mulan Original sources: README |
GPL |
house, bio, blog_data | bio, house, blog_data | CC BY 4.0 |
meps_19, meps_20, meps_21 | Data from the Agency for Healthcare Research and Quality (MEPS) | Requires accepting Data Usage Agreements 1 and 2 |
ansur2 | Data from the Open Design Lab | Public Domain |
births1, births2 | Data from the Centers for Disease Control and Prevention Code for data processing: lorismichel/drf |
Data: Public Domain Code: GPL 3.0 |
wage | Data from American Community Survey Data Code for data processing: lorismichel/drf |
Data: Public Domain Code: GPL 3.0 |
air | Data from U.S. Environmental Protection Agency Code for data processing: lorismichel/drf |
Data: Public Domain Code: GPL 3.0 |
taxi | Data: NYC Open Data Code for data processing: Zhendong-Wang/Probabilistic-Conformal-Prediction |
Data: Public Domain Code: MIT |
calcofi | Data from the California Cooperative Oceanic Fisheries Investigations | CC BY 4.0 |
households | Data: U.S. Bureau of Labor Statistics (Consumer Expenditure Survey PUMD) Code for data processing: aschnuecker/Superlevel-sets |
Data: Public Domain Code: GPL 3.0 |
afhq | Data from clovaai/stargan-v2 | CC BY-NC 4.0 |