Skip to content

Post-processing fit() update #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions auto_tutorials_source/tutorial_mc_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathlib import Path

from torch import nn
from torch.utils.data import DataLoader

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
Expand Down Expand Up @@ -84,15 +85,17 @@
# We can now wrap the model in a MCBatchNorm to add stochasticity to the
# predictions. We specify that the BatchNorm layers are to be converted to
# MCBatchNorm layers, and that we want to use 8 stochastic estimators.
# The amount of stochasticity is controlled by the ``mc_batch_size`` argument.
# The larger the ``mc_batch_size``, the more stochastic the predictions will be.
# The authors suggest 32 as a good value for ``mc_batch_size`` but we use 4 here
# The amount of stochasticity is controlled by the ``batch_size`` parameter.
# of the DataLoader used to train the model.
# The larger the ``batch_size``, the more stochastic the predictions will be.
# The authors suggest 32 as a good value for ``batch_size`` but we use 16 here
# to highlight the effect of stochasticity on the predictions.

routine.model = MCBatchNorm(
routine.model, num_estimators=8, convert=True, mc_batch_size=16
routine.model, num_estimators=8, convert=True
)
routine.model.fit(datamodule.train)
mc_batch_norm_dl = DataLoader(datamodule.train, batch_size=16, shuffle=True)
routine.model.fit(dataloader=mc_batch_norm_dl)
routine = routine.eval() # To avoid prints

# %%
Expand Down
3 changes: 2 additions & 1 deletion auto_tutorials_source/tutorial_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
dataset, [1000, 1000, len(dataset) - 2000]
)
test_dataloader = DataLoader(test_dataset, batch_size=32)
calibration_dataloader = DataLoader(cal_dataset, batch_size=32)

# Initialize the ECE
ece = CalibrationError(task="multiclass", num_classes=100)
Expand Down Expand Up @@ -114,7 +115,7 @@

# Fit the scaler on the calibration dataset
scaled_model = TemperatureScaler(model=model)
scaled_model.fit(calibration_set=cal_dataset)
scaled_model.fit(dataloader=calibration_dataloader)

# %%
# 6. Iterating Again to Compute the Improved ECE
Expand Down
6 changes: 3 additions & 3 deletions tests/post_processing/test_laplace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, TensorDataset

from tests._dummies.model import dummy_model
from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing
Expand All @@ -20,12 +20,12 @@ class TestLaplace:
"""Testing the LaplaceApprox class."""

def test_training(self):
ds = TensorDataset(torch.randn(16, 1), torch.randn(16, 10))
dl = DataLoader(TensorDataset(torch.randn(16, 1), torch.randn(16, 10)), batch_size=5)
la = LaplaceApprox(
task="classification",
model=dummy_model(1, 10),
)
la.fit(ds)
la.fit(dl)
la(torch.randn(1, 1))
la = LaplaceApprox(task="classification")
la.set_model(dummy_model(1, 10))
16 changes: 8 additions & 8 deletions tests/post_processing/test_mc_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader

from tests._dummies.dataset import DummyClassificationDataset
from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d
Expand All @@ -17,14 +18,13 @@ class TestMCBatchNorm:
def test_main(self):
"""Test initialization."""
mc_model = lenet(1, 1, norm=partial(MCBatchNorm2d, num_estimators=2))
stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False, mc_batch_size=1)
stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False)

model = lenet(1, 1, norm=nn.BatchNorm2d)
stoch_model = MCBatchNorm(
nn.Sequential(model),
num_estimators=2,
convert=True,
mc_batch_size=1,
)
dataset = DummyClassificationDataset(
"./",
Expand All @@ -34,28 +34,27 @@ def test_main(self):
num_images=2,
transform=T.ToTensor(),
)
stoch_model.fit(dataset=dataset)
dl = DataLoader(dataset, batch_size=1, shuffle=True)
stoch_model.fit(dataloader=dl)
stoch_model.train()
stoch_model(torch.randn(1, 1, 20, 20))
stoch_model.eval()
stoch_model(torch.randn(1, 1, 20, 20))

stoch_model = MCBatchNorm(num_estimators=2, convert=False, mc_batch_size=1)
stoch_model = MCBatchNorm(num_estimators=2, convert=False)
stoch_model.set_model(mc_model)

def test_errors(self):
"""Test errors."""
model = nn.Identity()
with pytest.raises(ValueError):
MCBatchNorm(model, num_estimators=0, convert=True)
with pytest.raises(ValueError, match="mc_batch_size must be a positive integer"):
MCBatchNorm(model, num_estimators=1, convert=True, mc_batch_size=-1)
with pytest.raises(ValueError):
MCBatchNorm(model, num_estimators=1, convert=False)
with pytest.raises(ValueError):
MCBatchNorm(model, num_estimators=1, convert=True)
model = lenet(1, 1, norm=nn.BatchNorm2d)
stoch_model = MCBatchNorm(model, num_estimators=4, convert=True, mc_batch_size=1)
stoch_model = MCBatchNorm(model, num_estimators=4, convert=True)
dataset = DummyClassificationDataset(
"./",
num_channels=1,
Expand All @@ -64,9 +63,10 @@ def test_errors(self):
num_images=2,
transform=T.ToTensor(),
)
dl = DataLoader(dataset, batch_size=2, shuffle=True)
stoch_model.eval()
with pytest.raises(RuntimeError):
stoch_model(torch.randn(1, 1, 20, 20))

with pytest.raises(ValueError):
stoch_model.fit(dataset=dataset)
stoch_model.fit(dataloader=dl)
6 changes: 4 additions & 2 deletions tests/post_processing/test_scalers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch
from torch import nn, softmax
from torch.utils.data import DataLoader

from torch_uncertainty.post_processing import (
MatrixScaler,
Expand All @@ -26,10 +27,11 @@ def test_fit_biased(self):
labels = torch.as_tensor([0.5, 0.5]).repeat(10, 1)

calibration_set = list(zip(inputs, labels, strict=True))
dl = DataLoader(calibration_set, batch_size=10)

scaler = TemperatureScaler(model=nn.Identity(), init_val=2, lr=1, max_iter=10)
assert scaler.temperature[0] == 2.0
scaler.fit(calibration_set)
scaler.fit(dl)
assert scaler.temperature[0] > 10 # best is +inf
assert (
torch.sum(
Expand All @@ -39,7 +41,7 @@ def test_fit_biased(self):
** 2
< 0.001
)
scaler.fit_predict(calibration_set, progress=False)
scaler.fit_predict(dl, progress=False)

def test_errors(self):
with pytest.raises(ValueError):
Expand Down
12 changes: 3 additions & 9 deletions torch_uncertainty/post_processing/abnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader

from torch_uncertainty.layers.bayesian.abnn import BatchNormAdapter2d
from torch_uncertainty.models import deep_ensembles
Expand All @@ -25,7 +25,6 @@ def __init__(
device: torch.device | str,
max_epochs: int = 5,
use_original_model: bool = True,
batch_size: int = 128,
precision: str = "32",
model: nn.Module | None = None,
):
Expand All @@ -45,8 +44,6 @@ def __init__(
to 5.
use_original_model (bool, optional): Use original model during
evaluation. Defaults to True.
batch_size (int, optional): Batch size for the training of ABNN.
Defaults to 128.
precision (str, optional): Machine precision for training & eval.
Defaults to "32".
model (nn.Module | None, optional): Model to use. Defaults to None.
Expand All @@ -63,7 +60,6 @@ def __init__(
num_models=num_models,
num_samples=num_samples,
base_lr=base_lr,
batch_size=batch_size,
)
self.num_classes = num_classes
self.alpha = alpha
Expand All @@ -74,7 +70,6 @@ def __init__(
self.use_original_model = use_original_model
self.max_epochs = max_epochs

self.batch_size = batch_size
self.precision = precision
self.device = device

Expand All @@ -88,10 +83,9 @@ def __init__(
weight[torch.randperm(num_classes)[:num_rp_classes]] += random_prior - 1
self.weights.append(weight)

def fit(self, dataset: Dataset) -> None:
def fit(self, dataloader: DataLoader) -> None:
if self.model is None:
raise ValueError("Model must be set before fitting.")
dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

source_model = copy.deepcopy(self.model)
_replace_bn_layers(source_model, self.alpha)
Expand Down Expand Up @@ -119,7 +113,7 @@ def fit(self, dataset: Dataset) -> None:
logger=None,
enable_model_summary=False,
)
trainer.fit(model=baseline, train_dataloaders=dl)
trainer.fit(model=baseline, train_dataloaders=dataloader)

final_models = (
[copy.deepcopy(source_model) for _ in range(self.num_samples)]
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/post_processing/abstract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod

from torch import Tensor, nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


class PostProcessing(ABC, nn.Module):
Expand All @@ -14,7 +14,7 @@ def set_model(self, model: nn.Module) -> None:
self.model = model

@abstractmethod
def fit(self, dataset: Dataset) -> None:
def fit(self, dataloader: DataLoader) -> None:
pass

@abstractmethod
Expand Down
21 changes: 7 additions & 14 deletions torch_uncertainty/post_processing/calibration/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

from torch_uncertainty.post_processing import PostProcessing
Expand Down Expand Up @@ -47,20 +47,14 @@ def __init__(

def fit(
self,
calibration_set: Dataset,
batch_size: int = 32,
shuffle: bool = False,
drop_last: bool = False,
dataloader: DataLoader,
save_logits: bool = False,
progress: bool = True,
) -> None:
"""Fit the temperature parameters to the calibration data.

Args:
calibration_set (Dataset): Calibration dataset.
batch_size (int, optional): Batch size for the calibration dataset. Defaults to 32.
shuffle (bool, optional): Whether to shuffle the calibration dataset. Defaults to False.
drop_last (bool, optional): Whether to drop the last batch if it's smaller than batch_size. Defaults to False.
dataloader (DataLoader): Dataloader with the calibration data.
save_logits (bool, optional): Whether to save the logits and
labels. Defaults to False.
progress (bool, optional): Whether to show a progress bar.
Expand All @@ -73,9 +67,7 @@ def fit(

all_logits = []
all_labels = []
calibration_dl = DataLoader(
calibration_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
calibration_dl = dataloader
with torch.no_grad():
for inputs, labels in tqdm(calibration_dl, disable=not progress):
logits = self.model(inputs.to(self.device))
Expand Down Expand Up @@ -119,10 +111,11 @@ def _scale(self, logits: Tensor) -> Tensor:

def fit_predict(
self,
calibration_set: Dataset,
# calibration_set: Dataset,
dataloader: DataLoader,
progress: bool = True,
) -> Tensor:
self.fit(calibration_set, save_logits=True, progress=progress)
self.fit(dataloader, save_logits=True, progress=progress)
return self(self.logits)

@property
Expand Down
11 changes: 3 additions & 8 deletions torch_uncertainty/post_processing/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal

from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader

from .abstract import PostProcessing

Expand All @@ -23,7 +23,6 @@ def __init__(
hessian_struct="kron",
pred_type: Literal["glm", "nn"] = "glm",
link_approx: Literal["mc", "probit", "bridge", "bridge_norm"] = "probit",
batch_size: int = 256,
optimize_prior_precision: bool = True,
) -> None:
"""Laplace approximation for uncertainty estimation.
Expand All @@ -42,8 +41,6 @@ def __init__(
link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional):
how to approximate the classification link function for the `'glm'`.
See the Laplace library for more details. Defaults to "probit".
batch_size (int, optional): batch size for the Laplace approximation.
Defaults to 256.
optimize_prior_precision (bool, optional): whether to optimize the prior
precision. Defaults to True.

Expand All @@ -63,7 +60,6 @@ def __init__(
self.task = task
self.weight_subset = weight_subset
self.hessian_struct = hessian_struct
self.batch_size = batch_size
self.optimize_prior_precision = optimize_prior_precision

if model is not None:
Expand All @@ -78,9 +74,8 @@ def set_model(self, model: nn.Module) -> None:
hessian_structure=self.hessian_struct,
)

def fit(self, dataset: Dataset) -> None:
dl = DataLoader(dataset, batch_size=self.batch_size)
self.la.fit(train_loader=dl)
def fit(self, dataloader: DataLoader) -> None:
self.la.fit(train_loader=dataloader)
if self.optimize_prior_precision:
self.la.optimize_prior_precision(method="marglik")

Expand Down
Loading