From 6df09149ac69ab3277feadb902f950f1a82b6610 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 11:43:49 +0000 Subject: [PATCH 01/17] :bug: Add missing learning rate scheduler class paths to lenet configs --- .../mnist/configs/lenet_checkpoint_ensemble.yaml | 10 ++++++---- .../classification/mnist/configs/lenet_ema.yaml | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index c5398a87..10ea9fe8 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -67,7 +67,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 363461c6..713bf607 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -55,7 +55,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 From 56ab6527239de90d3ab995dc7663766546642079 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 12:05:42 +0000 Subject: [PATCH 02/17] :sparkles: Add BatchEnsemble wrapper --- .../models/wrappers/batch_ensemble.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 torch_uncertainty/models/wrappers/batch_ensemble.py diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py new file mode 100644 index 00000000..d62adcb8 --- /dev/null +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -0,0 +1,34 @@ +import torch +from torch import nn + +class BatchEnsemble(nn.Module): + """Wraps a BatchEnsemble model to ensure correct batch replication. + + In a BatchEnsemble architecture, each estimator operates on a **sub-batch** + of the input. This means that the input batch must be **repeated** + :attr:`num_estimators` times before being processed. + + This wrapper automatically **duplicates the input batch** along the first axis, + ensuring that each estimator receives the correct data format. + + **Usage Example:** + ```python + model = lenet(in_channels=1, num_classes=10) + wrapped_model = BatchEnsembleWrapper(model, num_estimators=5) + logits = wrapped_model(x) # `x` is automatically repeated `num_estimators` times + ``` + + Args: + model (nn.Module): The BatchEnsemble model. + num_estimators (int): Number of ensemble members. + """ + + def __init__(self, model: nn.Module, num_estimators: int): + super().__init__() + self.model = model + self.num_estimators = num_estimators + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Repeats the input batch and passes it through the model.""" + x = x.repeat(self.num_estimators, 1, 1, 1) + return self.model(x) From 918de33527e9fa4fff9c1897d5a780d68c773c20 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 12:06:58 +0000 Subject: [PATCH 03/17] :books: Update documentation regarding (batch) ensemble usage --- torch_uncertainty/layers/batch_ensemble.py | 14 ++++++++++++-- torch_uncertainty/routines/classification.py | 10 +++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index dde72139..3f15ac21 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -79,7 +79,13 @@ def __init__( :math:`H_{out} = \text{out_features}`. Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`. + Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. + In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` + times along the first axis. Incorrect batch size may lead to unexpected results. + + To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically + repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. + Examples: >>> # With three estimators @@ -273,8 +279,12 @@ def __init__( {\text{stride}[1]} + 1\right\rfloor Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_channels` when calling :func:`forward()`. + Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. + In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` + times along the first axis. Incorrect batch size may lead to unexpected results. + To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically + repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. Examples: >>> # With square kernels, four estimators and equal stride diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index bb616c58..82189fbf 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -123,6 +123,14 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the Lightning CLI. + Warning: + When using an ensemble model, you must: + 1. Set :attr:`is_ensemble` to ``True``. + 2. Set :attr:`format_batch_fn` to :class:`torch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators)`. + 3. Ensure that the model's forward pass outputs a tensor of shape :math:`(M \times B, C)`, where :math:`M` is the number of estimators, :math:`B` is the batch size, :math:`C` is the number of classes. + + For automated batch handling, consider using the available model wrappers in `torch_uncertainty.models.wrappers`. + Note: :attr:`optim_recipe` can be anything that can be returned by :meth:`LightningModule.configure_optimizers()`. Find more details @@ -475,7 +483,7 @@ def test_step( """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) confs = probs.max(-1)[0] From c5b62d8bc9e6c4b652b993daa21b27a7db7fc88d Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 12:09:32 +0000 Subject: [PATCH 04/17] :sparkles: Add LeNet BatchEnsemble and Deep Ensemble --- .../mnist/configs/lenet_batch_ensemble.yaml | 67 ++++++++++++++++ .../mnist/configs/lenet_deep_ensemble.yaml | 78 +++++++++++++++++++ torch_uncertainty/models/lenet.py | 28 +++++++ 3 files changed, 173 insertions(+) create mode 100644 experiments/classification/mnist/configs/lenet_batch_ensemble.yaml create mode 100644 experiments/classification/mnist/configs/lenet_deep_ensemble.yaml diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml new file mode 100644 index 00000000..0d2a94ee --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -0,0 +1,67 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: batch_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # BatchEnsemble + class_path: torch_uncertainty.models.lenet.batchensemble_lenet + init_args: + in_channels: 1 + num_classes: 10 + num_estimators: 5 + activation: torch.nn.ReLU + norm: torch.nn.BatchNorm2d + groups: 1 + dropout_rate: 0 + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml new file mode 100644 index 00000000..1d47b782 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -0,0 +1,78 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: deep_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # DeepEnsemble + class_path: torch_uncertainty.models.wrappers.deep_ensembles.deep_ensembles + init_args: + models: + # LeNet + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + # last_layer_dropout: false + layer_args: {} + num_estimators: 5 + task: classification + probabilistic: false + reset_model_parameters: true + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 4b76c9e6..aca916ee 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -5,10 +5,12 @@ import torch.nn.functional as F from torch import nn +from torch_uncertainty.layers.batch_ensemble import BatchConv2d, BatchLinear from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear from torch_uncertainty.models import StochasticModel +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble __all__ = ["bayesian_lenet", "lenet", "packed_lenet"] @@ -119,6 +121,32 @@ def lenet( ) +def batchensemble_lenet( + in_channels: int, + num_classes: int, + num_estimators: int = 4, + activation: Callable = F.relu, + norm: type[nn.Module] = nn.BatchNorm2d, + groups: int = 1, + dropout_rate: float = 0.0, +) -> _LeNet: + model = _lenet( + stochastic=False, + in_channels=in_channels, + num_classes=num_classes, + linear_layer=BatchLinear, + conv2d_layer=BatchConv2d, + layer_args={ + "num_estimators": num_estimators, + }, + activation=activation, + norm=norm, + groups=groups, + dropout_rate=dropout_rate, + ) + return BatchEnsemble(model, num_estimators) + + def packed_lenet( in_channels: int, num_classes: int, From a92385eab837c4302f4db44a40288e95c6ac7eb5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 14:41:32 +0100 Subject: [PATCH 05/17] :shirt: Lint --- torch_uncertainty/models/wrappers/batch_ensemble.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index d62adcb8..88b2312b 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -1,6 +1,7 @@ import torch from torch import nn + class BatchEnsemble(nn.Module): """Wraps a BatchEnsemble model to ensure correct batch replication. From ca73b4ba2ca2d5f7269d74a8303e6223483b7145 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 14:51:24 +0100 Subject: [PATCH 06/17] :shirt: Also format --- torch_uncertainty/models/wrappers/batch_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index 88b2312b..c6995ca8 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -18,7 +18,7 @@ class BatchEnsemble(nn.Module): wrapped_model = BatchEnsembleWrapper(model, num_estimators=5) logits = wrapped_model(x) # `x` is automatically repeated `num_estimators` times ``` - + Args: model (nn.Module): The BatchEnsemble model. num_estimators (int): Number of ensemble members. From 597691904fb1c1fa8f53b3c37b9dbfc604be94bc Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 15:46:55 +0000 Subject: [PATCH 07/17] :white_check_mark: Add test for BatchEnsemble wrapper and LeNet implementation --- tests/models/test_lenet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_lenet.py b/tests/models/test_lenet.py index 8519ffdf..c6a08180 100644 --- a/tests/models/test_lenet.py +++ b/tests/models/test_lenet.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch_uncertainty.models.lenet import bayesian_lenet, lenet, packed_lenet +from torch_uncertainty.models.lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet class TestLeNet: @@ -18,6 +18,7 @@ def test_main(self): model.eval() model(torch.randn(1, 1, 20, 20)) + batchensemble_lenet(1, 1) packed_lenet(1, 1) bayesian_lenet(1, 1) bayesian_lenet( From 18dacb376545b3a6355a81f0ab827763647abc27 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 18:16:46 +0000 Subject: [PATCH 08/17] :heavy_check_mark: Add test for BatchEnsemble wrapper and fix bug in batch replication --- tests/models/wrappers/test_batch_ensemble.py | 36 +++++++++++++++++++ .../models/wrappers/batch_ensemble.py | 3 +- 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 tests/models/wrappers/test_batch_ensemble.py diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py new file mode 100644 index 00000000..8e2271ee --- /dev/null +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -0,0 +1,36 @@ +import pytest +import torch +from torch import nn + +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble + + +# Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) +class SimpleModel(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.fc = nn.Linear(in_features, out_features) + self.r_group = nn.Parameter(torch.randn(in_features)) + self.s_group = nn.Parameter(torch.randn(out_features)) + self.bias = nn.Parameter(torch.randn(out_features)) + + def forward(self, x): + return self.fc(x) + + +# Test the BatchEnsemble wrapper +def test_batch_ensemble(): + in_features = 10 + out_features = 5 + num_estimators = 3 + model = SimpleModel(in_features, out_features) + wrapped_model = BatchEnsemble(model, num_estimators) + + # Test forward pass + x = torch.randn(2, in_features) # Batch size of 2 + logits = wrapped_model(x) + assert logits.shape == (2 * num_estimators, out_features) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index c6995ca8..5d6355d3 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -31,5 +31,6 @@ def __init__(self, model: nn.Module, num_estimators: int): def forward(self, x: torch.Tensor) -> torch.Tensor: """Repeats the input batch and passes it through the model.""" - x = x.repeat(self.num_estimators, 1, 1, 1) + repeat_shape = [self.num_estimators] + [1] * (x.dim() - 1) + x = x.repeat(repeat_shape) return self.model(x) From 79b567dd3bf50640cae82cae559333692cf355f1 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 11:02:37 +0100 Subject: [PATCH 09/17] :ok_hand: Comply with PEP 257 --- torch_uncertainty/models/wrappers/batch_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index 5d6355d3..c6066e74 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -3,7 +3,7 @@ class BatchEnsemble(nn.Module): - """Wraps a BatchEnsemble model to ensure correct batch replication. + """Wrap a BatchEnsemble model to ensure correct batch replication. In a BatchEnsemble architecture, each estimator operates on a **sub-batch** of the input. This means that the input batch must be **repeated** From dbf4df9f73f40a4af31991f2ff8a3a51da82e6ca Mon Sep 17 00:00:00 2001 From: Anton Date: Mon, 10 Mar 2025 11:40:57 +0000 Subject: [PATCH 10/17] :books: Add note that BatchEnsemble wrapper expects model to use BatchEnsemble layers --- torch_uncertainty/models/wrappers/batch_ensemble.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index c6066e74..e306dcb7 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -22,6 +22,9 @@ class BatchEnsemble(nn.Module): Args: model (nn.Module): The BatchEnsemble model. num_estimators (int): Number of ensemble members. + + Note: + This wrapper assumes that the model uses **BatchEnsemble layers** (see `torchensemble.layers.batch_ensemble`). """ def __init__(self, model: nn.Module, num_estimators: int): From 16a235b69401778dcd9571802670d12909ab07e9 Mon Sep 17 00:00:00 2001 From: Anton Date: Mon, 10 Mar 2025 11:54:29 +0000 Subject: [PATCH 11/17] :hammer: Refactor BatchEnsemble test case to use framework's test format --- tests/models/wrappers/test_batch_ensemble.py | 29 ++++++++------------ 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py index 8e2271ee..8230819e 100644 --- a/tests/models/wrappers/test_batch_ensemble.py +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -1,4 +1,3 @@ -import pytest import torch from torch import nn @@ -6,7 +5,7 @@ # Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) -class SimpleModel(nn.Module): +class _DummyModel(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.fc = nn.Linear(in_features, out_features) @@ -18,19 +17,15 @@ def forward(self, x): return self.fc(x) -# Test the BatchEnsemble wrapper -def test_batch_ensemble(): - in_features = 10 - out_features = 5 - num_estimators = 3 - model = SimpleModel(in_features, out_features) - wrapped_model = BatchEnsemble(model, num_estimators) +class TestBatchEnsembleModel: + def test_forward_pass(self): + in_features = 10 + out_features = 5 + num_estimators = 3 + model = _DummyModel(in_features, out_features) + wrapped_model = BatchEnsemble(model, num_estimators) - # Test forward pass - x = torch.randn(2, in_features) # Batch size of 2 - logits = wrapped_model(x) - assert logits.shape == (2 * num_estimators, out_features) - - -if __name__ == "__main__": - pytest.main([__file__]) + # Test forward pass + x = torch.randn(2, in_features) # Batch size of 2 + logits = wrapped_model(x) + assert logits.shape == (2 * num_estimators, out_features) From 6e2596de0c7f6e1c873722ad0197a57faf6ee5de Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:00:50 +0100 Subject: [PATCH 12/17] :hammer: Use `einops.repeat` instead of `torch.repeat` in `RepeatTarget` --- torch_uncertainty/transforms/batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index 1426992a..19e0edab 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange +from einops import rearrange, repeat from torch import Tensor, nn @@ -21,7 +21,7 @@ def __init__(self, num_repeats: int) -> None: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat(self.num_repeats, *[1] * (targets.ndim - 1)) + return inputs, repeat(targets, "b ... -> (m b) ...", m=self.num_repeats) class MIMOBatchFormat(nn.Module): @@ -79,6 +79,6 @@ def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: [torch.index_select(targets, dim=0, index=indices) for indices in shuffle_indices], dim=0, ) - inputs = rearrange(inputs, "m b c h w -> (m b) c h w", m=self.num_estimators) + inputs = rearrange(inputs, "m b ... -> (m b) ...", m=self.num_estimators) targets = rearrange(targets, "m b -> (m b)", m=self.num_estimators) return inputs, targets From 16c77059c9d473efcf3f487b6ce1fbbe6666eeb5 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:02:20 +0100 Subject: [PATCH 13/17] :hammer: Refine BatchEnsemble layers and add conversion methods --- torch_uncertainty/layers/batch_ensemble.py | 133 ++++++++++++--------- 1 file changed, 75 insertions(+), 58 deletions(-) diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index 3f15ac21..abb481f8 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -1,6 +1,7 @@ import math import torch +from einops import repeat from torch import Tensor, nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair @@ -79,17 +80,18 @@ def __init__( :math:`H_{out} = \text{out_features}`. Warning: - Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. - In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` - times along the first axis. Incorrect batch size may lead to unexpected results. + It is advised to ensure that `batch_size` is divisible by :attr:`num_estimators` when + calling :func:`forward()`, so each estimator receives the same number of examples. + In a BatchEnsemble architecture, the input is typically **repeated** `num_estimators` + times along the batch dimension. Incorrect batch size may lead to unexpected results. - To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically - repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. + To simplify batch handling, wrap your model with `torch_uncertainty.wrappers.BatchEnsemble`, + which automatically repeats the batch before passing it through the network. Examples: >>> # With three estimators - >>> m = LinearBE(20, 30, 3) + >>> m = BatchLinear(20, 30, 3) >>> input = torch.randn(8, 20) >>> output = m(input) >>> print(output.size()) @@ -116,6 +118,30 @@ def __init__( self.register_parameter("bias", None) self.reset_parameters() + @classmethod + def from_linear(cls, linear: nn.Linear, num_estimators: int) -> "BatchLinear": + r"""Create a BatchEnsemble-style Linear layer from an existing Linear layer. + + Args: + linear (nn.Linear): The Linear layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchLinear: The converted BatchEnsemble-style Linear layer. + + Example: + >>> linear = nn.Linear(20, 30) + >>> be_linear = BatchLinear.from_linear(linear, num_estimators=3) + """ + return cls( + in_features=linear.in_features, + out_features=linear.out_features, + num_estimators=num_estimators, + bias=linear.bias is not None, + device=linear.weight.device, + dtype=linear.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -131,16 +157,12 @@ def forward(self, inputs: Tensor) -> Tensor: ) extra = batch_size % self.num_estimators - r_group = torch.repeat_interleave(self.r_group, examples_per_estimator, dim=0) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = torch.repeat_interleave(self.s_group, examples_per_estimator, dim=0) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) + r_group = repeat(self.r_group, "m h -> (m b) h", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = torch.repeat_interleave( - self.bias, - examples_per_estimator, - dim=0, - ) + bias = repeat(self.bias, "m h -> (m b) h", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None @@ -288,7 +310,7 @@ def __init__( Examples: >>> # With square kernels, four estimators and equal stride - >>> m = Conv2dBE(3, 32, 3, 4, stride=1) + >>> m = BatchConv2d(3, 32, 3, 4, stride=1) >>> input = torch.randn(8, 3, 16, 16) >>> output = m(input) >>> print(output.size()) @@ -325,6 +347,38 @@ def __init__( self.reset_parameters() + @classmethod + def from_conv2d(cls, conv2d: nn.Conv2d, num_estimators: int) -> "BatchConv2d": + r"""Create a BatchEnsemble-style Conv2d layer from an existing Conv2d layer. + + Args: + conv2d (nn.Conv2d): The Conv2d layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchConv2d: The converted BatchEnsemble-style Conv2d layer. + + Warning: + All parameters of the original Conv2d layer will be discarded. + + Example: + >>> conv2d = nn.Conv2d(3, 32, kernel_size=3) + >>> be_conv2d = BatchConv2d.from_conv2d(conv2d, num_estimators=3) + """ + return cls( + in_channels=conv2d.in_channels, + out_channels=conv2d.out_channels, + kernel_size=conv2d.kernel_size, + stride=conv2d.stride, + padding=conv2d.padding, + dilation=conv2d.dilation, + groups=conv2d.groups, + bias=conv2d.bias is not None, + num_estimators=num_estimators, + device=conv2d.weight.device, + dtype=conv2d.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -338,50 +392,13 @@ def forward(self, inputs: Tensor) -> Tensor: examples_per_estimator = batch_size // self.num_estimators extra = batch_size % self.num_estimators - r_group = ( - torch.repeat_interleave( - self.r_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.r_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = ( - torch.repeat_interleave( - self.s_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.s_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # + r_group = repeat(self.r_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = ( - torch.repeat_interleave( - self.bias, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.bias.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - + bias = repeat(self.bias, "m h -> (m b) h 1 1", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None From 9d4b4a49f1dbed6130b75b393f81b0e94a5265ad Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:04:02 +0100 Subject: [PATCH 14/17] :hammer: `BatchEnsemble` wrapper overhaul - now the input is repeated during training depending on `repeat_training_inputs` argument - with `convert_layers==True` all `nn.Linear` and `nn.Conv2d` layers are replaced by `BatchLinear` and `BatchConv2d` --- .../mnist/configs/lenet_batch_ensemble.yaml | 1 + tests/layers/test_batch.py | 22 +++ tests/models/wrappers/test_batch_ensemble.py | 75 ++++++++-- torch_uncertainty/models/__init__.py | 2 + torch_uncertainty/models/lenet.py | 17 +-- torch_uncertainty/models/wrappers/__init__.py | 1 + .../models/wrappers/batch_ensemble.py | 139 ++++++++++++++++-- 7 files changed, 224 insertions(+), 33 deletions(-) diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml index 0d2a94ee..d385b100 100644 --- a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -40,6 +40,7 @@ model: norm: torch.nn.BatchNorm2d groups: 1 dropout_rate: 0 + repeat_training_inputs: true num_classes: 10 loss: CrossEntropyLoss is_ensemble: true diff --git a/tests/layers/test_batch.py b/tests/layers/test_batch.py index bb7ca6e9..75e54485 100644 --- a/tests/layers/test_batch.py +++ b/tests/layers/test_batch.py @@ -33,6 +33,17 @@ def test_linear_one_estimator_no_bias(self, feat_input: torch.Tensor): out = layer(feat_input) assert out.shape == torch.Size([4, 2]) + def test_convert_from_linear(self, feat_input: torch.Tensor): + linear = torch.nn.Linear(6, 3) + layer = BatchLinear.from_linear(linear, num_estimators=2) + assert layer.linear.weight.shape == torch.Size([3, 6]) + assert layer.linear.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(feat_input) + assert out.shape == torch.Size([4, 3]) + class TestBatchConv2d: """Testing the BatchConv2d layer class.""" @@ -47,3 +58,14 @@ def test_conv_two_estimators(self, img_input: torch.Tensor): layer = BatchConv2d(6, 2, num_estimators=2, kernel_size=1) out = layer(img_input) assert out.shape == torch.Size([5, 2, 3, 3]) + + def test_convert_from_conv2d(self, img_input: torch.Tensor): + conv = torch.nn.Conv2d(6, 3, 1) + layer = BatchConv2d.from_conv2d(conv, num_estimators=2) + assert layer.conv.weight.shape == torch.Size([3, 6, 1, 1]) + assert layer.conv.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(img_input) + assert out.shape == torch.Size([5, 3, 3, 3]) diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py index 8230819e..8c1b675e 100644 --- a/tests/models/wrappers/test_batch_ensemble.py +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -1,31 +1,82 @@ +import pytest import torch from torch import nn +from torch_uncertainty.layers import BatchConv2d, BatchLinear from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble +@pytest.fixture() +def img_input() -> torch.Tensor: + return torch.rand((5, 6, 3, 3)) + + # Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) class _DummyModel(nn.Module): def __init__(self, in_features, out_features): super().__init__() - self.fc = nn.Linear(in_features, out_features) - self.r_group = nn.Parameter(torch.randn(in_features)) - self.s_group = nn.Parameter(torch.randn(out_features)) - self.bias = nn.Parameter(torch.randn(out_features)) + self.conv = nn.Conv2d(in_features, out_features, 3) + self.fc = nn.Linear(out_features, out_features) def forward(self, x): + x = self.conv(x) + x = x.flatten(1) + return self.fc(x) + + +class _DummyBEModel(nn.Module): + def __init__(self, in_features, out_features, num_estimators): + super().__init__() + self.conv = BatchConv2d(in_features, out_features, 3, num_estimators) + self.fc = BatchLinear(out_features, out_features, num_estimators=num_estimators) + + def forward(self, x): + x = self.conv(x) + x = x.flatten(1) return self.fc(x) class TestBatchEnsembleModel: - def test_forward_pass(self): - in_features = 10 - out_features = 5 + def test_convert_layers(self): + in_features = 6 + out_features = 4 num_estimators = 3 + model = _DummyModel(in_features, out_features) - wrapped_model = BatchEnsemble(model, num_estimators) + wrapped_model = BatchEnsemble(model, num_estimators, convert_layers=True) + assert wrapped_model.num_estimators == num_estimators + assert isinstance(wrapped_model.model.conv, BatchConv2d) + assert isinstance(wrapped_model.model.fc, BatchLinear) + + def test_forward_pass(self, img_input): + batch_size = img_input.size(0) + in_features = img_input.size(1) + out_features = 4 + num_estimators = 3 + model = _DummyBEModel(in_features, out_features, num_estimators) + # with repeat_training_inputs=False + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=False) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (img_input.size(0), out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # with repeat_training_inputs=True + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=True) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) - # Test forward pass - x = torch.randn(2, in_features) # Batch size of 2 - logits = wrapped_model(x) - assert logits.shape == (2 * num_estimators, out_features) + def test_errors(self): + with pytest.raises(ValueError): + BatchEnsemble(_DummyBEModel(10, 5, 1), 0) + with pytest.raises(ValueError): + BatchEnsemble(_DummyModel(10, 5), 1) + with pytest.raises(ValueError): + BatchEnsemble(nn.Identity(), 2, convert_layers=True) diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 4b8964bc..a3faf1f4 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -5,9 +5,11 @@ STEP_UPDATE_MODEL, SWA, SWAG, + BatchEnsemble, CheckpointEnsemble, MCDropout, StochasticModel, + batch_ensemble, deep_ensembles, mc_dropout, ) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index aca916ee..b31ba133 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -5,7 +5,6 @@ import torch.nn.functional as F from torch import nn -from torch_uncertainty.layers.batch_ensemble import BatchConv2d, BatchLinear from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear @@ -129,22 +128,22 @@ def batchensemble_lenet( norm: type[nn.Module] = nn.BatchNorm2d, groups: int = 1, dropout_rate: float = 0.0, + repeat_training_inputs: bool = False, ) -> _LeNet: - model = _lenet( - stochastic=False, + model = lenet( in_channels=in_channels, num_classes=num_classes, - linear_layer=BatchLinear, - conv2d_layer=BatchConv2d, - layer_args={ - "num_estimators": num_estimators, - }, activation=activation, norm=norm, groups=groups, dropout_rate=dropout_rate, ) - return BatchEnsemble(model, num_estimators) + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=True, + ) def packed_lenet( diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index 75f37e66..fb4ff50c 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 +from .batch_ensemble import BatchEnsemble, batch_ensemble from .checkpoint_ensemble import ( CheckpointEnsemble, ) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index e306dcb7..9d6213f1 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -1,6 +1,9 @@ import torch +from einops import repeat from torch import nn +from torch_uncertainty.layers import BatchConv2d, BatchLinear + class BatchEnsemble(nn.Module): """Wrap a BatchEnsemble model to ensure correct batch replication. @@ -12,28 +15,140 @@ class BatchEnsemble(nn.Module): This wrapper automatically **duplicates the input batch** along the first axis, ensuring that each estimator receives the correct data format. - **Usage Example:** - ```python - model = lenet(in_channels=1, num_classes=10) - wrapped_model = BatchEnsembleWrapper(model, num_estimators=5) - logits = wrapped_model(x) # `x` is automatically repeated `num_estimators` times - ``` - Args: model (nn.Module): The BatchEnsemble model. num_estimators (int): Number of ensemble members. + repeat_training_inputs (optional, bool): Whether to repeat the input batch during training. + If `True`, the input batch is repeated during both training and evaluation. If `False`, + the input batch is repeated only during evaluation. Default is `False`. + convert_layers (optional, bool): Whether to convert the model's layers to BatchEnsemble layers. + If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their + BatchEnsemble counterparts. Default is `False`. + + Raises: + ValueError: If neither `BatchLinear` nor `BatchConv2d` layers are found in the model at the + end of initialization. + ValueError: If `num_estimators` is less than or equal to 0. + ValueError: If `convert_layers=True` and neither `nn.Linear` nor `nn.Conv2d` layers are + found in the model. + + Warning: + If `convert_layers==True`, the wrapper will attempt to convert all `nn.Linear` and `nn.Conv2d` + layers in the model to their BatchEnsemble counterparts. If the model contains other types of + layers, the conversion won't happen for these layers. If don't have any `nn.Linear` or `nn.Conv2d` + layers in the model, the wrapper will raise an error during conversion. - Note: - This wrapper assumes that the model uses **BatchEnsemble layers** (see `torchensemble.layers.batch_ensemble`). + Warning: + If `repeat_training_inputs==True` and you want to use one of the `torch_uncertainty.routines` + for training, be sure to set `format_batch_fn=RepeatTarget(num_repeats=num_estimators)` when + initializing the routine. + + Example: + >>> model = nn.Sequential( + ... nn.Linear(10, 5), + ... nn.ReLU(), + ... nn.Linear(5, 2) + ... ) + >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) + >>> model + BatchEnsemble( + (model): Sequential( + (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) + (1): ReLU() + (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) + ) + ) """ - def __init__(self, model: nn.Module, num_estimators: int): + def __init__( + self, + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, + ) -> None: super().__init__() self.model = model self.num_estimators = num_estimators + self.repeat_training_inputs = repeat_training_inputs + + if convert_layers: + self._convert_layers() + + filtered_modules = [ + module + for module in self.model.modules() + if isinstance(module, BatchLinear | BatchConv2d) + ] + _batch_ensemble_checks(filtered_modules, num_estimators) def forward(self, x: torch.Tensor) -> torch.Tensor: """Repeats the input batch and passes it through the model.""" - repeat_shape = [self.num_estimators] + [1] * (x.dim() - 1) - x = x.repeat(repeat_shape) + if not self.training or self.repeat_training_inputs: + x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) return self.model(x) + + def _convert_layers(self) -> None: + """Converts the model's layers to BatchEnsemble layers.""" + no_valid_layers = True + for name, layer in self.model.named_modules(): + if isinstance(layer, nn.Linear): + setattr( + self.model, + name, + BatchLinear.from_linear(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + elif isinstance(layer, nn.Conv2d): + setattr( + self.model, + name, + BatchConv2d.from_conv2d(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + if no_valid_layers: + raise ValueError( + "No valid layers found in the model. " + "Please use `nn.Linear` or `nn.Conv2d` layers to apply BatchEnsemble." + ) + + +def _batch_ensemble_checks(filtered_modules, num_estimators): + """Check if the model contains the required number of dropout modules.""" + if len(filtered_modules) == 0: + raise ValueError( + "No BatchEnsemble layers found in the model. " + "Please use `BatchLinear` or `BatchConv2d` layers in your model " + "or set `convert_layers=True` when initializing the wrapper." + ) + if num_estimators <= 0: + raise ValueError("`num_estimators` must be greater than 0.") + + +def batch_ensemble( + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, +) -> BatchEnsemble: + """BatchEnsemble wrapper for a model. + + Args: + model (nn.Module): model to wrap + num_estimators (int): number of ensemble members + repeat_training_inputs (bool, optional): whether to repeat the input batch during training. + If `True`, the input batch is repeated during both training and evaluation. If `False`, + the input batch is repeated only during evaluation. Default is `False`. + convert_layers (bool, optional): whether to convert the model's layers to BatchEnsemble layers. + If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their + BatchEnsemble counterparts. Default is `False`. + + Returns: + BatchEnsemble: BatchEnsemble wrapper for the model + """ + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=convert_layers, + ) From 8b29026dd21bbfeac5efc3403ca715e72e610851 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:06:44 +0100 Subject: [PATCH 15/17] :books: Add BatchEnsemble wrapper utilities in the API Reference --- docs/source/api.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 16f1f1c2..0165ad49 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -201,6 +201,7 @@ Functions :toctree: generated/ :nosignatures: + batch_ensemble deep_ensembles mc_dropout @@ -212,6 +213,7 @@ Classes :nosignatures: :template: class.rst + BatchEnsemble CheckpointEnsemble EMA MCDropout From f2ae628ba8481910291368917f258edd29c63fdb Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:12:42 +0100 Subject: [PATCH 16/17] :book: `BatchEnsemble` docstring update --- torch_uncertainty/models/wrappers/batch_ensemble.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index 9d6213f1..d67230e5 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -83,7 +83,9 @@ def __init__( _batch_ensemble_checks(filtered_modules, num_estimators) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Repeats the input batch and passes it through the model.""" + """Repeat the input if `self.training == False` or `repeat_training_inputs==True` and pass + it through the model. + """ if not self.training or self.repeat_training_inputs: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) return self.model(x) From c737d07b24243191cfabccdba0b0ed41eee4043f Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:17:22 +0100 Subject: [PATCH 17/17] :book: `BatchEnsemble` docstring update --- .../models/wrappers/batch_ensemble.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index d67230e5..1e9bc46f 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -26,21 +26,21 @@ class BatchEnsemble(nn.Module): BatchEnsemble counterparts. Default is `False`. Raises: - ValueError: If neither `BatchLinear` nor `BatchConv2d` layers are found in the model at the + ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the end of initialization. - ValueError: If `num_estimators` is less than or equal to 0. - ValueError: If `convert_layers=True` and neither `nn.Linear` nor `nn.Conv2d` layers are + ValueError: If ``num_estimators`` is less than or equal to ``0``. + ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are found in the model. Warning: - If `convert_layers==True`, the wrapper will attempt to convert all `nn.Linear` and `nn.Conv2d` + If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d`` layers in the model to their BatchEnsemble counterparts. If the model contains other types of - layers, the conversion won't happen for these layers. If don't have any `nn.Linear` or `nn.Conv2d` + layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d`` layers in the model, the wrapper will raise an error during conversion. Warning: - If `repeat_training_inputs==True` and you want to use one of the `torch_uncertainty.routines` - for training, be sure to set `format_batch_fn=RepeatTarget(num_repeats=num_estimators)` when + If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines`` + for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when initializing the routine. Example: @@ -83,8 +83,8 @@ def __init__( _batch_ensemble_checks(filtered_modules, num_estimators) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Repeat the input if `self.training == False` or `repeat_training_inputs==True` and pass - it through the model. + """Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and + pass it through the model. """ if not self.training or self.repeat_training_inputs: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators)