Skip to content

✨ LeNet BatchEnsemble and Deep Ensemble + minor bugfixes #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions experiments/classification/mnist/configs/lenet_batch_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
fast_dev_run: false
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 10
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/lenet_trajectory
name: batch_ensemble
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/cls/Acc
patience: 1000
check_finite: true
model:
# ClassificationRoutine
model:
# BatchEnsemble
class_path: torch_uncertainty.models.lenet.batchensemble_lenet
init_args:
in_channels: 1
num_classes: 10
num_estimators: 5
activation: torch.nn.ReLU
norm: torch.nn.BatchNorm2d
groups: 1
dropout_rate: 0
num_classes: 10
loss: CrossEntropyLoss
is_ensemble: true
format_batch_fn:
class_path: torch_uncertainty.transforms.batch.RepeatTarget
init_args:
num_repeats: 5
data:
root: ./data
batch_size: 128
num_workers: 127
eval_ood: true
eval_shift: true
optimizer:
lr: 0.05
momentum: 0.9
weight_decay: 5e-4
nesterov: true
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ optimizer:
weight_decay: 5e-4
nesterov: true
lr_scheduler:
milestones:
- 25
- 50
gamma: 0.1
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
78 changes: 78 additions & 0 deletions experiments/classification/mnist/configs/lenet_deep_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
fast_dev_run: false
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 10
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/lenet_trajectory
name: deep_ensemble
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/cls/Acc
patience: 1000
check_finite: true
model:
# ClassificationRoutine
model:
# DeepEnsemble
class_path: torch_uncertainty.models.wrappers.deep_ensembles.deep_ensembles
init_args:
models:
# LeNet
class_path: torch_uncertainty.models.lenet._LeNet
init_args:
in_channels: 1
num_classes: 10
linear_layer: torch.nn.Linear
conv2d_layer: torch.nn.Conv2d
activation: torch.nn.ReLU
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
# last_layer_dropout: false
layer_args: {}
num_estimators: 5
task: classification
probabilistic: false
reset_model_parameters: true
num_classes: 10
loss: CrossEntropyLoss
is_ensemble: true
format_batch_fn:
class_path: torch_uncertainty.transforms.batch.RepeatTarget
init_args:
num_repeats: 5
data:
root: ./data
batch_size: 128
num_workers: 127
eval_ood: true
eval_shift: true
optimizer:
lr: 0.05
momentum: 0.9
weight_decay: 5e-4
nesterov: true
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
10 changes: 6 additions & 4 deletions experiments/classification/mnist/configs/lenet_ema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ optimizer:
weight_decay: 5e-4
nesterov: true
lr_scheduler:
milestones:
- 25
- 50
gamma: 0.1
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
3 changes: 2 additions & 1 deletion tests/models/test_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import nn

from torch_uncertainty.models.lenet import bayesian_lenet, lenet, packed_lenet
from torch_uncertainty.models.lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet


class TestLeNet:
Expand All @@ -18,6 +18,7 @@ def test_main(self):
model.eval()
model(torch.randn(1, 1, 20, 20))

batchensemble_lenet(1, 1)
packed_lenet(1, 1)
bayesian_lenet(1, 1)
bayesian_lenet(
Expand Down
36 changes: 36 additions & 0 deletions tests/models/wrappers/test_batch_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
import torch
from torch import nn

from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble


# Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture)
class SimpleModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.r_group = nn.Parameter(torch.randn(in_features))
self.s_group = nn.Parameter(torch.randn(out_features))
self.bias = nn.Parameter(torch.randn(out_features))

def forward(self, x):
return self.fc(x)


# Test the BatchEnsemble wrapper
def test_batch_ensemble():
in_features = 10
out_features = 5
num_estimators = 3
model = SimpleModel(in_features, out_features)
wrapped_model = BatchEnsemble(model, num_estimators)

# Test forward pass
x = torch.randn(2, in_features) # Batch size of 2
logits = wrapped_model(x)
assert logits.shape == (2 * num_estimators, out_features)


if __name__ == "__main__":
pytest.main([__file__])
14 changes: 12 additions & 2 deletions torch_uncertainty/layers/batch_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ def __init__(
:math:`H_{out} = \text{out_features}`.

Warning:
Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`.
Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`.
In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators`
times along the first axis. Incorrect batch size may lead to unexpected results.

To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically
repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details.


Examples:
>>> # With three estimators
Expand Down Expand Up @@ -273,8 +279,12 @@ def __init__(
{\text{stride}[1]} + 1\right\rfloor

Warning:
Make sure that :attr:`num_estimators` divides :attr:`out_channels` when calling :func:`forward()`.
Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`.
In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators`
times along the first axis. Incorrect batch size may lead to unexpected results.

To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically
repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details.

Examples:
>>> # With square kernels, four estimators and equal stride
Expand Down
28 changes: 28 additions & 0 deletions torch_uncertainty/models/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import torch.nn.functional as F
from torch import nn

from torch_uncertainty.layers.batch_ensemble import BatchConv2d, BatchLinear
from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear
from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d
from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear
from torch_uncertainty.models import StochasticModel
from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble

__all__ = ["bayesian_lenet", "lenet", "packed_lenet"]

Expand Down Expand Up @@ -119,6 +121,32 @@ def lenet(
)


def batchensemble_lenet(
in_channels: int,
num_classes: int,
num_estimators: int = 4,
activation: Callable = F.relu,
norm: type[nn.Module] = nn.BatchNorm2d,
groups: int = 1,
dropout_rate: float = 0.0,
) -> _LeNet:
model = _lenet(
stochastic=False,
in_channels=in_channels,
num_classes=num_classes,
linear_layer=BatchLinear,
conv2d_layer=BatchConv2d,
layer_args={
"num_estimators": num_estimators,
},
activation=activation,
norm=norm,
groups=groups,
dropout_rate=dropout_rate,
)
return BatchEnsemble(model, num_estimators)


def packed_lenet(
in_channels: int,
num_classes: int,
Expand Down
36 changes: 36 additions & 0 deletions torch_uncertainty/models/wrappers/batch_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torch import nn


class BatchEnsemble(nn.Module):
"""Wraps a BatchEnsemble model to ensure correct batch replication.

In a BatchEnsemble architecture, each estimator operates on a **sub-batch**
of the input. This means that the input batch must be **repeated**
:attr:`num_estimators` times before being processed.

This wrapper automatically **duplicates the input batch** along the first axis,
ensuring that each estimator receives the correct data format.

**Usage Example:**
```python
model = lenet(in_channels=1, num_classes=10)
wrapped_model = BatchEnsembleWrapper(model, num_estimators=5)
logits = wrapped_model(x) # `x` is automatically repeated `num_estimators` times
```

Args:
model (nn.Module): The BatchEnsemble model.
num_estimators (int): Number of ensemble members.
"""

def __init__(self, model: nn.Module, num_estimators: int):
super().__init__()
self.model = model
self.num_estimators = num_estimators

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Repeats the input batch and passes it through the model."""
repeat_shape = [self.num_estimators] + [1] * (x.dim() - 1)
x = x.repeat(repeat_shape)
return self.model(x)
10 changes: 9 additions & 1 deletion torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def __init__(
Warning:
You must define :attr:`optim_recipe` if you do not use the Lightning CLI.

Warning:
When using an ensemble model, you must:
1. Set :attr:`is_ensemble` to ``True``.
2. Set :attr:`format_batch_fn` to :class:`torch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators)`.
3. Ensure that the model's forward pass outputs a tensor of shape :math:`(M \times B, C)`, where :math:`M` is the number of estimators, :math:`B` is the batch size, :math:`C` is the number of classes.

For automated batch handling, consider using the available model wrappers in `torch_uncertainty.models.wrappers`.

Note:
:attr:`optim_recipe` can be anything that can be returned by
:meth:`LightningModule.configure_optimizers()`. Find more details
Expand Down Expand Up @@ -475,7 +483,7 @@ def test_step(
"""
inputs, targets = batch
logits = self.forward(inputs, save_feats=self.eval_grouping_loss)
logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0))
logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0))
probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1)
probs = probs_per_est.mean(dim=1)
confs = probs.max(-1)[0]
Expand Down