diff --git a/docs/source/api.rst b/docs/source/api.rst index 16f1f1c2..0165ad49 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -201,6 +201,7 @@ Functions :toctree: generated/ :nosignatures: + batch_ensemble deep_ensembles mc_dropout @@ -212,6 +213,7 @@ Classes :nosignatures: :template: class.rst + BatchEnsemble CheckpointEnsemble EMA MCDropout diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml new file mode 100644 index 00000000..d385b100 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -0,0 +1,68 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: batch_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # BatchEnsemble + class_path: torch_uncertainty.models.lenet.batchensemble_lenet + init_args: + in_channels: 1 + num_classes: 10 + num_estimators: 5 + activation: torch.nn.ReLU + norm: torch.nn.BatchNorm2d + groups: 1 + dropout_rate: 0 + repeat_training_inputs: true + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index c5398a87..10ea9fe8 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -67,7 +67,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml new file mode 100644 index 00000000..1d47b782 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -0,0 +1,78 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: deep_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # DeepEnsemble + class_path: torch_uncertainty.models.wrappers.deep_ensembles.deep_ensembles + init_args: + models: + # LeNet + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + # last_layer_dropout: false + layer_args: {} + num_estimators: 5 + task: classification + probabilistic: false + reset_model_parameters: true + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 363461c6..713bf607 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -55,7 +55,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/tests/layers/test_batch.py b/tests/layers/test_batch.py index bb7ca6e9..75e54485 100644 --- a/tests/layers/test_batch.py +++ b/tests/layers/test_batch.py @@ -33,6 +33,17 @@ def test_linear_one_estimator_no_bias(self, feat_input: torch.Tensor): out = layer(feat_input) assert out.shape == torch.Size([4, 2]) + def test_convert_from_linear(self, feat_input: torch.Tensor): + linear = torch.nn.Linear(6, 3) + layer = BatchLinear.from_linear(linear, num_estimators=2) + assert layer.linear.weight.shape == torch.Size([3, 6]) + assert layer.linear.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(feat_input) + assert out.shape == torch.Size([4, 3]) + class TestBatchConv2d: """Testing the BatchConv2d layer class.""" @@ -47,3 +58,14 @@ def test_conv_two_estimators(self, img_input: torch.Tensor): layer = BatchConv2d(6, 2, num_estimators=2, kernel_size=1) out = layer(img_input) assert out.shape == torch.Size([5, 2, 3, 3]) + + def test_convert_from_conv2d(self, img_input: torch.Tensor): + conv = torch.nn.Conv2d(6, 3, 1) + layer = BatchConv2d.from_conv2d(conv, num_estimators=2) + assert layer.conv.weight.shape == torch.Size([3, 6, 1, 1]) + assert layer.conv.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(img_input) + assert out.shape == torch.Size([5, 3, 3, 3]) diff --git a/tests/models/test_lenet.py b/tests/models/test_lenet.py index 8519ffdf..c6a08180 100644 --- a/tests/models/test_lenet.py +++ b/tests/models/test_lenet.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch_uncertainty.models.lenet import bayesian_lenet, lenet, packed_lenet +from torch_uncertainty.models.lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet class TestLeNet: @@ -18,6 +18,7 @@ def test_main(self): model.eval() model(torch.randn(1, 1, 20, 20)) + batchensemble_lenet(1, 1) packed_lenet(1, 1) bayesian_lenet(1, 1) bayesian_lenet( diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py new file mode 100644 index 00000000..8c1b675e --- /dev/null +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -0,0 +1,82 @@ +import pytest +import torch +from torch import nn + +from torch_uncertainty.layers import BatchConv2d, BatchLinear +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble + + +@pytest.fixture() +def img_input() -> torch.Tensor: + return torch.rand((5, 6, 3, 3)) + + +# Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) +class _DummyModel(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.conv = nn.Conv2d(in_features, out_features, 3) + self.fc = nn.Linear(out_features, out_features) + + def forward(self, x): + x = self.conv(x) + x = x.flatten(1) + return self.fc(x) + + +class _DummyBEModel(nn.Module): + def __init__(self, in_features, out_features, num_estimators): + super().__init__() + self.conv = BatchConv2d(in_features, out_features, 3, num_estimators) + self.fc = BatchLinear(out_features, out_features, num_estimators=num_estimators) + + def forward(self, x): + x = self.conv(x) + x = x.flatten(1) + return self.fc(x) + + +class TestBatchEnsembleModel: + def test_convert_layers(self): + in_features = 6 + out_features = 4 + num_estimators = 3 + + model = _DummyModel(in_features, out_features) + wrapped_model = BatchEnsemble(model, num_estimators, convert_layers=True) + assert wrapped_model.num_estimators == num_estimators + assert isinstance(wrapped_model.model.conv, BatchConv2d) + assert isinstance(wrapped_model.model.fc, BatchLinear) + + def test_forward_pass(self, img_input): + batch_size = img_input.size(0) + in_features = img_input.size(1) + out_features = 4 + num_estimators = 3 + model = _DummyBEModel(in_features, out_features, num_estimators) + # with repeat_training_inputs=False + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=False) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (img_input.size(0), out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # with repeat_training_inputs=True + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=True) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + + def test_errors(self): + with pytest.raises(ValueError): + BatchEnsemble(_DummyBEModel(10, 5, 1), 0) + with pytest.raises(ValueError): + BatchEnsemble(_DummyModel(10, 5), 1) + with pytest.raises(ValueError): + BatchEnsemble(nn.Identity(), 2, convert_layers=True) diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index dde72139..abb481f8 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -1,6 +1,7 @@ import math import torch +from einops import repeat from torch import Tensor, nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair @@ -79,11 +80,18 @@ def __init__( :math:`H_{out} = \text{out_features}`. Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`. + It is advised to ensure that `batch_size` is divisible by :attr:`num_estimators` when + calling :func:`forward()`, so each estimator receives the same number of examples. + In a BatchEnsemble architecture, the input is typically **repeated** `num_estimators` + times along the batch dimension. Incorrect batch size may lead to unexpected results. + + To simplify batch handling, wrap your model with `torch_uncertainty.wrappers.BatchEnsemble`, + which automatically repeats the batch before passing it through the network. + Examples: >>> # With three estimators - >>> m = LinearBE(20, 30, 3) + >>> m = BatchLinear(20, 30, 3) >>> input = torch.randn(8, 20) >>> output = m(input) >>> print(output.size()) @@ -110,6 +118,30 @@ def __init__( self.register_parameter("bias", None) self.reset_parameters() + @classmethod + def from_linear(cls, linear: nn.Linear, num_estimators: int) -> "BatchLinear": + r"""Create a BatchEnsemble-style Linear layer from an existing Linear layer. + + Args: + linear (nn.Linear): The Linear layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchLinear: The converted BatchEnsemble-style Linear layer. + + Example: + >>> linear = nn.Linear(20, 30) + >>> be_linear = BatchLinear.from_linear(linear, num_estimators=3) + """ + return cls( + in_features=linear.in_features, + out_features=linear.out_features, + num_estimators=num_estimators, + bias=linear.bias is not None, + device=linear.weight.device, + dtype=linear.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -125,16 +157,12 @@ def forward(self, inputs: Tensor) -> Tensor: ) extra = batch_size % self.num_estimators - r_group = torch.repeat_interleave(self.r_group, examples_per_estimator, dim=0) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = torch.repeat_interleave(self.s_group, examples_per_estimator, dim=0) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) + r_group = repeat(self.r_group, "m h -> (m b) h", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = torch.repeat_interleave( - self.bias, - examples_per_estimator, - dim=0, - ) + bias = repeat(self.bias, "m h -> (m b) h", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None @@ -273,12 +301,16 @@ def __init__( {\text{stride}[1]} + 1\right\rfloor Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_channels` when calling :func:`forward()`. + Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. + In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` + times along the first axis. Incorrect batch size may lead to unexpected results. + To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically + repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. Examples: >>> # With square kernels, four estimators and equal stride - >>> m = Conv2dBE(3, 32, 3, 4, stride=1) + >>> m = BatchConv2d(3, 32, 3, 4, stride=1) >>> input = torch.randn(8, 3, 16, 16) >>> output = m(input) >>> print(output.size()) @@ -315,6 +347,38 @@ def __init__( self.reset_parameters() + @classmethod + def from_conv2d(cls, conv2d: nn.Conv2d, num_estimators: int) -> "BatchConv2d": + r"""Create a BatchEnsemble-style Conv2d layer from an existing Conv2d layer. + + Args: + conv2d (nn.Conv2d): The Conv2d layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchConv2d: The converted BatchEnsemble-style Conv2d layer. + + Warning: + All parameters of the original Conv2d layer will be discarded. + + Example: + >>> conv2d = nn.Conv2d(3, 32, kernel_size=3) + >>> be_conv2d = BatchConv2d.from_conv2d(conv2d, num_estimators=3) + """ + return cls( + in_channels=conv2d.in_channels, + out_channels=conv2d.out_channels, + kernel_size=conv2d.kernel_size, + stride=conv2d.stride, + padding=conv2d.padding, + dilation=conv2d.dilation, + groups=conv2d.groups, + bias=conv2d.bias is not None, + num_estimators=num_estimators, + device=conv2d.weight.device, + dtype=conv2d.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -328,50 +392,13 @@ def forward(self, inputs: Tensor) -> Tensor: examples_per_estimator = batch_size // self.num_estimators extra = batch_size % self.num_estimators - r_group = ( - torch.repeat_interleave( - self.r_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.r_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = ( - torch.repeat_interleave( - self.s_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.s_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # + r_group = repeat(self.r_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = ( - torch.repeat_interleave( - self.bias, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.bias.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - + bias = repeat(self.bias, "m h -> (m b) h 1 1", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 4b8964bc..a3faf1f4 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -5,9 +5,11 @@ STEP_UPDATE_MODEL, SWA, SWAG, + BatchEnsemble, CheckpointEnsemble, MCDropout, StochasticModel, + batch_ensemble, deep_ensembles, mc_dropout, ) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 4b76c9e6..b31ba133 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -9,6 +9,7 @@ from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear from torch_uncertainty.models import StochasticModel +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble __all__ = ["bayesian_lenet", "lenet", "packed_lenet"] @@ -119,6 +120,32 @@ def lenet( ) +def batchensemble_lenet( + in_channels: int, + num_classes: int, + num_estimators: int = 4, + activation: Callable = F.relu, + norm: type[nn.Module] = nn.BatchNorm2d, + groups: int = 1, + dropout_rate: float = 0.0, + repeat_training_inputs: bool = False, +) -> _LeNet: + model = lenet( + in_channels=in_channels, + num_classes=num_classes, + activation=activation, + norm=norm, + groups=groups, + dropout_rate=dropout_rate, + ) + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=True, + ) + + def packed_lenet( in_channels: int, num_classes: int, diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index 75f37e66..fb4ff50c 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 +from .batch_ensemble import BatchEnsemble, batch_ensemble from .checkpoint_ensemble import ( CheckpointEnsemble, ) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py new file mode 100644 index 00000000..1e9bc46f --- /dev/null +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -0,0 +1,156 @@ +import torch +from einops import repeat +from torch import nn + +from torch_uncertainty.layers import BatchConv2d, BatchLinear + + +class BatchEnsemble(nn.Module): + """Wrap a BatchEnsemble model to ensure correct batch replication. + + In a BatchEnsemble architecture, each estimator operates on a **sub-batch** + of the input. This means that the input batch must be **repeated** + :attr:`num_estimators` times before being processed. + + This wrapper automatically **duplicates the input batch** along the first axis, + ensuring that each estimator receives the correct data format. + + Args: + model (nn.Module): The BatchEnsemble model. + num_estimators (int): Number of ensemble members. + repeat_training_inputs (optional, bool): Whether to repeat the input batch during training. + If `True`, the input batch is repeated during both training and evaluation. If `False`, + the input batch is repeated only during evaluation. Default is `False`. + convert_layers (optional, bool): Whether to convert the model's layers to BatchEnsemble layers. + If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their + BatchEnsemble counterparts. Default is `False`. + + Raises: + ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the + end of initialization. + ValueError: If ``num_estimators`` is less than or equal to ``0``. + ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are + found in the model. + + Warning: + If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d`` + layers in the model to their BatchEnsemble counterparts. If the model contains other types of + layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d`` + layers in the model, the wrapper will raise an error during conversion. + + Warning: + If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines`` + for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when + initializing the routine. + + Example: + >>> model = nn.Sequential( + ... nn.Linear(10, 5), + ... nn.ReLU(), + ... nn.Linear(5, 2) + ... ) + >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) + >>> model + BatchEnsemble( + (model): Sequential( + (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) + (1): ReLU() + (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) + ) + ) + """ + + def __init__( + self, + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, + ) -> None: + super().__init__() + self.model = model + self.num_estimators = num_estimators + self.repeat_training_inputs = repeat_training_inputs + + if convert_layers: + self._convert_layers() + + filtered_modules = [ + module + for module in self.model.modules() + if isinstance(module, BatchLinear | BatchConv2d) + ] + _batch_ensemble_checks(filtered_modules, num_estimators) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and + pass it through the model. + """ + if not self.training or self.repeat_training_inputs: + x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) + return self.model(x) + + def _convert_layers(self) -> None: + """Converts the model's layers to BatchEnsemble layers.""" + no_valid_layers = True + for name, layer in self.model.named_modules(): + if isinstance(layer, nn.Linear): + setattr( + self.model, + name, + BatchLinear.from_linear(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + elif isinstance(layer, nn.Conv2d): + setattr( + self.model, + name, + BatchConv2d.from_conv2d(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + if no_valid_layers: + raise ValueError( + "No valid layers found in the model. " + "Please use `nn.Linear` or `nn.Conv2d` layers to apply BatchEnsemble." + ) + + +def _batch_ensemble_checks(filtered_modules, num_estimators): + """Check if the model contains the required number of dropout modules.""" + if len(filtered_modules) == 0: + raise ValueError( + "No BatchEnsemble layers found in the model. " + "Please use `BatchLinear` or `BatchConv2d` layers in your model " + "or set `convert_layers=True` when initializing the wrapper." + ) + if num_estimators <= 0: + raise ValueError("`num_estimators` must be greater than 0.") + + +def batch_ensemble( + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, +) -> BatchEnsemble: + """BatchEnsemble wrapper for a model. + + Args: + model (nn.Module): model to wrap + num_estimators (int): number of ensemble members + repeat_training_inputs (bool, optional): whether to repeat the input batch during training. + If `True`, the input batch is repeated during both training and evaluation. If `False`, + the input batch is repeated only during evaluation. Default is `False`. + convert_layers (bool, optional): whether to convert the model's layers to BatchEnsemble layers. + If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their + BatchEnsemble counterparts. Default is `False`. + + Returns: + BatchEnsemble: BatchEnsemble wrapper for the model + """ + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=convert_layers, + ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index bb616c58..82189fbf 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -123,6 +123,14 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the Lightning CLI. + Warning: + When using an ensemble model, you must: + 1. Set :attr:`is_ensemble` to ``True``. + 2. Set :attr:`format_batch_fn` to :class:`torch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators)`. + 3. Ensure that the model's forward pass outputs a tensor of shape :math:`(M \times B, C)`, where :math:`M` is the number of estimators, :math:`B` is the batch size, :math:`C` is the number of classes. + + For automated batch handling, consider using the available model wrappers in `torch_uncertainty.models.wrappers`. + Note: :attr:`optim_recipe` can be anything that can be returned by :meth:`LightningModule.configure_optimizers()`. Find more details @@ -475,7 +483,7 @@ def test_step( """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) confs = probs.max(-1)[0] diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index 1426992a..19e0edab 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange +from einops import rearrange, repeat from torch import Tensor, nn @@ -21,7 +21,7 @@ def __init__(self, num_repeats: int) -> None: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat(self.num_repeats, *[1] * (targets.ndim - 1)) + return inputs, repeat(targets, "b ... -> (m b) ...", m=self.num_repeats) class MIMOBatchFormat(nn.Module): @@ -79,6 +79,6 @@ def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: [torch.index_select(targets, dim=0, index=indices) for indices in shuffle_indices], dim=0, ) - inputs = rearrange(inputs, "m b c h w -> (m b) c h w", m=self.num_estimators) + inputs = rearrange(inputs, "m b ... -> (m b) ...", m=self.num_estimators) targets = rearrange(targets, "m b -> (m b)", m=self.num_estimators) return inputs, targets