diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index bb726902..4ec27105 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -26,6 +26,7 @@ from pathlib import Path from torch import nn +from torch.utils.data import DataLoader from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule @@ -84,15 +85,17 @@ # We can now wrap the model in a MCBatchNorm to add stochasticity to the # predictions. We specify that the BatchNorm layers are to be converted to # MCBatchNorm layers, and that we want to use 8 stochastic estimators. -# The amount of stochasticity is controlled by the ``mc_batch_size`` argument. -# The larger the ``mc_batch_size``, the more stochastic the predictions will be. -# The authors suggest 32 as a good value for ``mc_batch_size`` but we use 4 here +# The amount of stochasticity is controlled by the ``batch_size`` parameter. +# of the DataLoader used to train the model. +# The larger the ``batch_size``, the more stochastic the predictions will be. +# The authors suggest 32 as a good value for ``batch_size`` but we use 16 here # to highlight the effect of stochasticity on the predictions. routine.model = MCBatchNorm( - routine.model, num_estimators=8, convert=True, mc_batch_size=16 + routine.model, num_estimators=8, convert=True ) -routine.model.fit(datamodule.train) +mc_batch_norm_dl = DataLoader(datamodule.train, batch_size=16, shuffle=True) +routine.model.fit(dataloader=mc_batch_norm_dl) routine = routine.eval() # To avoid prints # %% diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index ceaaa036..a04da64b 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -84,6 +84,7 @@ dataset, [1000, 1000, len(dataset) - 2000] ) test_dataloader = DataLoader(test_dataset, batch_size=32) +calibration_dataloader = DataLoader(cal_dataset, batch_size=32) # Initialize the ECE ece = CalibrationError(task="multiclass", num_classes=100) @@ -114,7 +115,7 @@ # Fit the scaler on the calibration dataset scaled_model = TemperatureScaler(model=model) -scaled_model.fit(calibration_set=cal_dataset) +scaled_model.fit(dataloader=calibration_dataloader) # %% # 6. Iterating Again to Compute the Improved ECE diff --git a/tests/post_processing/test_laplace.py b/tests/post_processing/test_laplace.py index 8b6249ea..f2fdda7b 100644 --- a/tests/post_processing/test_laplace.py +++ b/tests/post_processing/test_laplace.py @@ -1,6 +1,6 @@ import torch from torch import nn -from torch.utils.data import TensorDataset +from torch.utils.data import DataLoader, TensorDataset from tests._dummies.model import dummy_model from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing @@ -20,12 +20,12 @@ class TestLaplace: """Testing the LaplaceApprox class.""" def test_training(self): - ds = TensorDataset(torch.randn(16, 1), torch.randn(16, 10)) + dl = DataLoader(TensorDataset(torch.randn(16, 1), torch.randn(16, 10)), batch_size=5) la = LaplaceApprox( task="classification", model=dummy_model(1, 10), ) - la.fit(ds) + la.fit(dl) la(torch.randn(1, 1)) la = LaplaceApprox(task="classification") la.set_model(dummy_model(1, 10)) diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index bbe987ca..b2277d94 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -4,6 +4,7 @@ import torch import torchvision.transforms as T from torch import nn +from torch.utils.data import DataLoader from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d @@ -17,14 +18,13 @@ class TestMCBatchNorm: def test_main(self): """Test initialization.""" mc_model = lenet(1, 1, norm=partial(MCBatchNorm2d, num_estimators=2)) - stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False, mc_batch_size=1) + stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False) model = lenet(1, 1, norm=nn.BatchNorm2d) stoch_model = MCBatchNorm( nn.Sequential(model), num_estimators=2, convert=True, - mc_batch_size=1, ) dataset = DummyClassificationDataset( "./", @@ -34,13 +34,14 @@ def test_main(self): num_images=2, transform=T.ToTensor(), ) - stoch_model.fit(dataset=dataset) + dl = DataLoader(dataset, batch_size=1, shuffle=True) + stoch_model.fit(dataloader=dl) stoch_model.train() stoch_model(torch.randn(1, 1, 20, 20)) stoch_model.eval() stoch_model(torch.randn(1, 1, 20, 20)) - stoch_model = MCBatchNorm(num_estimators=2, convert=False, mc_batch_size=1) + stoch_model = MCBatchNorm(num_estimators=2, convert=False) stoch_model.set_model(mc_model) def test_errors(self): @@ -48,14 +49,12 @@ def test_errors(self): model = nn.Identity() with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=0, convert=True) - with pytest.raises(ValueError, match="mc_batch_size must be a positive integer"): - MCBatchNorm(model, num_estimators=1, convert=True, mc_batch_size=-1) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=False) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=True) model = lenet(1, 1, norm=nn.BatchNorm2d) - stoch_model = MCBatchNorm(model, num_estimators=4, convert=True, mc_batch_size=1) + stoch_model = MCBatchNorm(model, num_estimators=4, convert=True) dataset = DummyClassificationDataset( "./", num_channels=1, @@ -64,9 +63,10 @@ def test_errors(self): num_images=2, transform=T.ToTensor(), ) + dl = DataLoader(dataset, batch_size=2, shuffle=True) stoch_model.eval() with pytest.raises(RuntimeError): stoch_model(torch.randn(1, 1, 20, 20)) with pytest.raises(ValueError): - stoch_model.fit(dataset=dataset) + stoch_model.fit(dataloader=dl) diff --git a/tests/post_processing/test_scalers.py b/tests/post_processing/test_scalers.py index fabe77f3..b499efc5 100644 --- a/tests/post_processing/test_scalers.py +++ b/tests/post_processing/test_scalers.py @@ -1,6 +1,7 @@ import pytest import torch from torch import nn, softmax +from torch.utils.data import DataLoader from torch_uncertainty.post_processing import ( MatrixScaler, @@ -26,10 +27,11 @@ def test_fit_biased(self): labels = torch.as_tensor([0.5, 0.5]).repeat(10, 1) calibration_set = list(zip(inputs, labels, strict=True)) + dl = DataLoader(calibration_set, batch_size=10) scaler = TemperatureScaler(model=nn.Identity(), init_val=2, lr=1, max_iter=10) assert scaler.temperature[0] == 2.0 - scaler.fit(calibration_set) + scaler.fit(dl) assert scaler.temperature[0] > 10 # best is +inf assert ( torch.sum( @@ -39,7 +41,7 @@ def test_fit_biased(self): ** 2 < 0.001 ) - scaler.fit_predict(calibration_set, progress=False) + scaler.fit_predict(dl, progress=False) def test_errors(self): with pytest.raises(ValueError): diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 0ec24375..b79e7f36 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -2,7 +2,7 @@ import torch from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from torch_uncertainty.layers.bayesian.abnn import BatchNormAdapter2d from torch_uncertainty.models import deep_ensembles @@ -25,7 +25,6 @@ def __init__( device: torch.device | str, max_epochs: int = 5, use_original_model: bool = True, - batch_size: int = 128, precision: str = "32", model: nn.Module | None = None, ): @@ -45,8 +44,6 @@ def __init__( to 5. use_original_model (bool, optional): Use original model during evaluation. Defaults to True. - batch_size (int, optional): Batch size for the training of ABNN. - Defaults to 128. precision (str, optional): Machine precision for training & eval. Defaults to "32". model (nn.Module | None, optional): Model to use. Defaults to None. @@ -63,7 +60,6 @@ def __init__( num_models=num_models, num_samples=num_samples, base_lr=base_lr, - batch_size=batch_size, ) self.num_classes = num_classes self.alpha = alpha @@ -74,7 +70,6 @@ def __init__( self.use_original_model = use_original_model self.max_epochs = max_epochs - self.batch_size = batch_size self.precision = precision self.device = device @@ -88,10 +83,9 @@ def __init__( weight[torch.randperm(num_classes)[:num_rp_classes]] += random_prior - 1 self.weights.append(weight) - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: if self.model is None: raise ValueError("Model must be set before fitting.") - dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) source_model = copy.deepcopy(self.model) _replace_bn_layers(source_model, self.alpha) @@ -119,7 +113,7 @@ def fit(self, dataset: Dataset) -> None: logger=None, enable_model_summary=False, ) - trainer.fit(model=baseline, train_dataloaders=dl) + trainer.fit(model=baseline, train_dataloaders=dataloader) final_models = ( [copy.deepcopy(source_model) for _ in range(self.num_samples)] diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py index 5b4ae9a6..7afe050c 100644 --- a/torch_uncertainty/post_processing/abstract.py +++ b/torch_uncertainty/post_processing/abstract.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader class PostProcessing(ABC, nn.Module): @@ -14,7 +14,7 @@ def set_model(self, model: nn.Module) -> None: self.model = model @abstractmethod - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: pass @abstractmethod diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index ed5d8e36..54c5c703 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn, optim -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from tqdm import tqdm from torch_uncertainty.post_processing import PostProcessing @@ -47,20 +47,14 @@ def __init__( def fit( self, - calibration_set: Dataset, - batch_size: int = 32, - shuffle: bool = False, - drop_last: bool = False, + dataloader: DataLoader, save_logits: bool = False, progress: bool = True, ) -> None: """Fit the temperature parameters to the calibration data. Args: - calibration_set (Dataset): Calibration dataset. - batch_size (int, optional): Batch size for the calibration dataset. Defaults to 32. - shuffle (bool, optional): Whether to shuffle the calibration dataset. Defaults to False. - drop_last (bool, optional): Whether to drop the last batch if it's smaller than batch_size. Defaults to False. + dataloader (DataLoader): Dataloader with the calibration data. save_logits (bool, optional): Whether to save the logits and labels. Defaults to False. progress (bool, optional): Whether to show a progress bar. @@ -73,9 +67,7 @@ def fit( all_logits = [] all_labels = [] - calibration_dl = DataLoader( - calibration_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last - ) + calibration_dl = dataloader with torch.no_grad(): for inputs, labels in tqdm(calibration_dl, disable=not progress): logits = self.model(inputs.to(self.device)) @@ -119,10 +111,11 @@ def _scale(self, logits: Tensor) -> Tensor: def fit_predict( self, - calibration_set: Dataset, + # calibration_set: Dataset, + dataloader: DataLoader, progress: bool = True, ) -> Tensor: - self.fit(calibration_set, save_logits=True, progress=progress) + self.fit(dataloader, save_logits=True, progress=progress) return self(self.logits) @property diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 7a918bd5..eda08d87 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -2,7 +2,7 @@ from typing import Literal from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from .abstract import PostProcessing @@ -23,7 +23,6 @@ def __init__( hessian_struct="kron", pred_type: Literal["glm", "nn"] = "glm", link_approx: Literal["mc", "probit", "bridge", "bridge_norm"] = "probit", - batch_size: int = 256, optimize_prior_precision: bool = True, ) -> None: """Laplace approximation for uncertainty estimation. @@ -42,8 +41,6 @@ def __init__( link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): how to approximate the classification link function for the `'glm'`. See the Laplace library for more details. Defaults to "probit". - batch_size (int, optional): batch size for the Laplace approximation. - Defaults to 256. optimize_prior_precision (bool, optional): whether to optimize the prior precision. Defaults to True. @@ -63,7 +60,6 @@ def __init__( self.task = task self.weight_subset = weight_subset self.hessian_struct = hessian_struct - self.batch_size = batch_size self.optimize_prior_precision = optimize_prior_precision if model is not None: @@ -78,9 +74,8 @@ def set_model(self, model: nn.Module) -> None: hessian_structure=self.hessian_struct, ) - def fit(self, dataset: Dataset) -> None: - dl = DataLoader(dataset, batch_size=self.batch_size) - self.la.fit(train_loader=dl) + def fit(self, dataloader: DataLoader) -> None: + self.la.fit(train_loader=dataloader) if self.optimize_prior_precision: self.la.optimize_prior_precision(method="marglik") diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 33dbf35d..d5d467b6 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.post_processing import PostProcessing @@ -19,7 +19,6 @@ def __init__( model: nn.Module | None = None, num_estimators: int = 16, convert: bool = True, - mc_batch_size: int = 32, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Monte Carlo Batch Normalization wrapper. @@ -28,7 +27,6 @@ def __init__( model (nn.Module): model to be converted. num_estimators (int): number of estimators. convert (bool): whether to convert the model. - mc_batch_size (int, optional): Monte Carlo batch size. Defaults to 32. device (Literal["cpu", "cuda"] | torch.device | None, optional): device. Defaults to None. @@ -40,7 +38,6 @@ def __init__( batch normalized deep networks. In ICML 2018. """ super().__init__() - self.mc_batch_size = mc_batch_size self.convert = convert self.num_estimators = num_estimators self.device = device @@ -49,7 +46,7 @@ def __init__( self._setup_model(model) def _setup_model(self, model): - _mcbn_checks(model, self.num_estimators, self.mc_batch_size, self.convert) + _mcbn_checks(model, self.num_estimators, self.convert) self.model = deepcopy(model) # TODO: Is it necessary? self.model = self.model.eval() if self.convert: @@ -61,22 +58,28 @@ def set_model(self, model: nn.Module) -> None: self.model = model self._setup_model(model) - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: """Fit the model on the dataset. Args: - dataset (Dataset): dataset to be used for fitting. + dataloader (DataLoader): DataLoader with the training dataset. Note: This method is used to populate the MC BatchNorm layers. Use the training dataset. + + Warning: + The ``batch_size`` of the DataLoader should be carefully chosen as it + will have an impact on the statistics of the MC BatchNorm layers. + + Raises: + ValueError: If there are less batches than the number of estimators. """ - self.dl = DataLoader(dataset, batch_size=self.mc_batch_size, shuffle=True) self.counter = 0 self.reset_counters() self.set_accumulate(True) self.eval() - for x, _ in self.dl: + for x, _ in dataloader: self.model(x.to(self.device)) self.raise_counters() if self.counter == self.num_estimators: @@ -162,10 +165,8 @@ def has_mcbn(model: nn.Module) -> bool: return any(isinstance(module, MCBatchNorm2d) for module in model.modules()) -def _mcbn_checks(model, num_estimators, mc_batch_size, convert): +def _mcbn_checks(model, num_estimators, convert): if num_estimators < 1 or not isinstance(num_estimators, int): raise ValueError(f"num_estimators must be a positive integer, got {num_estimators}.") - if mc_batch_size < 1 or not isinstance(mc_batch_size, int): - raise ValueError(f"mc_batch_size must be a positive integer, got {mc_batch_size}.") if not convert and not has_mcbn(model): raise ValueError("model does not contain any MCBatchNorm2d nor is not to be converted.") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index e5469b5c..0d0faeb9 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -381,13 +381,13 @@ def on_test_start(self) -> None: the storage lists for logit plotting and update the batchnorms if needed. """ if self.post_processing is not None: - calibration_dataset = ( - self.trainer.datamodule.val_dataloader().dataset + calibration_dataloader = ( + self.trainer.datamodule.val_dataloader() if self.calibration_set == "val" - else self.trainer.datamodule.test_dataloader()[0].dataset + else self.trainer.datamodule.test_dataloader()[0] ) with torch.inference_mode(False): - self.post_processing.fit(calibration_dataset) + self.post_processing.fit(calibration_dataloader) if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = []