Skip to content

✨ LeNet BatchEnsemble and Deep Ensemble + minor bugfixes #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ Functions
:toctree: generated/
:nosignatures:

batch_ensemble
deep_ensembles
mc_dropout

Expand All @@ -212,6 +213,7 @@ Classes
:nosignatures:
:template: class.rst

BatchEnsemble
CheckpointEnsemble
EMA
MCDropout
Expand Down
68 changes: 68 additions & 0 deletions experiments/classification/mnist/configs/lenet_batch_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
fast_dev_run: false
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 10
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/lenet_trajectory
name: batch_ensemble
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/cls/Acc
patience: 1000
check_finite: true
model:
# ClassificationRoutine
model:
# BatchEnsemble
class_path: torch_uncertainty.models.lenet.batchensemble_lenet
init_args:
in_channels: 1
num_classes: 10
num_estimators: 5
activation: torch.nn.ReLU
norm: torch.nn.BatchNorm2d
groups: 1
dropout_rate: 0
repeat_training_inputs: true
num_classes: 10
loss: CrossEntropyLoss
is_ensemble: true
format_batch_fn:
class_path: torch_uncertainty.transforms.batch.RepeatTarget
init_args:
num_repeats: 5
data:
root: ./data
batch_size: 128
num_workers: 127
eval_ood: true
eval_shift: true
optimizer:
lr: 0.05
momentum: 0.9
weight_decay: 5e-4
nesterov: true
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ optimizer:
weight_decay: 5e-4
nesterov: true
lr_scheduler:
milestones:
- 25
- 50
gamma: 0.1
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
78 changes: 78 additions & 0 deletions experiments/classification/mnist/configs/lenet_deep_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
fast_dev_run: false
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 10
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/lenet_trajectory
name: deep_ensemble
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/cls/Acc
patience: 1000
check_finite: true
model:
# ClassificationRoutine
model:
# DeepEnsemble
class_path: torch_uncertainty.models.wrappers.deep_ensembles.deep_ensembles
init_args:
models:
# LeNet
class_path: torch_uncertainty.models.lenet._LeNet
init_args:
in_channels: 1
num_classes: 10
linear_layer: torch.nn.Linear
conv2d_layer: torch.nn.Conv2d
activation: torch.nn.ReLU
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
# last_layer_dropout: false
layer_args: {}
num_estimators: 5
task: classification
probabilistic: false
reset_model_parameters: true
num_classes: 10
loss: CrossEntropyLoss
is_ensemble: true
format_batch_fn:
class_path: torch_uncertainty.transforms.batch.RepeatTarget
init_args:
num_repeats: 5
data:
root: ./data
batch_size: 128
num_workers: 127
eval_ood: true
eval_shift: true
optimizer:
lr: 0.05
momentum: 0.9
weight_decay: 5e-4
nesterov: true
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
10 changes: 6 additions & 4 deletions experiments/classification/mnist/configs/lenet_ema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ optimizer:
weight_decay: 5e-4
nesterov: true
lr_scheduler:
milestones:
- 25
- 50
gamma: 0.1
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
22 changes: 22 additions & 0 deletions tests/layers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def test_linear_one_estimator_no_bias(self, feat_input: torch.Tensor):
out = layer(feat_input)
assert out.shape == torch.Size([4, 2])

def test_convert_from_linear(self, feat_input: torch.Tensor):
linear = torch.nn.Linear(6, 3)
layer = BatchLinear.from_linear(linear, num_estimators=2)
assert layer.linear.weight.shape == torch.Size([3, 6])
assert layer.linear.bias is None
assert layer.r_group.shape == torch.Size([2, 6])
assert layer.s_group.shape == torch.Size([2, 3])
assert layer.bias.shape == torch.Size([2, 3])
out = layer(feat_input)
assert out.shape == torch.Size([4, 3])


class TestBatchConv2d:
"""Testing the BatchConv2d layer class."""
Expand All @@ -47,3 +58,14 @@ def test_conv_two_estimators(self, img_input: torch.Tensor):
layer = BatchConv2d(6, 2, num_estimators=2, kernel_size=1)
out = layer(img_input)
assert out.shape == torch.Size([5, 2, 3, 3])

def test_convert_from_conv2d(self, img_input: torch.Tensor):
conv = torch.nn.Conv2d(6, 3, 1)
layer = BatchConv2d.from_conv2d(conv, num_estimators=2)
assert layer.conv.weight.shape == torch.Size([3, 6, 1, 1])
assert layer.conv.bias is None
assert layer.r_group.shape == torch.Size([2, 6])
assert layer.s_group.shape == torch.Size([2, 3])
assert layer.bias.shape == torch.Size([2, 3])
out = layer(img_input)
assert out.shape == torch.Size([5, 3, 3, 3])
3 changes: 2 additions & 1 deletion tests/models/test_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import nn

from torch_uncertainty.models.lenet import bayesian_lenet, lenet, packed_lenet
from torch_uncertainty.models.lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet


class TestLeNet:
Expand All @@ -18,6 +18,7 @@ def test_main(self):
model.eval()
model(torch.randn(1, 1, 20, 20))

batchensemble_lenet(1, 1)
packed_lenet(1, 1)
bayesian_lenet(1, 1)
bayesian_lenet(
Expand Down
82 changes: 82 additions & 0 deletions tests/models/wrappers/test_batch_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
import torch
from torch import nn

from torch_uncertainty.layers import BatchConv2d, BatchLinear
from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble


@pytest.fixture()
def img_input() -> torch.Tensor:
return torch.rand((5, 6, 3, 3))


# Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture)
class _DummyModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.conv = nn.Conv2d(in_features, out_features, 3)
self.fc = nn.Linear(out_features, out_features)

def forward(self, x):
x = self.conv(x)
x = x.flatten(1)
return self.fc(x)


class _DummyBEModel(nn.Module):
def __init__(self, in_features, out_features, num_estimators):
super().__init__()
self.conv = BatchConv2d(in_features, out_features, 3, num_estimators)
self.fc = BatchLinear(out_features, out_features, num_estimators=num_estimators)

def forward(self, x):
x = self.conv(x)
x = x.flatten(1)
return self.fc(x)


class TestBatchEnsembleModel:
def test_convert_layers(self):
in_features = 6
out_features = 4
num_estimators = 3

model = _DummyModel(in_features, out_features)
wrapped_model = BatchEnsemble(model, num_estimators, convert_layers=True)
assert wrapped_model.num_estimators == num_estimators
assert isinstance(wrapped_model.model.conv, BatchConv2d)
assert isinstance(wrapped_model.model.fc, BatchLinear)

def test_forward_pass(self, img_input):
batch_size = img_input.size(0)
in_features = img_input.size(1)
out_features = 4
num_estimators = 3
model = _DummyBEModel(in_features, out_features, num_estimators)
# with repeat_training_inputs=False
wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=False)
# test forward pass for training
logits = wrapped_model(img_input)
assert logits.shape == (img_input.size(0), out_features)
# test forward pass for evaluation
wrapped_model.eval()
logits = wrapped_model(img_input)
assert logits.shape == (batch_size * num_estimators, out_features)
# with repeat_training_inputs=True
wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=True)
# test forward pass for training
logits = wrapped_model(img_input)
assert logits.shape == (batch_size * num_estimators, out_features)
# test forward pass for evaluation
wrapped_model.eval()
logits = wrapped_model(img_input)
assert logits.shape == (batch_size * num_estimators, out_features)

def test_errors(self):
with pytest.raises(ValueError):
BatchEnsemble(_DummyBEModel(10, 5, 1), 0)
with pytest.raises(ValueError):
BatchEnsemble(_DummyModel(10, 5), 1)
with pytest.raises(ValueError):
BatchEnsemble(nn.Identity(), 2, convert_layers=True)
Loading