diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index c404a106..a870c00e 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -65,7 +65,7 @@ jobs: if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - python3 -m pip install .[image,dev,docs] + python3 -m pip install .[all] - name: Check style & format if: steps.changed-files-specific.outputs.only_changed != 'true' diff --git a/README.md b/README.md index 00d0bc2a..a66c6b85 100644 --- a/README.md +++ b/README.md @@ -18,21 +18,21 @@ _TorchUncertainty_ is a package designed to help you leverage [uncertainty quant :books: Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). :books: -TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **WACV 2024** and **ECCV 2024**. +TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **[WACV](https://wacv2024.thecvf.com/) 2024**, **[HAICON](https://haicon24.de/) 2024** and **[ECCV](https://eccv.ecva.net/) 2024**. --- This package provides a multi-level API, including: -- easy-to-use ⚡️ lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. +- easy-to-use :zap: lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. - ready-to-train baselines on research datasets, such as ImageNet and CIFAR -- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (work in progress 🚧). +- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR (:construction: work in progress :construction:). - **layers**, **models**, **metrics**, & **losses** available for use in your networks - scikit-learn style post-processing methods such as Temperature Scaling. Have a look at the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. -## ⚙️ Installation +## :gear: Installation TorchUncertainty requires Python 3.10 or greater. Install the desired PyTorch version in your environment. Then, install the package from PyPI: @@ -51,7 +51,6 @@ We make a quickstart available at [torch-uncertainty.github.io/quickstart](https TorchUncertainty currently supports **classification**, **probabilistic** and pointwise **regression**, **segmentation** and **pixelwise regression** (such as monocular depth estimation). It includes the official codes of the following papers: -- *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) - *LP-BNN: Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification* - [IEEE TPAMI](https://arxiv.org/abs/2012.02818) - *Packed-Ensembles for Efficient Uncertainty Estimation* - [ICLR 2023](https://arxiv.org/abs/2210.09184) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) - *MUAD: Multiple Uncertainties for Autonomous Driving, a benchmark for multiple uncertainty types and tasks* - [BMVC 2022](https://arxiv.org/abs/2203.01437) @@ -60,17 +59,16 @@ We also provide the following methods: ### Baselines -To date, the following deep learning baselines have been implemented: +To date, the following deep learning baselines have been implemented. **Click on the methods for tutorials**: -- Deep Ensembles -- MC-Dropout - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) -- BatchEnsemble -- Masksembles -- MIMO -- Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) -- Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) +- [Deep Ensembles](https://torch-uncertainty.github.io/auto_tutorials/tutorial_from_de_to_pe.html), BatchEnsemble, Masksembles, & MIMO +- [MC-Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) +- [Packed-Ensembles](https://torch-uncertainty.github.io/auto_tutorials/tutorial_from_de_to_pe.html) (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) +- [Variational Bayesian Neural Networks](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) +- Checkpoint Ensembles & Snapshot Ensembles +- Stochastic Weight Averaging & Stochastic Weight Averaging Gaussian - Regression with Beta Gaussian NLL Loss -- Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) +- [Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) & [Regression](https://torch-uncertainty.github.io/auto_tutorials/tutorial_der_cubic.html) ### Augmentation methods @@ -82,16 +80,18 @@ The following data augmentation methods have been implemented: To date, the following post-processing methods have been implemented: -- Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) -- Monte Carlo Batch Normalization - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_batch_norm.html) +- [Temperature](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html), Vector, & Matrix scaling +- [Monte Carlo Batch Normalization](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_batch_norm.html) +- Laplace approximation using the [Laplace library](https://github.com/aleximmer/Laplace) ## Tutorials -Our documentation contains the following tutorials: +Check out our tutorials at [torch-uncertainty.github.io/auto_tutorials](https://torch-uncertainty.github.io/auto_tutorials/index.html). + +## :telescope: Projects using TorchUncertainty + +The following projects use TorchUncertainty: + +- *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) -- [From a Standard Classifier to a Packed-Ensemble](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) -- [Training a Bayesian Neural Network in 3 minutes](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) -- [Improve Top-label Calibration with Temperature Scaling](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) -- [Deep Evidential Regression on a Toy Example](https://torch-uncertainty.github.io/auto_tutorials/tutorial_der_cubic.html) -- [Training a LeNet with Monte-Carlo Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) -- [Training a LeNet with Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) +**If you are using TorchUncertainty in your project, please let us know, we will add your project to this list!** diff --git a/auto_tutorials_source/tutorial_bayesian.py b/auto_tutorials_source/tutorial_bayesian.py index 04f1202e..68628d2a 100644 --- a/auto_tutorials_source/tutorial_bayesian.py +++ b/auto_tutorials_source/tutorial_bayesian.py @@ -55,12 +55,12 @@ # We will use the Adam optimizer with the default learning rate of 0.001. -def optim_lenet(model: nn.Module) -> dict: +def optim_lenet(model: nn.Module): optimizer = optim.Adam( model.parameters(), lr=1e-3, ) - return {"optimizer": optimizer} + return optimizer # %% @@ -75,7 +75,7 @@ def optim_lenet(model: nn.Module) -> dict: trainer = Trainer(accelerator="cpu", enable_progress_bar=False, max_epochs=1) # datamodule -root = Path("") / "data" +root = Path("data") datamodule = MNISTDataModule(root=root, batch_size=128, eval_ood=False) # model @@ -105,6 +105,7 @@ def optim_lenet(model: nn.Module) -> dict: num_classes=datamodule.num_classes, loss=loss, optim_recipe=optim_lenet(model), + is_ensemble=True ) # %% @@ -125,8 +126,10 @@ def optim_lenet(model: nn.Module) -> dict: # 6. Testing the Model # ~~~~~~~~~~~~~~~~~~~~ # -# Now that the model is trained, let's test it on MNIST - +# Now that the model is trained, let's test it on MNIST. +# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble +# and to the batch. As for TorchUncertainty 0.2.0, the ensemble dimension is merged with the batch dimension +# in this order (num_estimator x batch, classes). import matplotlib.pyplot as plt import numpy as np import torch @@ -148,14 +151,23 @@ def imshow(img): imshow(torchvision.utils.make_grid(images[:4, ...])) print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) -logits = model(images) +# Put the model in eval mode to use several samples +model = model.eval() +logits = model(images).reshape(16, 128, 10) # num_estimators, batch_size, num_classes + +# We apply the softmax on the classes and average over the estimators probs = torch.nn.functional.softmax(logits, dim=-1) +avg_probs = probs.mean(dim=0) +var_probs = probs.std(dim=0) -_, predicted = torch.max(probs, 1) +_, predicted = torch.max(avg_probs, 1) print("Predicted digits: ", " ".join(f"{predicted[j]}" for j in range(4))) - +print("Std. dev. of the scores over the posterior samples", " ".join(f"{var_probs[j][predicted[j]]:.3}" for j in range(4))) # %% +# Here, we show the variance of the top prediction. This is a non-standard but intuitive way to show the diversity of the predictions +# of the ensemble. Ideally, the variance should be high when the average top prediction is incorrect. +# # References # ---------- # diff --git a/auto_tutorials_source/tutorial_corruptions.py b/auto_tutorials_source/tutorial_corruption.py similarity index 95% rename from auto_tutorials_source/tutorial_corruptions.py rename to auto_tutorials_source/tutorial_corruption.py index d20e4f19..9e4e7a10 100644 --- a/auto_tutorials_source/tutorial_corruptions.py +++ b/auto_tutorials_source/tutorial_corruption.py @@ -1,6 +1,6 @@ """ -Image Corruptions -================= +Corrupting Images with TorchUncertainty to Benchmark Robustness +=============================================================== This tutorial shows the impact of the different corruptions available in the TorchUncertainty library. These corruptions were first proposed in the paper diff --git a/auto_tutorials_source/tutorial_der_cubic.py b/auto_tutorials_source/tutorial_der_cubic.py index b77a0a4d..96d72375 100644 --- a/auto_tutorials_source/tutorial_der_cubic.py +++ b/auto_tutorials_source/tutorial_der_cubic.py @@ -29,7 +29,6 @@ We also need to define an optimizer using torch.optim and the neural network utils within torch.nn. """ -# %% import torch from lightning.pytorch import Trainer from lightning import LightningDataModule @@ -49,15 +48,13 @@ def optim_regression( model: nn.Module, learning_rate: float = 5e-4, -) -> dict: +): optimizer = optim.Adam( model.parameters(), lr=learning_rate, weight_decay=0, ) - return { - "optimizer": optimizer, - } + return optimizer # %% @@ -69,7 +66,7 @@ def optim_regression( # Please note that this MLP finishes with a NormalInverseGammaLayer that interpret the outputs of the model # as the parameters of a Normal Inverse Gamma distribution. -trainer = Trainer(accelerator="cpu", max_epochs=50)#, enable_progress_bar=False) +trainer = Trainer(accelerator="cpu", max_epochs=50) #, enable_progress_bar=False) # dataset train_ds = Cubic(num_samples=1000) diff --git a/auto_tutorials_source/tutorial_from_de_to_pe.py b/auto_tutorials_source/tutorial_from_de_to_pe.py new file mode 100644 index 00000000..2290a024 --- /dev/null +++ b/auto_tutorials_source/tutorial_from_de_to_pe.py @@ -0,0 +1,416 @@ +"""Improved Ensemble parameter-efficiency with Packed-Ensembles +============================================================ + +*This tutorial is adapted from a notebook part of a lecture given at the [Helmholtz AI Conference](https://haicon24.de/) by Sebastian Starke, Peter Steinbach, Gianni Franchi, and Olivier Laurent.* + +In this notebook will work on the MNIST dataset that was introduced by Corinna Cortes, Christopher J.C. Burges, and later modified by Yann LeCun in the foundational paper: + +- [Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE.](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) + +The MNIST dataset consists of 70 000 images of handwritten digits from 0 to 9. The images are grayscale and 28x28-pixel sized. The task is to classify the images into their respective digits. The dataset can be automatically downloaded using the `torchvision` library. + +In this notebook, we will train a model and an ensemble on this task and evaluate their performance. The performance will consist in the following metrics: +- Accuracy: the proportion of correctly classified images, +- Brier score: a measure of the quality of the predicted probabilities, +- Calibration error: a measure of the calibration of the predicted probabilities, +- Negative Log-Likelihood: the value of the loss on the test set. + +Throughout this notebook, we abstract the training and evaluation process using [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) +and [TorchUncertainty](https://torch-uncertainty.github.io/). + +Similarly to keras for tensorflow, PyTorch Lightning is a high-level interface for PyTorch that simplifies the training and evaluation process using a Trainer. +TorchUncertainty is partly built on top of PyTorch Lightning and provides tools to train and evaluate models with uncertainty quantification. + +TorchUncertainty includes datamodules that handle the data loading and preprocessing. We don't use them here for tutorial purposes. +""" +# 1. Download, instantiate and visualize the datasets +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The dataset is automatically downloaded using torchvision. We then visualize a few images to see a bit what we are working with. + +# Create the transforms for the images +import torch +import torchvision.transforms as T + +# We set the number of epochs to some low value for the sake of time +max_epochs = 2 + +train_transform = T.Compose( + [ + T.ToTensor(), + # We perform random cropping as data augmentation + T.RandomCrop(28, padding=4), + # As for the MNIST1d dataset, we normalize the data + T.Normalize((0.1307,), (0.3081,)), + ] +) +test_transform = T.Compose( + [ + T.Grayscale(num_output_channels=1), + T.ToTensor(), + T.CenterCrop(28), + T.Normalize((0.1307,), (0.3081,)), + ] +) + +# Download and instantiate the dataset +from torch.utils.data import Subset +from torchvision.datasets import MNIST, FashionMNIST + +train_data = MNIST( + root="./data/", download=True, train=True, transform=train_transform +) +test_data = MNIST(root="./data/", train=False, transform=test_transform) +# We only take the first 10k images to have the same number of samples as the test set using torch Subsets +ood_data = Subset( + FashionMNIST(root="./data/", download=True, transform=test_transform), + indices=range(10000), +) + +# Create the corresponding dataloaders +from torch.utils.data import DataLoader + +train_dl = DataLoader(train_data, batch_size=32, shuffle=True) +test_dl = DataLoader(test_data, batch_size=32, shuffle=False) +ood_dl = DataLoader(ood_data, batch_size=32, shuffle=False) + +# %% +# You could replace all this cell by simply loading the MNIST datamodule from TorchUncertainty. +# Now, let's visualize a few images from the dataset. For this task, we use the viz_data dataset that applies no transformation to the images. + +# Datasets without transformation to visualize the unchanged data +viz_data = MNIST(root="./data/", train=False) +ood_viz_data = FashionMNIST(root="./data/", download=True) + +print("In distribution data:") +viz_data[0][0] +# %% +print("Out of distribution data:") +ood_viz_data[0][0] + +# %% +# 2. Create & train the model +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We will create a simple convolutional neural network (CNN): the LeNet model (also introduced by LeCun). +import torch.nn as nn +import torch.nn.functional as F + + +class LeNet(nn.Module): + def __init__( + self, + in_channels: int, + num_classes: int, + ) -> None: + super().__init__() + self.conv1 = nn.Conv2d(in_channels, 6, (5, 5)) + self.conv2 = nn.Conv2d(6, 16, (5, 5)) + self.pooling = nn.AdaptiveAvgPool2d((4, 4)) + self.fc1 = nn.Linear(256, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = torch.flatten(out, 1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + return self.fc3(out) # No softmax in the model! + + +# Instantiate the model, the images are in grayscale so the number of channels is 1 +model = LeNet(in_channels=1, num_classes=10) + +# %% +# We now need to define the optimization recipe: +# - the optimizer, here the standard stochastic gradient descent (SGD) with a learning rate of 0.05 +# - the scheduler, here cosine annealing. + + +def optim_recipe(model, lr_mult: float = 1.0): + optimizer = torch.optim.SGD(model.parameters(), lr=0.05 * lr_mult) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + return {"optimizer": optimizer, "scheduler": scheduler} + + +# %% +# To train the model, we use [TorchUncertainty](https://torch-uncertainty.github.io/), a library that we have developed to ease +# the training and evaluation of models with uncertainty. +# +# **Note:** To train supervised classification models we most often use the cross-entropy loss. +# With weight-decay, minimizing this loss amounts to finding a Maximum a posteriori (MAP) estimate of the model parameters. +# This means that the model is trained to predict the most likely class for each input. + + +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TUTrainer + +# Create the trainer that will handle the training +trainer = TUTrainer(accelerator="cpu", max_epochs=max_epochs) + +# The routine is a wrapper of the model that contains the training logic with the metrics, etc +routine = ClassificationRoutine( + num_classes=10, + model=model, + loss=nn.CrossEntropyLoss(), + optim_recipe=optim_recipe(model), + eval_ood=True, +) + +# In practice, avoid performing the validation on the test set (if you do model selection) +trainer.fit(routine, train_dataloaders=train_dl, val_dataloaders=test_dl) + +# %% +# Evaluate the trained model on the test set - pay attention to the cls/Acc metric +perf = trainer.test(routine, dataloaders=[test_dl, ood_dl]) + +# %% +# This table provides a lot of information: +# +# **OOD Detection: Binary Classification MNIST vs. FashionMNIST** +# - AUPR/AUROC/FPR95: Measures the quality of the OOD detection. The higher the better for AUPR and AUROC, the lower the better for FPR95. +# +# **Calibration: Reliability of the Predictions** +# - ECE: Expected Calibration Error. The lower the better. +# - aECE: Adaptive Expected Calibration Error. The lower the better. (~More precise version of the ECE) +# +# **Classification Performance** +# - Accuracy: The ratio of correctly classified images. The higher the better. +# - Brier: The quality of the predicted probabilities (Mean Squared Error of the predictions vs. ground-truth). The lower the better. +# - Negative Log-Likelihood: The value of the loss on the test set. The lower the better. +# +# **Selective Classification & Grouping Loss** +# - We talk about these points later in the "To go further" section. +# +# 3. Training an ensemble of models with TorchUncertainty +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# You have two options here, you can either train the ensemble directly if you have enough memory, +# otherwise, you can train independent models and do the ensembling during the evaluation (sometimes called inference). +# +# In this case, we will do it sequentially. In this tutorial, you have the choice between training multiple models, +# which will take time if you have no GPU, or downloading the pre-trained models that we have prepared for you. +# +# Training the ensemble +# +# To train the ensemble, you will have to use the "deep_ensembles" function from TorchUncertainty, which will +# replicate and change the initialization of your networks to ensure diversity. + +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.transforms import RepeatTarget + +# Create the ensemble model +ensemble = deep_ensembles( + LeNet(in_channels=1, num_classes=10), + num_estimators=2, + task="classification", + reset_model_parameters=True, +) + +trainer = TUTrainer(accelerator="cpu", max_epochs=1) +ens_routine = ClassificationRoutine( + is_ensemble=True, + num_classes=10, + model=ensemble, + loss=nn.CrossEntropyLoss(), # The loss for the training + format_batch_fn=RepeatTarget( + 2 + ), # How to handle the targets when comparing the predictions + optim_recipe=optim_recipe( + ensemble, 2.0 + ), # The optimization scheme with the optimizer and the scheduler as a dictionnary + eval_ood=True, # We want to evaluate the OOD-related metrics +) +trainer.fit(ens_routine, train_dataloaders=train_dl, val_dataloaders=test_dl) +ens_perf = trainer.test(ens_routine, dataloaders=[test_dl, ood_dl]) + +# %% +# The results are not comparable since we only trained the ensemble for one epoch to reduce GitHub's cpu usage. +# Feel free to run the notebook on your machine for a longer duration. +# +# We need to multiply the learning rate by 2 to account for the fact that we have 4 models +# in the ensemble and that we average the loss over all the predictions. +# +# #### Downloading the pre-trained models +# +# We have put the pre-trained models on Hugging Face that you can download with the utility function +# "hf_hub_download" imported just below. These models are trained for 75 epochs and are therefore not +# comparable to the all the other models trained in this notebook. The pretrained models can be seen +# [here](https://huggingface.co/ENSTA-U2IS/tutorial-models) and TorchUncertainty's are [here](https://huggingface.co/torch-uncertainty). + +from torch_uncertainty.utils.hub import hf_hub_download + +all_models = [] +for i in range(8): + hf_hub_download( + repo_id="ENSTA-U2IS/tutorial-models", + filename=f"version_{i}.ckpt", + local_dir="./models/", + ) + model = LeNet(in_channels=1, num_classes=10) + state_dict = torch.load(f"./models/version_{i}.ckpt", map_location="cpu")[ + "state_dict" + ] + state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} + model.load_state_dict(state_dict) + all_models.append(model) + +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.transforms import RepeatTarget + +ensemble = deep_ensembles( + all_models, + num_estimators=None, + task="classification", + reset_model_parameters=True, +) + +ens_routine = ClassificationRoutine( + is_ensemble=True, + num_classes=10, + model=ensemble, + loss=nn.CrossEntropyLoss(), # The loss for the training + format_batch_fn=RepeatTarget( + 8 + ), # How to handle the targets when comparing the predictions + optim_recipe=None, # No optim recipe as the model is already trained + eval_ood=True, # We want to evaluate the OOD-related metrics +) + +trainer = TUTrainer(accelerator="cpu", max_epochs=max_epochs) + +ens_perf = trainer.test(ens_routine, dataloaders=[test_dl, ood_dl]) + +# %% +# 4. From Deep Ensembles to Packed-Ensembles +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In the paper [Packed-Ensembles for Efficient Uncertainty Quantification](https://arxiv.org/abs/2210.09184) +# published at the International Conference on Learning Representations (ICLR) in 2023, we introduced a +# modification of Deep Ensembles to make it more computationally-efficient. The idea is to pack the ensemble +# members into a single model, which allows us to train the ensemble in a single forward pass. +# This modification is particularly useful when the ensemble size is large, as it is often the case in practice. +# +# We will need to update the model and replace the layers with their Packed equivalents. You can find the +# documentation of the Packed-Linear layer [here](https://torch-uncertainty.github.io/generated/torch_uncertainty.layers.PackedLinear.html), +# and the Packed-Conv2D, [here](https://torch-uncertainty.github.io/generated/torch_uncertainty.layers.PackedLinear.html). + +import torch +import torch.nn as nn +from einops import rearrange + +from torch_uncertainty.layers import PackedConv2d, PackedLinear + + +class PackedLeNet(nn.Module): + def __init__( + self, + in_channels: int, + num_classes: int, + alpha: int, + num_estimators: int, + ) -> None: + super().__init__() + self.num_estimators = num_estimators + self.conv1 = PackedConv2d( + in_channels, + 6, + (5, 5), + alpha=alpha, + num_estimators=num_estimators, + first=True, + ) + self.conv2 = PackedConv2d( + 6, + 16, + (5, 5), + alpha=alpha, + num_estimators=num_estimators, + ) + self.pooling = nn.AdaptiveAvgPool2d((4, 4)) + self.fc1 = PackedLinear( + 256, 120, alpha=alpha, num_estimators=num_estimators + ) + self.fc2 = PackedLinear( + 120, 84, alpha=alpha, num_estimators=num_estimators + ) + self.fc3 = PackedLinear( + 84, + num_classes, + alpha=alpha, + num_estimators=num_estimators, + last=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = F.relu(self.conv1(x)) + out = F.max_pool2d(out, 2) + out = F.relu(self.conv2(out)) + out = F.max_pool2d(out, 2) + out = rearrange( + out, "e (m c) h w -> (m e) c h w", m=self.num_estimators + ) + out = torch.flatten(out, 1) + out = F.relu(self.fc1(out)) + out = F.relu(self.fc2(out)) + return self.fc3(out) # Again, no softmax in the model + + +# Instantiate the model, the images are in grayscale so the number of channels is 1 +packed_model = PackedLeNet( + in_channels=1, num_classes=10, alpha=2, num_estimators=4 +) + +# Create the trainer that will handle the training +trainer = TUTrainer(accelerator="cpu", max_epochs=max_epochs) + +# The routine is a wrapper of the model that contains the training logic with the metrics, etc +packed_routine = ClassificationRoutine( + is_ensemble=True, + num_classes=10, + model=packed_model, + loss=nn.CrossEntropyLoss(), + format_batch_fn=RepeatTarget(4), + optim_recipe=optim_recipe(packed_model, 4.0), + eval_ood=True, +) + +# In practice, avoid performing the validation on the test set +trainer.fit(packed_routine, train_dataloaders=train_dl, val_dataloaders=test_dl) + +packed_perf = trainer.test(packed_routine, dataloaders=[test_dl, ood_dl]) + +# %% +# The training time should be approximately similar to the one of the single model that you trained before. However, please note that we are working with very small models, hence completely underusing your GPU. As such, the training time is not representative of what you would observe with larger models. +# +# You can read more on Packed-Ensembles in the [paper](https://arxiv.org/abs/2210.09184) or the [Medium](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873) post. +# +# To Go Further & More Concepts of Uncertainty in ML +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# **Question 1:** Have a look at the models in the "lightning_logs". If you are on your own machine, try to visualize the learning curves with `tensorboard --logdir lightning_logs`. +# +# **Question 2:** Add a cell below and try to find the errors made by packed-ensembles on the test set. Visualize the errors and their labels and look at the predictions of the different sub-models. Are they similar? Can you think of uncertainty scores that could help you identify these errors? +# +# Selective Classification +# ^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Selective classification or "prediction with rejection" is a paradigm in uncertainty-aware machine learning where the model can decide not to make a prediction if the confidence score given by the model is below some pre-computed threshold. This can be useful in real-world applications where the cost of making a wrong prediction is high. +# +# In constrast to calibration, the values of the confidence scores are not important, only the order of the scores. *Ideally, the best model will order all the correct predictions first, and all the incorrect predictions last.* In this case, there will be a threshold so that all the predictions above the threshold are correct, and all the predictions below the threshold are incorrect. +# +# In TorchUncertainty, we look at 3 different metrics for selective classification: +# - **AURC**: The area under the Risk (% of errors) vs. Coverage (% of classified samples) curve. This curve expresses how the risk of the model evolves as we increase the coverage (the proportion of predictions that are above the selection threshold). This metric will be minimized by a model able to perfectly separate the correct and incorrect predictions. +# +# The following metrics are computed at a fixed risk and coverage level and that have practical interests. The idea of these metrics is that you can set the selection threshold to achieve a certain level of risk and coverage, as required by the technical constraints of your application: +# - **Coverage at 5% Risk**: The proportion of predictions that are above the selection threshold when it is set for the risk to egal 5%. Set the risk threshold to your application constraints. The higher the better. +# - **Risk at 80% Coverage**: The proportion of errors when the coverage is set to 80%. Set the coverage threshold to your application constraints. The lower the better. +# +# Grouping Loss +# ^^^^^^^^^^^^^ +# +# The grouping loss is a measure of uncertainty orthogonal to calibration. Have a look at [this paper](https://arxiv.org/abs/2210.16315) to learn about it. Check out their small library [GLest](https://github.com/aperezlebel/glest). TorchUncertainty includes a wrapper of the library to compute the grouping loss with eval_grouping_loss parameter. diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 12781e2b..a8bed883 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -22,7 +22,6 @@ We also need import the neural network utils within `torch.nn`. """ -# %% from pathlib import Path from lightning import Trainer @@ -44,7 +43,7 @@ trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -root = Path("") / "data" +root = Path("data") datamodule = MNISTDataModule(root, batch_size=128) @@ -76,7 +75,7 @@ # `trainer.test`. trainer.fit(model=routine, datamodule=datamodule) -trainer.test(model=routine, datamodule=datamodule); +perf = trainer.test(model=routine, datamodule=datamodule) # %% # 5. Wrapping the Model in a MCBatchNorm @@ -93,7 +92,7 @@ routine.model, num_estimators=8, convert=True, mc_batch_size=16 ) routine.model.fit(datamodule.train) -routine.eval(); +routine = routine.eval() # To avoid prints # %% # 6. Testing the Model @@ -102,6 +101,9 @@ # .eval() to enable Monte Carlo batch normalization at inference. # In this tutorial, we plot the most uncertain images, i.e. the images for which # the variance of the predictions is the highest. +# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble +# and to the batch. As for TorchUncertainty 2.0, the ensemble dimension is merged with the batch dimension +# in this order (num_estimator x batch, classes). import matplotlib.pyplot as plt import numpy as np @@ -121,7 +123,7 @@ def imshow(img): images, labels = next(dataiter) routine.eval() -logits = routine(images).reshape(8, 128, 10) +logits = routine(images).reshape(8, 128, 10) # num_estimators, batch_size, num_classes probs = torch.nn.functional.softmax(logits, dim=-1) most_uncertain = sorted(probs.var(0).sum(-1).topk(4).indices) diff --git a/auto_tutorials_source/tutorial_mc_dropout.py b/auto_tutorials_source/tutorial_mc_dropout.py index 59dd4241..e19eee61 100644 --- a/auto_tutorials_source/tutorial_mc_dropout.py +++ b/auto_tutorials_source/tutorial_mc_dropout.py @@ -22,8 +22,8 @@ - the Trainer from Lightning - the datamodule handling dataloaders: MNISTDataModule from torch_uncertainty.datamodules - the model: LeNet, which lies in torch_uncertainty.models -- the MC Dropout wrapper: mc_dropout, which lies in torch_uncertainty.models -- the classification training routine in the torch_uncertainty.routines +- the MC Dropout wrapper: mc_dropout, from torch_uncertainty.models.wrappers +- the classification training & evaluation routine in the torch_uncertainty.routines - an optimization recipe in the torch_uncertainty.optim_recipes module. We also need import the neural network utils within `torch.nn`. @@ -51,20 +51,19 @@ # dataloaders and transforms. We create the model using the # blueprint from torch_uncertainty.models and we wrap it into mc_dropout. # -# It is important to specify the arguments,``num_estimators`` and the ``dropout_rate`` -# to use Monte Carlo dropout. +# It is important to add a ``dropout_rate`` argument in your model to use Monte Carlo dropout. trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False) # datamodule -root = Path("") / "data" +root = Path("data") datamodule = MNISTDataModule(root=root, batch_size=128) model = lenet( in_channels=datamodule.num_channels, num_classes=datamodule.num_classes, - dropout_rate=0.4, + dropout_rate=0.5, ) mc_model = mc_dropout(model, num_estimators=16, last_layer=False) @@ -75,16 +74,14 @@ # This is a classification problem, and we use CrossEntropyLoss as the likelihood. # We define the training routine using the classification training routine from # torch_uncertainty.routines.classification. We provide the number of classes -# and channels, the optimizer wrapper, the dropout rate, and the number of -# forward passes to perform through the network, as well as all the default -# arguments. +# and channels, the optimizer wrapper, and the dropout rate. routine = ClassificationRoutine( num_classes=datamodule.num_classes, model=mc_model, loss=nn.CrossEntropyLoss(), optim_recipe=optim_cifar10_resnet18(mc_model), - num_estimators=16, + is_ensemble=True, ) # %% @@ -118,8 +115,8 @@ def imshow(img): images, labels = next(dataiter) # print images -imshow(torchvision.utils.make_grid(images[:4, ...])) -print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) +imshow(torchvision.utils.make_grid(images[:6, ...], padding=0)) +print("Ground truth labels: ", " ".join(f"{labels[j]}" for j in range(6))) routine.eval() logits = routine(images).reshape(16, 128, 10) @@ -127,7 +124,7 @@ def imshow(img): probs = torch.nn.functional.softmax(logits, dim=-1) -for j in range(4): +for j in range(6): values, predicted = torch.max(probs[:, j], 1) print( f"Predicted digits for the image {j+1}: ", @@ -135,5 +132,5 @@ def imshow(img): ) # %% -# We see that there is some disagreement between the samples of the dropout +# Most of the time, we see that there is some disagreement between the samples of the dropout # approximation of the posterior distribution. diff --git a/auto_tutorials_source/tutorial_pe_cifar10.py b/auto_tutorials_source/tutorial_pe_cifar10.py index 57b6b51f..d3a233cd 100644 --- a/auto_tutorials_source/tutorial_pe_cifar10.py +++ b/auto_tutorials_source/tutorial_pe_cifar10.py @@ -45,6 +45,7 @@ import torch import torchvision import torchvision.transforms as transforms +from torch.utils.data import DataLoader torch.set_num_threads(1) @@ -69,14 +70,14 @@ trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) -trainloader = torch.utils.data.DataLoader( +trainloader = DataLoader( trainset, batch_size=batch_size, shuffle=True, num_workers=2 ) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) -testloader = torch.utils.data.DataLoader( +testloader = DataLoader( testset, batch_size=batch_size, shuffle=False, num_workers=2 ) diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index 75f1953c..fdbfc469 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -25,7 +25,6 @@ If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the `log_plots` flag. """ - from torch_uncertainty.datamodules import CIFAR100DataModule from torch_uncertainty.metrics import CalibrationError from torch_uncertainty.models.resnet import resnet @@ -114,7 +113,7 @@ # Fit the scaler on the calibration dataset scaled_model = TemperatureScaler(model=model) -scaled_model = scaled_model.fit(calibration_set=cal_dataset) +scaled_model.fit(calibration_set=cal_dataset) # %% # 6. Iterating Again to Compute the Improved ECE diff --git a/docs/source/api.rst b/docs/source/api.rst index d4f99acf..ed5e07ce 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -22,25 +22,25 @@ Classification ClassificationRoutine -Regression -^^^^^^^^^^ +Segmentation +^^^^^^^^^^^^ .. autosummary:: :toctree: generated/ :nosignatures: :template: class.rst - RegressionRoutine + SegmentationRoutine -Segmentation -^^^^^^^^^^^^ +Regression +^^^^^^^^^^ .. autosummary:: :toctree: generated/ :nosignatures: :template: class.rst - SegmentationRoutine + RegressionRoutine Pixelwise Regression ^^^^^^^^^^^^^^^^^^^^ @@ -153,24 +153,35 @@ Models .. currentmodule:: torch_uncertainty.models -Deep Ensembles -^^^^^^^^^^^^^^ +Wrappers +^^^^^^^^ + + + +Functions +""""""""" .. autosummary:: :toctree: generated/ :nosignatures: - :template: class.rst deep_ensembles + mc_dropout -Monte Carlo Dropout +Classes +""""""" .. autosummary:: :toctree: generated/ :nosignatures: :template: class.rst - mc_dropout + CheckpointEnsemble + EMA + MCDropout + StochasticModel + SWA + SWAG Metrics ------- @@ -242,6 +253,17 @@ Post-Processing Methods .. currentmodule:: torch_uncertainty.post_processing +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + MCBatchNorm + LaplaceApprox + +Scaling Methods +^^^^^^^^^^^^^^^ + .. autosummary:: :toctree: generated/ :nosignatures: @@ -250,7 +272,6 @@ Post-Processing Methods TemperatureScaler VectorScaler MatrixScaler - MCBatchNorm Datamodules ----------- @@ -301,3 +322,74 @@ Segmentation CamVidDataModule CityscapesDataModule MUADDataModule + +Datasets +-------- + +.. currentmodule:: torch_uncertainty.datasets + +Classification +^^^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets.classification + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + MNISTC + NotMNIST + CIFAR10C + CIFAR100C + CIFAR10H + CIFAR10N + CIFAR100N + ImageNetA + ImageNetC + ImageNetO + ImageNetR + TinyImageNet + TinyImageNetC + OpenImageO + +Regression +^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets.regression + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + UCIRegression + +Segmentation +^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets.segmentation + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + CamVid + Cityscapes + +Others & Cross-Categories +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: torch_uncertainty.datasets + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + Fractals + FrostImages + KITTIDepth + MUAD + NYUv2 diff --git a/docs/source/index.rst b/docs/source/index.rst index b0af32c5..0c3c7994 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,7 +32,7 @@ To install TorchUncertainty with contribution in mind, check the ----- Official Implementations -^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^ TorchUncertainty also houses multiple official implementations of papers from major conferences & journals. @@ -56,14 +56,6 @@ TorchUncertainty also houses multiple official implementations of papers from ma * Authors: *Gianni Franchi, Xuanlong Yu, Andrei Bursuc, Angel Tena, Rémi Kazmierczak, Séverine Dubuisson, Emanuel Aldea, David Filliat* * Paper: `BMVC 2022 `_. -Packed-Ensembles -^^^^^^^^^^^^^^^^ - -**Packed-Ensembles for Efficient Uncertainty Estimation** - -* Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi* -* Paper: `here `_. - .. toctree:: :maxdepth: 2 :caption: Contents: diff --git a/docs/source/references.rst b/docs/source/references.rst index bd4467c9..89829c16 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -15,7 +15,7 @@ For Deep Evidential Classification, consider citing: **Evidential Deep Learning to Quantify Classification Uncertainty** -* Authors: *Murat Sensoy, Lance Kaplan, Melih Kandemir* +* Authors: *Murat Sensoy, Lance Kaplan, and Melih Kandemir* * Paper: `NeurIPS 2018 `__. @@ -26,7 +26,7 @@ For Beta NLL in Deep Regression, consider citing: **On the Pitfalls of Heteroscedastic Uncertainty Estimation with Probabilistic Neural Networks** -* Authors: *Maximilian Seitzer, Arash Tavakoli, Dimitrije Antic, Georg Martius* +* Authors: *Maximilian Seitzer, Arash Tavakoli, Dimitrije Antic, and Georg Martius* * Paper: `ICLR 2022 `__. @@ -37,14 +37,14 @@ For Deep Evidential Regression, consider citing: **Deep Evidential Regression** -* Authors: *Alexander Amini, Wilko Schwarting, Ava Soleimany, Daniela Rus* +* Authors: *Alexander Amini, Wilko Schwarting, Ava Soleimany, and Daniela Rus* * Paper: `NeurIPS 2020 `__. -Bayesian Neural Networks -^^^^^^^^^^^^^^^^^^^^^^^^ +Variational Inference Bayesian Neural Networks +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -For Bayesian Neural Networks, consider citing: +For Variational Inference Bayesian Neural Networks, consider citing: **Weight Uncertainty in Neural Networks** @@ -73,6 +73,46 @@ For Monte-Carlo Dropout, consider citing: * Authors: *Yarin Gal and Zoubin Ghahramani* * Paper: `ICML 2016 `__. +Stochastic Weight Averaging +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For Stochastic Weight Averaging, consider citing: + +**Averaging Weights Leads to Wider Optima and Better Generalization** + +* Authors: *Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson* +* Paper: `UAI 2018 `__. + +Stochastic Weight Averaging Gaussian +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For Stochastic Weight Averaging Gaussian, consider citing: + +**A simple baseline for Bayesian uncertainty in deep learning** + +* Authors: *Wesley Maddox, Timur Garipov, Pavel Izmailov, Dmitry Vetrov, Andrew Gordon Wilson* +* Paper: `NeurIPS 2019 `__. + + +CheckpointEnsemble +^^^^^^^^^^^^^^^^^^ + +For CheckpointEnsemble, consider citing: + +**Checkpoint Ensembles: Ensemble Methods from a Single Training Process** + +* Authors: *Hugh Chen, Scott Lundberg, and Su-In Lee* +* Paper: `ArXiv `__. + +SnapshotEnsemble +^^^^^^^^^^^^^^^^ + +For SnapshotEnsemble, consider citing: + +**Snapshot Ensembles: Train 1, get M for free** + +* Authors: *Gao Huang, Yixuan Li, Geoff Pleiss, Zhuang Liu, John E. Hopcroft, and Kilian Q. Weinberger* +* Paper: `ICLR 2017 `__. BatchEnsemble ^^^^^^^^^^^^^ @@ -147,7 +187,7 @@ For RegMixup, consider citing: **RegMixup: Mixup as a Regularizer Can Surprisingly Improve Accuracy and Out Distribution Robustness** -* Authors: *Francesco Pinto, Harry Yang, Ser-Nam Lim, Philip H.S. Torr, Puneet K. Dokania* +* Authors: *Francesco Pinto, Harry Yang, Ser-Nam Lim, Philip H.S. Torr, and Puneet K. Dokania* * Paper: `NeurIPS 2022 `__. MixupIO @@ -193,6 +233,16 @@ For Monte-Carlo Batch Normalization, consider citing: * Authors: *Mathias Teye, Hossein Azizpour, and Kevin Smith* * Paper: `ICML 2018 `__. +Laplace Approximation +^^^^^^^^^^^^^^^^^^^^^ + +For Laplace Approximation, consider citing: + +**Laplace Redux - Effortless Bayesian Deep Learning** + +* Authors: *Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig* +* Paper: `NeurIPS 2021 `__. + Metrics ------- @@ -215,7 +265,7 @@ For the adaptive calibration error, consider citing: **Measuring Calibration in Deep Learning** -* Authors: Jeremy Nixon, Mike Dusenberry, Ghassen Jerfel, Timothy Nguyen, Jeremiah Liu, Linchuan Zhang, Dustin Tran +* Authors: *Jeremy Nixon, Mike Dusenberry, Ghassen Jerfel, Timothy Nguyen, Jeremiah Liu, Linchuan Zhang, and Dustin Tran* * Paper: `CVPRW 2019 `__. Area Under the Risk-Coverage curve @@ -225,7 +275,7 @@ For the area under the risk-coverage curve, consider citing: **Selective classification for deep neural networks** -* Authors: Yonatan Geifman, Ran El-Yaniv +* Authors: *Yonatan Geifman and Ran El-Yaniv* * Paper: `NeurIPS 2017 `__. Grouping Loss @@ -295,7 +345,7 @@ CIFAR-10 N / CIFAR-100 N **Learning with Noisy Labels Revisited: A Study Using Real-World Human Annotations** -* Authors: *Jiaheng Wei, Zhaowei Zhu, Hao Cheng, Tongliang Liu, Gang Niu, Yang Liu* +* Authors: *Jiaheng Wei, Zhaowei Zhu, Hao Cheng, Tongliang Liu, Gang Niu, and Yang Liu* * Paper: `ICLR 2022 `__. SVHN @@ -360,7 +410,7 @@ MUAD **MUAD: Multiple Uncertainties for Autonomous Driving Dataset** -* Authors: Gianni Franchi, Xuanlong Yu, Andrei Bursuc, et al.* +* Authors: *Gianni Franchi, Xuanlong Yu, Andrei Bursuc, et al.* * Paper: `BMVC 2022 __` Architectures diff --git a/experiments/classification/mnist/bayesian_lenet.py b/experiments/classification/mnist/bayesian_lenet.py deleted file mode 100644 index 05a7c17e..00000000 --- a/experiments/classification/mnist/bayesian_lenet.py +++ /dev/null @@ -1,62 +0,0 @@ -from functools import partial -from pathlib import Path - -from torch import nn, optim - -from torch_uncertainty import cli_main, init_args -from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.losses import ELBOLoss -from torch_uncertainty.models.lenet import bayesian_lenet -from torch_uncertainty.routines.classification import ClassificationSingle - - -def optim_lenet(model: nn.Module) -> dict: - """Optimization recipe for LeNet. - - Uses Adam default hyperparameters. - - Args: - model (nn.Module): LeNet model. - """ - optimizer = optim.Adam( - model.parameters(), - lr=1e-3, - ) - return {"optimizer": optimizer} - - -if __name__ == "__main__": - args = init_args(datamodule=MNISTDataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - - net_name = "bayesian-lenet-mnist" - - # datamodule - args.root = str(root / "data") - dm = MNISTDataModule(**vars(args)) - - # model - model = bayesian_lenet(dm.num_channels, dm.num_classes) - - # Here, the loss is a bit more complicated - # hyperparameters are from blitz. - loss = partial( - ELBOLoss, - inner_loss=nn.CrossEntropyLoss(), - kl_weight=1 / 50000, - num_samples=3, - ) - - baseline = ClassificationSingle( - model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=loss, - optim_recipe=optim_lenet, - **vars(args), - ) - - cli_main(baseline, dm, "logs/", net_name, args) diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml new file mode 100644 index 00000000..55f6b3c6 --- /dev/null +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -0,0 +1,68 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.StochasticModel + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + num_samples: 16 + num_classes: 10 + loss: + class_path: torch_uncertainty.losses.ELBOLoss + init_args: + kl_weight: 0.00002 + inner_loss: torch.nn.CrossEntropyLoss + num_samples: 3 +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml new file mode 100644 index 00000000..0c7989ab --- /dev/null +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -0,0 +1,59 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml new file mode 100644 index 00000000..c5398a87 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -0,0 +1,73 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.CheckpointEnsemble + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + save_schedule: + - 20 + - 25 + - 30 + - 35 + - 40 + - 45 + - 50 + - 55 + - 60 + - 65 + - 70 + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml new file mode 100644 index 00000000..363461c6 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -0,0 +1,61 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_ema + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.EMA + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + momentum: 0.99 + num_classes: 10 + loss: CrossEntropyLoss +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml new file mode 100644 index 00000000..fa3eb77d --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -0,0 +1,64 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_swa + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.SWA + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + cycle_start: 19 + cycle_length: 5 + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch_uncertainty.optim_recipes.FullSWALR + init_args: + milestone: 20 + swa_lr: 0.01 + anneal_epochs: 5 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml new file mode 100644 index 00000000..292b49f0 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -0,0 +1,64 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 75 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_swag + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + model: + class_path: torch_uncertainty.models.wrappers.SWAG + init_args: + model: + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + last_layer_dropout: false + layer_args: {} + cycle_start: 10 + cycle_length: 5 + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true +data: + root: ./data + batch_size: 128 +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch_uncertainty.optim_recipes.FullSWALR + init_args: + milestone: 10 + swa_lr: 0.01 + anneal_epochs: 5 diff --git a/experiments/classification/mnist/lenet.py b/experiments/classification/mnist/lenet.py index 450f72c2..c93610f4 100644 --- a/experiments/classification/mnist/lenet.py +++ b/experiments/classification/mnist/lenet.py @@ -1,52 +1,26 @@ -from pathlib import Path +import torch +from lightning.pytorch.cli import LightningArgumentParser -from torch import nn, optim - -from torch_uncertainty import cli_main, init_args from torch_uncertainty.datamodules import MNISTDataModule -from torch_uncertainty.models.lenet import lenet -from torch_uncertainty.routines.classification import ClassificationSingle +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TULightningCLI -def optim_lenet(model: nn.Module) -> dict: - """Optimization recipe for LeNet. +class MNISTCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) - Uses Adam default hyperparameters. - Args: - model (nn.Module): LeNet model. - """ - return { - "optimizer": optim.Adam( - model.parameters(), - ) - } +def cli_main() -> MNISTCLI: + return MNISTCLI(ClassificationRoutine, MNISTDataModule) if __name__ == "__main__": - args = init_args(datamodule=MNISTDataModule) - if args.root == "./data/": - root = Path(__file__).parent.absolute().parents[2] - else: - root = Path(args.root) - - if args.exp_name == "": - args.exp_name = "std-lenet-mnist" - - # datamodule - args.root = str(root / "data") - dm = MNISTDataModule(**vars(args)) - - # model - model = lenet(dm.num_channels, dm.num_classes) - - baseline = ClassificationSingle( - model=model, - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss(), - optim_recipe=optim_lenet, - **vars(args), - ) - - cli_main(baseline, dm, args.exp_dir, args.exp_name, args) + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/depth/kitti/configs/bts.yaml b/experiments/depth/kitti/configs/bts.yaml index 89de3232..3c20e048 100644 --- a/experiments/depth/kitti/configs/bts.yaml +++ b/experiments/depth/kitti/configs/bts.yaml @@ -37,7 +37,7 @@ data: crop_size: - 352 - 704 - inference_size: + eval_size: - 352 - 1216 num_workers: 4 diff --git a/experiments/depth/nyu/configs/bts.yaml b/experiments/depth/nyu/configs/bts.yaml index 8a9d0957..48f9d7db 100644 --- a/experiments/depth/nyu/configs/bts.yaml +++ b/experiments/depth/nyu/configs/bts.yaml @@ -37,7 +37,7 @@ data: crop_size: - 416 - 544 - inference_size: + eval_size: - 480 - 640 num_workers: 8 diff --git a/experiments/readme.md b/experiments/readme.md index 0035c5a7..8f8996c6 100644 --- a/experiments/readme.md +++ b/experiments/readme.md @@ -14,6 +14,6 @@ Torch-Uncertainty proposes various benchmarks to evaluate uncertainty quantifica *Work in progress* -## Monocular Depth Estimation +## Pixel Regression *Work in progress* diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index 7cbb001b..0dfac0a0 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -9,7 +9,6 @@ model: loss: CrossEntropyLoss version: std arch: 0 - num_estimators: 1 data: root: ./data batch_size: 16 diff --git a/experiments/segmentation/cityscapes/configs/deeplab.yaml b/experiments/segmentation/cityscapes/configs/deeplab.yaml index babefa1c..51cc2a1e 100644 --- a/experiments/segmentation/cityscapes/configs/deeplab.yaml +++ b/experiments/segmentation/cityscapes/configs/deeplab.yaml @@ -29,12 +29,11 @@ model: style: v3+ output_stride: 16 separable: false - num_estimators: 1 data: root: ./data/Cityscapes batch_size: 8 crop_size: 768 - inference_size: + eval_size: - 1024 - 2048 num_workers: 8 diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index 0ae0c212..145a96eb 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -25,12 +25,11 @@ model: loss: CrossEntropyLoss version: std arch: 0 - num_estimators: 1 data: root: ./data/Cityscapes batch_size: 8 crop_size: 1024 - inference_size: + eval_size: - 1024 - 2048 num_workers: 8 diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml index b2abf11e..a0c110e0 100644 --- a/experiments/segmentation/muad/configs/segformer.yaml +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -10,12 +10,11 @@ model: loss: CrossEntropyLoss version: std arch: 0 - num_estimators: 1 data: root: ./data batch_size: 8 crop_size: 1024 - inference_size: + eval_size: - 1024 - 2048 num_workers: 30 diff --git a/pyproject.toml b/pyproject.toml index 0b11a230..31db7c62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,13 +40,14 @@ dependencies = [ "scipy", "huggingface-hub", "scikit-learn", - "matplotlib", + "matplotlib==3.5.2", + "numpy<2", "opencv-python", "glest==0.0.1a0", ] [project.optional-dependencies] -image = ["scikit-image", "h5py",] +image = ["scikit-image", "h5py"] tabular = ["pandas"] dev = [ "torch_uncertainty[image]", @@ -63,7 +64,10 @@ docs = [ "sphinx-design", "sphinx-codeautolink", ] -all = ["torch_uncertainty[dev,docs,image,tabular]"] +all = [ + "torch_uncertainty[dev,docs,image,tabular]", + "laplace-torch" +] [project.urls] homepage = "https://torch-uncertainty.github.io/" diff --git a/tests/_dummies/__init__.py b/tests/_dummies/__init__.py index d942a5ae..4f2df70d 100644 --- a/tests/_dummies/__init__.py +++ b/tests/_dummies/__init__.py @@ -1,19 +1,19 @@ # ruff: noqa: F401 from .baseline import ( DummyClassificationBaseline, - DummyDepthBaseline, + DummyPixelRegressionBaseline, DummyRegressionBaseline, DummySegmentationBaseline, ) from .datamodule import ( DummyClassificationDataModule, - DummyDepthDataModule, + DummyPixelRegressionDataModule, DummyRegressionDataModule, DummySegmentationDataModule, ) from .dataset import ( + DummPixelRegressionDataset, DummyClassificationDataset, - DummyDepthDataset, DummyRegressionDataset, DummySegmentationDataset, ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index c43b444c..535cd567 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -7,7 +7,9 @@ NormalInverseGammaLayer, NormalLayer, ) -from torch_uncertainty.models.deep_ensembles import deep_ensembles +from torch_uncertainty.models import EMA, SWA, deep_ensembles +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.routines import ( ClassificationRoutine, PixelRegressionRoutine, @@ -24,11 +26,10 @@ def __new__( cls, num_classes: int, in_channels: int, - loss: type[nn.Module], + loss: nn.Module, baseline_type: str = "single", - optim_recipe=None, + optim_recipe=optim_cifar10_resnet18, with_feats: bool = True, - with_linear: bool = True, ood_criterion: str = "msp", eval_ood: bool = False, eval_grouping_loss: bool = False, @@ -41,14 +42,31 @@ def __new__( kernel_tau_std: float = 0.5, mixup_alpha: float = 0, cutmix_alpha: float = 0, + no_mixup_params: bool = False, + ema: bool = False, + swa: bool = False, ) -> ClassificationRoutine: model = dummy_model( in_channels=in_channels, num_classes=num_classes, with_feats=with_feats, - with_linear=with_linear, ) - + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=1) + if not no_mixup_params: + mixup_params = { + "mixup_alpha": mixup_alpha, + "cutmix_alpha": cutmix_alpha, + "mixtype": mixtype, + "mixmode": mixmode, + "dist_sim": dist_sim, + "kernel_tau_max": kernel_tau_max, + "kernel_tau_std": kernel_tau_std, + } + else: + mixup_params = None if baseline_type == "single": return ClassificationRoutine( num_classes=num_classes, @@ -57,18 +75,12 @@ def __new__( format_batch_fn=nn.Identity(), log_plots=True, optim_recipe=optim_recipe(model), - num_estimators=1, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + is_ensemble=False, + mixup_params=mixup_params, ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, - calibration_set="val" if calibrate else None, + post_processing=TemperatureScaler() if calibrate else None, save_in_csv=save_in_csv, ) # baseline_type == "ensemble": @@ -83,11 +95,11 @@ def __new__( optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), log_plots=True, - num_estimators=2, + is_ensemble=True, ood_criterion=ood_criterion, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, - calibration_set="val" if calibrate else None, + post_processing=TemperatureScaler() if calibrate else None, save_in_csv=save_in_csv, ) @@ -98,10 +110,12 @@ def __new__( probabilistic: bool, in_features: int, output_dim: int, - loss: type[nn.Module], + loss: nn.Module, baseline_type: str = "single", optim_recipe=None, dist_type: str = "normal", + ema: bool = False, + swa: bool = False, ) -> RegressionRoutine: if probabilistic: if dist_type == "normal": @@ -122,13 +136,17 @@ def __new__( num_classes=num_classes, last_layer=last_layer, ) + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=1) + if baseline_type == "single": return RegressionRoutine( probabilistic=probabilistic, output_dim=output_dim, model=model, loss=loss, - num_estimators=1, optim_recipe=optim_recipe(model), ) # baseline_type == "ensemble": @@ -142,7 +160,7 @@ def __new__( output_dim=output_dim, model=model, loss=loss, - num_estimators=2, + is_ensemble=True, optim_recipe=optim_recipe(model), format_batch_fn=RepeatTarget(2), ) @@ -154,17 +172,23 @@ def __new__( in_channels: int, num_classes: int, image_size: int, - loss: type[nn.Module], + loss: nn.Module, baseline_type: str = "single", optim_recipe=None, metric_subsampling_rate: float = 1, log_plots: bool = False, + ema: bool = False, + swa: bool = False, ) -> SegmentationRoutine: model = dummy_segmentation_model( in_channels=in_channels, num_classes=num_classes, image_size=image_size, ) + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=2) if baseline_type == "single": return SegmentationRoutine( @@ -172,7 +196,6 @@ def __new__( model=model, loss=loss, format_batch_fn=None, - num_estimators=1, optim_recipe=optim_recipe(model), metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, @@ -188,37 +211,57 @@ def __new__( model=model, loss=loss, format_batch_fn=RepeatTarget(2), - num_estimators=2, optim_recipe=optim_recipe(model), metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, ) -class DummyDepthBaseline: +class DummyPixelRegressionBaseline: def __new__( cls, + probabilistic: bool, in_channels: int, output_dim: int, image_size: int, - loss: type[nn.Module], + loss: nn.Module, + dist_type: str = "normal", baseline_type: str = "single", optim_recipe=None, + ema: bool = False, + swa: bool = False, ) -> PixelRegressionRoutine: + if probabilistic: + if dist_type == "normal": + last_layer = NormalLayer(output_dim) + num_classes = output_dim * 2 + elif dist_type == "laplace": + last_layer = LaplaceLayer(output_dim) + num_classes = output_dim * 2 + else: # dist_type == "nig" + last_layer = NormalInverseGammaLayer(output_dim) + num_classes = output_dim * 4 + else: + last_layer = nn.Identity() + num_classes = output_dim + model = dummy_segmentation_model( - num_classes=output_dim, + num_classes=num_classes, in_channels=in_channels, image_size=image_size, + last_layer=last_layer, ) + if ema: + model = EMA(model, momentum=0.99) + if swa: + model = SWA(model, cycle_start=0, cycle_length=1) if baseline_type == "single": return PixelRegressionRoutine( + probabilistic=probabilistic, output_dim=output_dim, - probabilistic=False, model=model, loss=loss, - format_batch_fn=None, - num_estimators=1, optim_recipe=optim_recipe(model), ) @@ -226,14 +269,14 @@ def __new__( model = deep_ensembles( [model, copy.deepcopy(model)], task="pixel_regression", - probabilistic=False, + probabilistic=probabilistic, ) return PixelRegressionRoutine( + probabilistic=probabilistic, output_dim=output_dim, - probabilistic=False, model=model, loss=loss, format_batch_fn=RepeatTarget(2), - num_estimators=2, + is_ensemble=True, optim_recipe=optim_recipe(model), ) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 51c769dd..9dc59dfd 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -7,17 +7,17 @@ from torch.utils.data import DataLoader from torchvision import tv_tensors -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from .dataset import ( + DummPixelRegressionDataset, DummyClassificationDataset, - DummyDepthDataset, DummyRegressionDataset, DummySegmentationDataset, ) -class DummyClassificationDataModule(AbstractDataModule): +class DummyClassificationDataModule(TUDataModule): num_channels = 1 image_size: int = 4 training_task = "classification" @@ -104,7 +104,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummyRegressionDataModule(AbstractDataModule): +class DummyRegressionDataModule(TUDataModule): in_features = 4 training_task = "regression" @@ -160,7 +160,7 @@ def test_dataloader(self) -> DataLoader | list[DataLoader]: return [self._data_loader(self.test)] -class DummySegmentationDataModule(AbstractDataModule): +class DummySegmentationDataModule(TUDataModule): num_channels = 3 training_task = "segmentation" @@ -249,7 +249,7 @@ def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) -class DummyDepthDataModule(AbstractDataModule): +class DummyPixelRegressionDataModule(TUDataModule): num_channels = 3 training_task = "pixel_regression" @@ -278,7 +278,7 @@ def __init__( self.num_images = num_images self.image_size = image_size - self.dataset = DummyDepthDataset + self.dataset = DummPixelRegressionDataset self.train_transform = T.ToDtype( dtype={ diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 1ab0c66b..662e4f9f 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -214,7 +214,7 @@ def __len__(self) -> int: return len(self.data) -class DummyDepthDataset(Dataset): +class DummPixelRegressionDataset(Dataset): def __init__( self, root: Path, diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index 2e29e2b5..c6c2d64d 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -12,23 +12,16 @@ def __init__( in_channels: int, num_classes: int, dropout_rate: float, - with_linear: bool, last_layer: nn.Module, ) -> None: super().__init__() self.in_channels = in_channels self.dropout_rate = dropout_rate - if with_linear: - self.linear = nn.Linear( - 1, - num_classes, - ) - else: - self.out = nn.Linear( - 1, - num_classes, - ) + self.linear = nn.Linear( + 1, + num_classes, + ) self.last_layer = last_layer self.dropout = nn.Dropout(p=dropout_rate) @@ -60,6 +53,7 @@ def __init__( num_classes: int, dropout_rate: float, image_size: int, + last_layer: nn.Module, ) -> None: super().__init__() self.dropout_rate = dropout_rate @@ -70,18 +64,21 @@ def __init__( in_channels, num_classes, kernel_size=3, padding=1 ) self.dropout = nn.Dropout(p=dropout_rate) + self.last_layer = last_layer def forward(self, x: Tensor) -> Tensor: - return self.dropout( - self.conv( - torch.ones( - ( - x.shape[0], - self.in_channels, - self.image_size, - self.image_size, - ), - dtype=torch.float32, + return self.last_layer( + self.dropout( + self.conv( + torch.ones( + ( + x.shape[0], + self.in_channels, + self.image_size, + self.image_size, + ), + dtype=torch.float32, + ) ) ) ) @@ -92,8 +89,7 @@ def dummy_model( num_classes: int, dropout_rate: float = 0.0, with_feats: bool = True, - with_linear: bool = True, - last_layer=None, + last_layer: nn.Module | None = None, ) -> _Dummy: """Dummy model for testing purposes. @@ -103,9 +99,7 @@ def dummy_model( num_estimators (int): Number of estimators in the ensemble. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. with_feats (bool, optional): Whether to include features. Defaults to True. - with_linear (bool, optional): Whether to include a linear layer. - Defaults to True. - last_layer ([type], optional): Last layer of the model. Defaults to None. + last_layer (nn.Module, optional): Last layer of the model. Defaults to None. Returns: _Dummy: Dummy model. @@ -117,14 +111,12 @@ def dummy_model( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, - with_linear=with_linear, last_layer=last_layer, ) return _Dummy( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, - with_linear=with_linear, last_layer=last_layer, ) @@ -134,6 +126,7 @@ def dummy_segmentation_model( num_classes: int, image_size: int, dropout_rate: float = 0.0, + last_layer: nn.Module | None = None, ) -> nn.Module: """Dummy segmentation model for testing purposes. @@ -142,13 +135,17 @@ def dummy_segmentation_model( num_classes (int): Number of output classes. image_size (int): Size of the input image. dropout_rate (float, optional): Dropout rate. Defaults to 0.0. + last_layer (nn.Module, optional): Last layer of the model. Defaults to None. Returns: nn.Module: Dummy segmentation model. """ + if last_layer is None: + last_layer = nn.Identity() return _DummySegmentation( in_channels=in_channels, num_classes=num_classes, dropout_rate=dropout_rate, image_size=image_size, + last_layer=last_layer, ) diff --git a/tests/baselines/test_deep_ensembles.py b/tests/baselines/test_deep_ensembles.py index fbb7a512..cd8642cf 100644 --- a/tests/baselines/test_deep_ensembles.py +++ b/tests/baselines/test_deep_ensembles.py @@ -9,7 +9,9 @@ class TestDeepEnsembles: """Testing the Deep Ensembles baseline class.""" def test_failure(self): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Models must not be an empty list." + ): DeepEnsemblesBaseline( log_path=".", checkpoint_ids=[], diff --git a/tests/datamodules/classification/test_cifar10_datamodule.py b/tests/datamodules/classification/test_cifar10.py similarity index 100% rename from tests/datamodules/classification/test_cifar10_datamodule.py rename to tests/datamodules/classification/test_cifar10.py diff --git a/tests/datamodules/classification/test_cifar100_datamodule.py b/tests/datamodules/classification/test_cifar100.py similarity index 100% rename from tests/datamodules/classification/test_cifar100_datamodule.py rename to tests/datamodules/classification/test_cifar100.py diff --git a/tests/datamodules/classification/test_imagenet_datamodule.py b/tests/datamodules/classification/test_imagenet.py similarity index 100% rename from tests/datamodules/classification/test_imagenet_datamodule.py rename to tests/datamodules/classification/test_imagenet.py diff --git a/tests/datamodules/classification/test_mnist_datamodule.py b/tests/datamodules/classification/test_mnist.py similarity index 79% rename from tests/datamodules/classification/test_mnist_datamodule.py rename to tests/datamodules/classification/test_mnist.py index 1707409a..f52c9abf 100644 --- a/tests/datamodules/classification/test_mnist_datamodule.py +++ b/tests/datamodules/classification/test_mnist.py @@ -12,14 +12,22 @@ class TestMNISTDataModule: def test_mnist_cutout(self): dm = MNISTDataModule( - root="./data/", batch_size=128, cutout=16, val_split=0.1 + root="./data/", + batch_size=128, + cutout=16, + val_split=0.1, + eval_ood=True, ) assert dm.dataset == MNIST assert isinstance(dm.train_transform.transforms[0], Cutout) dm = MNISTDataModule( - root="./data/", batch_size=128, ood_ds="not", cutout=0, val_split=0 + root="./data/", + batch_size=128, + ood_ds="notMNIST", + cutout=0, + val_split=0, ) assert isinstance(dm.train_transform.transforms[0], nn.Identity) @@ -42,6 +50,7 @@ def test_mnist_cutout(self): dm.setup("other") dm.eval_ood = True + dm.ood_transform = dm.test_transform dm.val_split = 0.1 dm.prepare_data() dm.setup() diff --git a/tests/datamodules/classification/test_tiny_imagenet_datamodule.py b/tests/datamodules/classification/test_tiny_imagenet.py similarity index 100% rename from tests/datamodules/classification/test_tiny_imagenet_datamodule.py rename to tests/datamodules/classification/test_tiny_imagenet.py diff --git a/tests/datamodules/classification/test_uci_regression_datamodule.py b/tests/datamodules/classification/test_uci_regression.py similarity index 100% rename from tests/datamodules/classification/test_uci_regression_datamodule.py rename to tests/datamodules/classification/test_uci_regression.py diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 7b0f5e66..0f0bd64f 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -4,16 +4,17 @@ from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.datamodules.abstract import ( - AbstractDataModule, CrossValDataModule, + TUDataModule, ) -class TestAbstractDataModule: - """Testing the AbstractDataModule class.""" +class TestTUDataModule: + """Testing the TUDataModule class.""" def test_errors(self): - dm = AbstractDataModule("root", 128, 0.0, 4, True, True) + TUDataModule.__abstractmethods__ = set() + dm = TUDataModule("root", 128, 0.0, 4, True, True) with pytest.raises(NotImplementedError): dm.setup() dm._get_train_data() @@ -24,7 +25,8 @@ class TestCrossValDataModule: """Testing the CrossValDataModule class.""" def test_cv_main(self): - dm = AbstractDataModule("root", 128, 0.0, 4, True, True) + TUDataModule.__abstractmethods__ = set() + dm = TUDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds @@ -46,7 +48,8 @@ def test_cv_main(self): cv_dm.test_dataloader() def test_errors(self): - dm = AbstractDataModule("root", 128, 0.0, 4, True, True) + TUDataModule.__abstractmethods__ = set() + dm = TUDataModule("root", 128, 0.0, 4, True, True) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds diff --git a/tests/datamodules/test_depth.py b/tests/datamodules/test_depth.py index bee19975..4305be05 100644 --- a/tests/datamodules/test_depth.py +++ b/tests/datamodules/test_depth.py @@ -1,6 +1,6 @@ import pytest -from tests._dummies.dataset import DummyDepthDataset +from tests._dummies.dataset import DummPixelRegressionDataset from torch_uncertainty.datamodules.depth import ( KITTIDataModule, MUADDataModule, @@ -19,7 +19,7 @@ def test_muad_main(self): assert dm.dataset == MUAD - dm.dataset = DummyDepthDataset + dm.dataset = DummPixelRegressionDataset dm.prepare_data() dm.setup() @@ -51,7 +51,7 @@ def test_nyu_main(self): assert dm.dataset == NYUv2 - dm.dataset = DummyDepthDataset + dm.dataset = DummPixelRegressionDataset dm.prepare_data() dm.setup() diff --git a/tests/layers/test_distributions.py b/tests/layers/test_distributions.py index b1fbf4bd..63a52f27 100644 --- a/tests/layers/test_distributions.py +++ b/tests/layers/test_distributions.py @@ -3,10 +3,16 @@ from torch_uncertainty.layers.distributions import ( LaplaceLayer, NormalLayer, + TUDist, ) class TestDistributions: + def test(self): + TUDist.__abstractmethods__ = set() + dist = TUDist(dim=1) + dist.forward(None) + def test_errors(self): with pytest.raises(ValueError): NormalLayer(-1, 1) diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index ee8ab224..3ad5e3f3 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -15,6 +15,7 @@ def test_plot_binary(self) -> None: torch.as_tensor([0, 0, 1, 1, 1]), ) fig, ax = metric.plot() + metric.plot(ax=ax) assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) assert ax.get_xlabel() == "Top-class Confidence (%)" diff --git a/tests/models/test_lenet.py b/tests/models/test_lenet.py index a0c6446a..8519ffdf 100644 --- a/tests/models/test_lenet.py +++ b/tests/models/test_lenet.py @@ -20,7 +20,16 @@ def test_main(self): packed_lenet(1, 1) bayesian_lenet(1, 1) - bayesian_lenet(1, 1, 1, 1, 1, 0, 1) + bayesian_lenet( + in_channels=1, + num_classes=1, + num_samples=1, + prior_sigma_1=1, + prior_sigma_2=1, + prior_pi=0, + mu_init=1, + sigma_init=1, + ) def test_errors(self): with pytest.raises(ValueError): diff --git a/tests/models/wrappers/__init__.py b/tests/models/wrappers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/wrappers/test_checkpoint_ensemble.py b/tests/models/wrappers/test_checkpoint_ensemble.py new file mode 100644 index 00000000..b159160c --- /dev/null +++ b/tests/models/wrappers/test_checkpoint_ensemble.py @@ -0,0 +1,26 @@ +import torch + +from tests._dummies.model import dummy_model +from torch_uncertainty.models import CheckpointEnsemble + + +class TestCheckpointEnsemble: + """Testing the CheckpointEnsemble class.""" + + def test_training(self): + ens = CheckpointEnsemble(dummy_model(1, 10)) + ens.eval() + ens(torch.randn(1, 1)) + + ens.train() + ens(torch.randn(1, 1)) + ens.update_wrapper(0) + ens.eval() + ens(torch.randn(1, 1)) + + ens = CheckpointEnsemble(dummy_model(1, 10), use_final_checkpoint=False) + ens.train() + ens(torch.randn(1, 1)) + ens.update_wrapper(0) + ens.eval() + ens(torch.randn(1, 1)) diff --git a/tests/models/test_deep_ensembles.py b/tests/models/wrappers/test_deep_ensembles.py similarity index 100% rename from tests/models/test_deep_ensembles.py rename to tests/models/wrappers/test_deep_ensembles.py diff --git a/tests/models/wrappers/test_ema.py b/tests/models/wrappers/test_ema.py new file mode 100644 index 00000000..3be66e3e --- /dev/null +++ b/tests/models/wrappers/test_ema.py @@ -0,0 +1,21 @@ +import pytest +import torch +from torch import nn + +from tests._dummies.model import dummy_model +from torch_uncertainty.models import EMA + + +class TestEMA: + """Testing the EMA class.""" + + def test_training(self): + ema = EMA(dummy_model(1, 10), momentum=0.99) + ema.eval() + ema(torch.randn(1, 1)) + ema.train() + ema.update_wrapper(0) + + def test_failures(self): + with pytest.raises(ValueError, match="must be in the range"): + EMA(nn.Module(), momentum=-1) diff --git a/tests/models/test_mc_dropout.py b/tests/models/wrappers/test_mc_dropout.py similarity index 76% rename from tests/models/test_mc_dropout.py rename to tests/models/wrappers/test_mc_dropout.py index b0cd9327..23a70c6a 100644 --- a/tests/models/test_mc_dropout.py +++ b/tests/models/wrappers/test_mc_dropout.py @@ -2,7 +2,7 @@ import torch from tests._dummies.model import dummy_model -from torch_uncertainty.models.mc_dropout import _MCDropout, mc_dropout +from torch_uncertainty.models import MCDropout, mc_dropout class TestMCDropout: @@ -27,14 +27,23 @@ def test_mc_dropout_eval(self): assert not dropout_model.training dropout_model(torch.rand(1, 10)) + dropout_model = mc_dropout(model, num_estimators=5, on_batch=False) + dropout_model.eval() + assert not dropout_model.training + dropout_model(torch.rand(1, 10)) + def test_mc_dropout_errors(self): model = dummy_model(10, 5, 0.1) with pytest.raises(ValueError): - _MCDropout(model=model, num_estimators=-1, last_layer=True) + MCDropout( + model=model, num_estimators=-1, last_layer=True, on_batch=True + ) with pytest.raises(ValueError): - _MCDropout(model=model, num_estimators=0, last_layer=False) + MCDropout( + model=model, num_estimators=0, last_layer=False, on_batch=False + ) dropout_model = mc_dropout(model, 5) with pytest.raises(TypeError): diff --git a/tests/models/test_stochastic_model.py b/tests/models/wrappers/test_stochastic.py similarity index 70% rename from tests/models/test_stochastic_model.py rename to tests/models/wrappers/test_stochastic.py index b7ab5a75..dc8a814e 100644 --- a/tests/models/test_stochastic_model.py +++ b/tests/models/wrappers/test_stochastic.py @@ -1,10 +1,10 @@ +import torch from torch import nn from torch_uncertainty.layers import BayesConv2d, BayesLinear -from torch_uncertainty.models.utils import stochastic_model +from torch_uncertainty.models import StochasticModel -@stochastic_model class DummyModelLinear(nn.Module): """Dummy model for testing purposes.""" @@ -16,7 +16,6 @@ def forward(self, x): return self.layer(x) -@stochastic_model class DummyModelConv(nn.Module): """Dummy conv model for testing purposes.""" @@ -28,7 +27,6 @@ def forward(self, x): return self.layer(x) -@stochastic_model class DummyModelMix(nn.Module): """Dummy mix model for testing purposes.""" @@ -47,27 +45,31 @@ class TestStochasticModel: """Testing the StochasticModel decorator.""" def test_main(self): - model = DummyModelLinear() + model = StochasticModel(DummyModelLinear(), 2) model.freeze() - assert model.layer.frozen + model(torch.randn(1, 1)) + assert model.core_model.layer.frozen model.unfreeze() - assert not model.layer.frozen + assert not model.core_model.layer.frozen + model.eval() + model(torch.randn(1, 1)) - model = DummyModelConv() + model = StochasticModel(DummyModelConv(), 2) model.freeze() - assert model.layer.frozen + assert model.core_model.layer.frozen model.unfreeze() - assert not model.layer.frozen + assert not model.core_model.layer.frozen def test_mix(self): - model = DummyModelMix() + model = StochasticModel(DummyModelMix(), 2) model.freeze() - assert model.layer.frozen + assert model.core_model.layer.frozen model.unfreeze() - assert not model.layer.frozen + assert not model.core_model.layer.frozen state = model.sample()[0] keys = state.keys() + print(list(keys)) assert list(keys) == [ "layer.weight", "layer2.weight", diff --git a/tests/models/wrappers/test_swa.py b/tests/models/wrappers/test_swa.py new file mode 100644 index 00000000..f590473b --- /dev/null +++ b/tests/models/wrappers/test_swa.py @@ -0,0 +1,135 @@ +import pytest +import torch +from torch import nn +from torch.utils.data import DataLoader, TensorDataset + +from tests._dummies.model import dummy_model +from torch_uncertainty.models import SWA, SWAG + + +class TestSWA: + """Testing the SWA class.""" + + def test_training(self): + dl = DataLoader(TensorDataset(torch.randn(1, 1)), batch_size=1) + swa = SWA(dummy_model(1, 10), cycle_start=1, cycle_length=1) + swa.eval() + swa(torch.randn(1, 1)) + + swa.train() + swa(torch.randn(1, 1)) + swa.update_wrapper(0) + swa.bn_update(dl, "cpu") + + swa.update_wrapper(1) + swa.bn_update(dl, "cpu") + + swa.eval() + swa(torch.randn(1, 1)) + + def test_failures(self): + with pytest.raises( + ValueError, match="`cycle_start` must be non-negative." + ): + SWA(nn.Module(), cycle_start=-1, cycle_length=1) + with pytest.raises( + ValueError, match="`cycle_length` must be strictly positive." + ): + SWA(nn.Module(), cycle_start=1, cycle_length=0) + + +class TestSWAG: + """Testing the SWAG class.""" + + def test_training(self): + dl = DataLoader(TensorDataset(torch.randn(1, 1)), batch_size=1) + swag = SWAG( + dummy_model(1, 10), + cycle_start=1, + cycle_length=1, + max_num_models=3, + num_estimators=2, + ) + assert swag.num_avgd_models == 0 + swag.eval() + swag(torch.randn(1, 1)) + + swag.train() + swag(torch.randn(1, 1)) + swag.update_wrapper(0) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (0, 10) + swag.bn_update(dl, "cpu") + swag(torch.randn(1, 1)) + + swag.update_wrapper(1) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (0, 10) + assert swag.num_avgd_models == 0 + swag.bn_update(dl, "cpu") + + swag.update_wrapper(2) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (1, 10) + swag.bn_update(dl, "cpu") + swag(torch.randn(1, 1)) + swag.update_wrapper(3) + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (2, 10) + swag.update_wrapper(4) + assert swag.num_avgd_models == 3 + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (3, 10) + swag.update_wrapper(5) + assert swag.num_avgd_models == 4 + assert swag.swag_stats[ + "model.swag_stats.linear.weight_covariance_sqrt" + ].shape == (3, 10) + swag.eval() + swag(torch.randn(1, 1)) + + swag = SWAG( + dummy_model(1, 10), + cycle_start=1, + cycle_length=1, + diag_covariance=True, + ) + swag.train() + swag.update_wrapper(2) + swag.sample(1, True, False, seed=1) + + def test_state_dict(self): + mod = dummy_model(1, 10) + swag = SWAG(mod, cycle_start=1, cycle_length=1, num_estimators=3) + print(swag.state_dict()) + swag.load_state_dict(swag.state_dict()) + + def test_failures(self): + with pytest.raises( + NotImplementedError, match="Raise an issue if you need this feature" + ): + swag = SWAG(nn.Module(), scale=1, cycle_start=1, cycle_length=1) + swag.sample(scale=1, block=True) + with pytest.raises(ValueError, match="`scale` must be non-negative."): + SWAG(nn.Module(), scale=-1, cycle_start=1, cycle_length=1) + with pytest.raises( + ValueError, match="`max_num_models` must be non-negative." + ): + SWAG(nn.Module(), max_num_models=-1, cycle_start=1, cycle_length=1) + with pytest.raises( + ValueError, match="`var_clamp` must be non-negative. " + ): + SWAG(nn.Module(), var_clamp=-1, cycle_start=1, cycle_length=1) + swag = SWAG( + nn.Module(), cycle_start=1, cycle_length=1, diag_covariance=True + ) + with pytest.raises( + ValueError, + match="Cannot sample full rank from diagonal covariance matrix.", + ): + swag.sample(scale=1, diag_covariance=False) diff --git a/tests/post_processing/test_laplace.py b/tests/post_processing/test_laplace.py new file mode 100644 index 00000000..6b798d6b --- /dev/null +++ b/tests/post_processing/test_laplace.py @@ -0,0 +1,31 @@ +import torch +from torch import nn +from torch.utils.data import TensorDataset + +from tests._dummies.model import dummy_model +from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing + + +class TestPostProcessing: + """Testing the PostProcessing class.""" + + def test_errors(self): + PostProcessing.__abstractmethods__ = set() + pp = PostProcessing(nn.Identity()) + pp.fit(None) + pp.forward(None) + + +class TestLaplace: + """Testing the LaplaceApprox class.""" + + def test_training(self): + ds = TensorDataset(torch.randn(16, 1), torch.randn(16, 10)) + la = LaplaceApprox( + task="classification", + model=dummy_model(1, 10, last_layer=nn.Linear(10, 10)), + ) + la.fit(ds) + la(torch.randn(1, 1)) + la = LaplaceApprox(task="classification") + la.set_model(dummy_model(1, 10)) diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index fa72caee..0d1acdca 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -42,11 +42,20 @@ def test_main(self): stoch_model.eval() stoch_model(torch.randn(1, 1, 20, 20)) + stoch_model = MCBatchNorm( + num_estimators=2, convert=False, mc_batch_size=1 + ) + stoch_model.set_model(mc_model) + def test_errors(self): """Test errors.""" model = nn.Identity() with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=0, convert=True) + with pytest.raises( + ValueError, match="mc_batch_size must be a positive integer" + ): + MCBatchNorm(model, num_estimators=1, convert=True, mc_batch_size=-1) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=False) with pytest.raises(ValueError): diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 22f6cae6..999a9663 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -9,8 +9,8 @@ dummy_model, ) from torch_uncertainty.losses import DECLoss, ELBOLoss -from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.transforms import RepeatTarget from torch_uncertainty.utils import TUTrainer @@ -30,9 +30,9 @@ def test_one_estimator_binary(self): in_channels=dm.num_channels, num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="msp", + ema=True, ) trainer.fit(model, dm) @@ -53,9 +53,9 @@ def test_two_estimators_binary(self): in_channels=dm.num_channels, num_classes=dm.num_classes, loss=nn.BCEWithLogitsLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="logit", + swa=True, ) trainer.fit(model, dm) @@ -77,10 +77,10 @@ def test_one_estimator_two_classes(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, + no_mixup_params=True, ) trainer.fit(model, dm) @@ -102,7 +102,6 @@ def test_one_estimator_two_classes_timm(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -130,7 +129,6 @@ def test_one_estimator_two_classes_mixup(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -157,7 +155,6 @@ def test_one_estimator_two_classes_mixup_io(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -184,7 +181,6 @@ def test_one_estimator_two_classes_regmixup(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -211,7 +207,6 @@ def test_one_estimator_two_classes_kernel_warping_emb(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -238,7 +233,6 @@ def test_one_estimator_two_classes_kernel_warping_inp(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="entropy", eval_ood=True, @@ -266,7 +260,6 @@ def test_one_estimator_two_classes_calibrated_with_ood(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss(), - optim_recipe=optim_cifar10_resnet18, baseline_type="single", ood_criterion="energy", eval_ood=True, @@ -293,7 +286,6 @@ def test_two_estimators_two_classes_mi(self): num_classes=dm.num_classes, in_channels=dm.num_channels, loss=DECLoss(1, 1e-2), - optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ood_criterion="mi", eval_ood=True, @@ -327,7 +319,6 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): loss=ELBOLoss( None, nn.CrossEntropyLoss(), kl_weight=1.0, num_samples=4 ), - optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ood_criterion="vr", eval_ood=True, @@ -340,11 +331,6 @@ def test_two_estimator_two_classes_elbo_vr_logs(self): model(dm.get_test_set()[0][0]) def test_classification_failures(self): - # num_estimators - with pytest.raises(ValueError): - ClassificationRoutine( - num_classes=10, model=nn.Module(), loss=None, num_estimators=-1 - ) # num_classes with pytest.raises(ValueError): ClassificationRoutine(num_classes=0, model=nn.Module(), loss=None) @@ -354,7 +340,7 @@ def test_classification_failures(self): num_classes=10, model=nn.Module(), loss=None, - num_estimators=1, + is_ensemble=False, ood_criterion="mi", ) with pytest.raises(ValueError): @@ -366,8 +352,12 @@ def test_classification_failures(self): ) with pytest.raises(ValueError): + mixup_params = {"cutmix_alpha": -1} ClassificationRoutine( - num_classes=10, model=nn.Module(), loss=None, cutmix_alpha=-1 + num_classes=10, + model=nn.Module(), + loss=None, + mixup_params=mixup_params, ) with pytest.raises( @@ -393,18 +383,36 @@ def test_classification_failures(self): num_classes=10, model=nn.Module(), loss=None, - num_estimators=2, + is_ensemble=True, eval_grouping_loss=True, ) - model = dummy_model(1, 1, 0, with_feats=False, with_linear=True) + model = dummy_model(1, 1, 0, with_feats=False) with pytest.raises(ValueError): ClassificationRoutine( num_classes=10, model=model, loss=None, eval_grouping_loss=True ) - model = dummy_model(1, 1, 0, with_feats=True, with_linear=False) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Mixup is not supported for ensembles at training time", + ): ClassificationRoutine( - num_classes=10, model=model, loss=None, eval_grouping_loss=True + num_classes=10, + model=nn.Module(), + loss=None, + mixup_params={"mixtype": "mixup"}, + format_batch_fn=RepeatTarget(2), + ) + + with pytest.raises( + ValueError, + match="Ensembles and post-processing methods cannot be used together. Raise an issue if needed.", + ): + ClassificationRoutine( + num_classes=10, + model=nn.Module(), + loss=None, + is_ensemble=True, + post_processing=nn.Module(), ) diff --git a/tests/routines/test_depth.py b/tests/routines/test_depth.py deleted file mode 100644 index e404ca80..00000000 --- a/tests/routines/test_depth.py +++ /dev/null @@ -1,75 +0,0 @@ -from pathlib import Path - -import pytest -from torch import nn - -from tests._dummies import ( - DummyDepthBaseline, - DummyDepthDataModule, -) -from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 -from torch_uncertainty.routines import PixelRegressionRoutine -from torch_uncertainty.utils import TUTrainer - - -class TestDepth: - def test_one_estimator_two_classes(self): - trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) - - root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummyDepthDataModule(root=root, batch_size=4, output_dim=2) - - model = DummyDepthBaseline( - in_channels=dm.num_channels, - output_dim=dm.output_dim, - image_size=dm.image_size, - loss=nn.MSELoss(), - baseline_type="single", - optim_recipe=optim_cifar10_resnet18, - ) - - trainer.fit(model, dm) - trainer.validate(model, dm) - trainer.test(model, dm) - model(dm.get_test_set()[0][0]) - - def test_two_estimators_one_class(self): - trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) - - root = Path(__file__).parent.absolute().parents[0] / "data" - dm = DummyDepthDataModule(root=root, batch_size=4, output_dim=1) - - model = DummyDepthBaseline( - in_channels=dm.num_channels, - output_dim=dm.output_dim, - image_size=dm.image_size, - loss=nn.MSELoss(), - baseline_type="ensemble", - optim_recipe=optim_cifar10_resnet18, - ) - - trainer.fit(model, dm) - trainer.validate(model, dm) - trainer.test(model, dm) - model(dm.get_test_set()[0][0]) - - def test_depth_errors(self): - with pytest.raises( - ValueError, match="num_estimators must be positive, got" - ): - PixelRegressionRoutine( - model=nn.Identity(), - output_dim=2, - loss=nn.MSELoss(), - num_estimators=0, - probabilistic=False, - ) - - with pytest.raises(ValueError, match="output_dim must be positive"): - PixelRegressionRoutine( - model=nn.Identity(), - output_dim=0, - loss=nn.MSELoss(), - num_estimators=1, - probabilistic=False, - ) diff --git a/tests/routines/test_pixel_regression.py b/tests/routines/test_pixel_regression.py new file mode 100644 index 00000000..56e2058d --- /dev/null +++ b/tests/routines/test_pixel_regression.py @@ -0,0 +1,131 @@ +from pathlib import Path + +import pytest +import torch +from torch import nn + +from tests._dummies import ( + DummyPixelRegressionBaseline, + DummyPixelRegressionDataModule, +) +from torch_uncertainty.losses import DistributionNLLLoss +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.routines.pixel_regression import ( + PixelRegressionRoutine, + colorize, +) +from torch_uncertainty.utils import TUTrainer + + +class TestPixelRegression: + def test_one_estimator_two_classes(self): + trainer = TUTrainer( + accelerator="cpu", + max_epochs=1, + logger=None, + enable_checkpointing=False, + ) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyPixelRegressionDataModule( + root=root, batch_size=5, output_dim=3 + ) + + model = DummyPixelRegressionBaseline( + probabilistic=False, + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=nn.MSELoss(), + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, + ema=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + trainer = TUTrainer( + accelerator="cpu", + max_epochs=1, + logger=None, + enable_checkpointing=False, + ) + model = DummyPixelRegressionBaseline( + probabilistic=True, + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=DistributionNLLLoss(), + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, + swa=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_one_class(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyPixelRegressionDataModule( + root=root, batch_size=4, output_dim=1 + ) + + model = DummyPixelRegressionBaseline( + probabilistic=False, + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=nn.MSELoss(), + baseline_type="ensemble", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True, logger=None) + model = DummyPixelRegressionBaseline( + probabilistic=True, + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=DistributionNLLLoss(), + baseline_type="ensemble", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + colorize(torch.ones((10, 10)), 0, 1) + colorize(torch.ones((10, 10)), 0, 0) + + def test_depth_errors(self): + with pytest.raises(ValueError, match="output_dim must be positive"): + PixelRegressionRoutine( + probabilistic=False, + model=nn.Identity(), + output_dim=0, + loss=nn.MSELoss(), + ) + + with pytest.raises(ValueError, match="num_image_plot must be positive"): + PixelRegressionRoutine( + probabilistic=False, + model=nn.Identity(), + output_dim=1, + loss=nn.MSELoss(), + num_image_plot=0, + log_plots=True, + ) diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 2c7eb469..7c03ab1e 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -26,6 +26,7 @@ def test_one_estimator_one_output(self): loss=DistributionNLLLoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", + ema=True, ) trainer.fit(model, dm) @@ -33,13 +34,15 @@ def test_one_estimator_one_output(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", + swa=True, ) trainer.fit(model, dm) @@ -63,18 +66,21 @@ def test_one_estimator_two_outputs(self): dist_type="laplace", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="single", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) @@ -94,18 +100,21 @@ def test_two_estimators_one_output(self): dist_type="nig", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=1, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) trainer.fit(model, dm) + trainer.validate(model, dm) trainer.test(model, dm) model(dm.get_test_set()[0][0]) @@ -128,11 +137,12 @@ def test_two_estimators_two_outputs(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) model = DummyRegressionBaseline( probabilistic=False, in_features=dm.in_features, output_dim=2, - loss=DistributionNLLLoss(), + loss=nn.MSELoss(), optim_recipe=optim_cifar10_resnet18, baseline_type="ensemble", ) @@ -142,12 +152,10 @@ def test_two_estimators_two_outputs(self): model(dm.get_test_set()[0][0]) def test_regression_failures(self): - with pytest.raises(ValueError): - RegressionRoutine( - True, 1, nn.Identity(), nn.MSELoss, num_estimators=0 - ) - - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="output_dim must be positive"): RegressionRoutine( - True, 0, nn.Identity(), nn.MSELoss, num_estimators=1 + probabilistic=True, + output_dim=0, + model=nn.Identity(), + loss=nn.MSELoss(), ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index cb2e41a3..7168e607 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -27,6 +27,33 @@ def test_one_estimator_two_classes(self): baseline_type="single", optim_recipe=optim_cifar10_resnet18, log_plots=True, + ema=True, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + trainer = TUTrainer( + accelerator="cpu", + max_epochs=2, + logger=None, + enable_checkpointing=False, + ) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) + + model = DummySegmentationBaseline( + in_channels=dm.num_channels, + num_classes=dm.num_classes, + image_size=dm.image_size, + loss=nn.CrossEntropyLoss(), + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, + log_plots=True, + swa=True, ) trainer.fit(model, dm) @@ -35,7 +62,12 @@ def test_one_estimator_two_classes(self): model(dm.get_test_set()[0][0]) def test_two_estimators_two_classes(self): - trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + trainer = TUTrainer( + accelerator="cpu", + max_epochs=2, + logger=None, + enable_checkpointing=False, + ) root = Path(__file__).parent.absolute().parents[0] / "data" dm = DummySegmentationDataModule(root=root, batch_size=4, num_classes=2) @@ -47,6 +79,7 @@ def test_two_estimators_two_classes(self): loss=nn.CrossEntropyLoss(), baseline_type="ensemble", optim_recipe=optim_cifar10_resnet18, + swa=True, ) trainer.fit(model, dm) @@ -55,16 +88,6 @@ def test_two_estimators_two_classes(self): model(dm.get_test_set()[0][0]) def test_segmentation_errors(self): - with pytest.raises( - ValueError, match="num_estimators must be positive, got" - ): - SegmentationRoutine( - model=nn.Identity(), - num_classes=2, - loss=nn.CrossEntropyLoss(), - num_estimators=0, - ) - with pytest.raises( ValueError, match="num_classes must be at least 2, got" ): diff --git a/tests/test_optim_recipes.py b/tests/test_optim_recipes.py index b71ac43f..b6d15863 100644 --- a/tests/test_optim_recipes.py +++ b/tests/test_optim_recipes.py @@ -2,9 +2,17 @@ import pytest import torch -from torch_uncertainty.optim_recipes import ( - get_procedure, -) +from torch_uncertainty.optim_recipes import FullSWALR, get_procedure, optim_abnn + + +class TestFullSWALR: + def test_full_swa_lr(self): + FullSWALR( + torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), lr=1e-3), + swa_lr=1, + milestone=12, + anneal_epochs=5, + ) class TestOptProcedures: @@ -15,6 +23,7 @@ def test_optim_cifar10(self): get_procedure("resnet50", "cifar10", "packed")(model) get_procedure("wideresnet28x10", "cifar10", "batched")(model) get_procedure("vgg16", "cifar10", "standard")(model) + optim_abnn(model, lr=0.1) def test_optim_cifar100(self): model = torch.nn.Linear(1, 1) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index fd6bc6c1..e6188322 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -24,10 +24,10 @@ def __init__( eval_ood: bool = False, eval_grouping_loss: bool = False, ood_criterion: Literal[ - "msp", "logits", "energy", "entropy", "mi", "VR" + "msp", "logit", "energy", "entropy", "mi", "vr" ] = "msp", log_plots: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", ) -> None: log_path = Path(log_path) @@ -45,14 +45,13 @@ def __init__( optim_recipe=None, ).eval() models.append(trained_model.model) - de = deep_ensembles(models=models) - super().__init__( + super().__init__( # coverage: ignore num_classes=num_classes, model=de, loss=None, - num_estimators=de.num_estimators, + is_ensemble=de.num_estimators > 1, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index ff051b48..d184cda0 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -14,10 +14,17 @@ from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget +ENSEMBLE_METHODS = [ + "packed", + "batched", + "lpbnn", + "masked", + "mc-dropout", + "mimo", +] + class ResNetBaseline(ClassificationRoutine): - single = ["std"] - ensemble = ["packed", "batched", "lpbnn", "masked", "mc-dropout", "mimo"] versions = { "std": resnet, "packed": packed_resnet, @@ -47,13 +54,7 @@ def __init__( style: str = "imagenet", num_estimators: int = 1, dropout_rate: float = 0.0, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, last_layer_dropout: bool = False, groups: int = 1, scale: float | None = None, @@ -66,7 +67,7 @@ def __init__( ] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_grouping_loss: bool = False, num_calibration_bins: int = 15, @@ -108,17 +109,10 @@ def __init__( Only used if :attr:`version` is either ``"packed"``, ``"batched"``, ``"masked"`` or ``"mc-dropout"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel - tau. Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults - to ``0``. - cutmix_alpha (float, optional): Alpha parameter for CutMix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixtype, + mixmode, dist_sim, kernel_tau_max, kernel_tau_std, + mixup_alpha, and cutmix_alpha. If None, no augmentations. + Defaults to ``None``. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. scale (float, optional): Expansion factor affecting the width of @@ -184,7 +178,7 @@ def __init__( if version not in self.versions: raise ValueError(f"Unknown version: {version}") - if version in self.ensemble: + if version in ENSEMBLE_METHODS: params |= { "num_estimators": num_estimators, } @@ -226,15 +220,9 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + mixup_params=mixup_params, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index fc4f5256..7375a082 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -10,10 +10,10 @@ from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget +ENSEMBLE_METHODS = ["mc-dropout", "packed"] + class VGGBaseline(ClassificationRoutine): - single = ["std"] - ensemble = ["mc-dropout", "packed"] versions = { "std": vgg, "mc-dropout": vgg, @@ -32,13 +32,7 @@ def __init__( num_estimators: int = 1, dropout_rate: float = 0.0, last_layer_dropout: bool = False, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, groups: int = 1, alpha: int | None = None, gamma: int = 1, @@ -47,7 +41,7 @@ def __init__( ] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_grouping_loss: bool = False, ) -> None: @@ -79,17 +73,10 @@ def __init__( Only used if :attr:`version` is either ``"packed"``, ``"batched"`` or ``"masked"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel - tau. Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults - to ``0``. - cutmix_alpha (float, optional): Alpha parameter for CutMix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixtype, + mixmode, dist_sim, kernel_tau_max, kernel_tau_std, + mixup_alpha, and cutmix_alpha. If None, no augmentations. + Defaults to ``None``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. @@ -147,7 +134,7 @@ def __init__( "num_estimators": num_estimators, } - if version in self.ensemble: + if version in ENSEMBLE_METHODS: params |= { "num_estimators": num_estimators, } @@ -175,15 +162,9 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + mixup_params=mixup_params, eval_ood=eval_ood, ood_criterion=ood_criterion, log_plots=log_plots, diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index ffda0d48..78abe960 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -15,10 +15,10 @@ ) from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget +ENSEMBLE_METHODS = ["packed", "batched", "masked", "mimo", "mc-dropout"] + class WideResNetBaseline(ClassificationRoutine): - single = ["std"] - ensemble = ["packed", "batched", "masked", "mimo", "mc-dropout"] versions = { "std": [wideresnet28x10], "mc-dropout": [wideresnet28x10], @@ -39,13 +39,7 @@ def __init__( style: str = "imagenet", num_estimators: int = 1, dropout_rate: float = 0.0, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, groups: int = 1, last_layer_dropout: bool = False, scale: float | None = None, @@ -58,7 +52,7 @@ def __init__( ] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, + calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_grouping_loss: bool = False, ) -> None: @@ -89,17 +83,10 @@ def __init__( Only used if :attr:`version` is either ``"packed"``, ``"batched"`` or ``"masked"`` Defaults to ``None``. dropout_rate (float, optional): Dropout rate. Defaults to ``0.0``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel - tau. Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults - to ``0``. - cutmix_alpha (float, optional): Alpha parameter for CutMix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixtype, + mixmode, dist_sim, kernel_tau_max, kernel_tau_std, + mixup_alpha, and cutmix_alpha. If None, no augmentations. + Defaults to ``None``. last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. @@ -155,7 +142,7 @@ def __init__( if version not in self.versions: raise ValueError(f"Unknown version: {version}") - if version in self.ensemble: + if version in ENSEMBLE_METHODS: params |= { "num_estimators": num_estimators, } @@ -197,15 +184,9 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, - mixtype=mixtype, - mixmode=mixmode, - dist_sim=dist_sim, - kernel_tau_max=kernel_tau_max, - kernel_tau_std=kernel_tau_std, - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, + mixup_params=mixup_params, eval_ood=eval_ood, eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 02e3c658..34cfdc21 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -13,10 +13,10 @@ ) from torch_uncertainty.transforms.batch import RepeatTarget +ENSEMBLE_METHODS = ["packed"] + class MLPBaseline(RegressionRoutine): - single = ["std"] - ensemble = ["packed"] versions = {"std": mlp, "packed": packed_mlp} def __init__( @@ -82,7 +82,7 @@ def __init__( output_dim=output_dim, model=model, loss=loss, - num_estimators=num_estimators, + is_ensemble=version in ENSEMBLE_METHODS, format_batch_fn=format_batch_fn, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/deeplab.py b/torch_uncertainty/baselines/segmentation/deeplab.py index 01575f1f..f3e3982c 100644 --- a/torch_uncertainty/baselines/segmentation/deeplab.py +++ b/torch_uncertainty/baselines/segmentation/deeplab.py @@ -28,7 +28,6 @@ def __init__( style: Literal["v3", "v3+"], output_stride: int, separable: bool, - num_estimators: int = 1, metric_subsampling_rate: float = 1e-2, log_plots: bool = False, num_calibration_bins: int = 15, @@ -52,7 +51,6 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, format_batch_fn=format_batch_fn, metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 97d98a3b..c2a46013 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -21,7 +21,6 @@ def __init__( loss: nn.Module, version: Literal["std"], arch: int, - num_estimators: int = 1, ) -> None: r"""SegFormer backbone baseline for segmentation providing support for various versions and architectures. @@ -63,7 +62,6 @@ def __init__( num_classes=num_classes, model=model, loss=loss, - num_estimators=num_estimators, format_batch_fn=format_batch_fn, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 1da19ced..16308d72 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from pathlib import Path from typing import Literal @@ -8,7 +9,7 @@ from torch.utils.data.sampler import SubsetRandomSampler -class AbstractDataModule(LightningDataModule): +class TUDataModule(ABC, LightningDataModule): training_task: str train: Dataset val: Dataset @@ -47,8 +48,9 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers + @abstractmethod def setup(self, stage: Literal["fit", "test"] | None = None) -> None: - raise NotImplementedError + pass def get_train_set(self) -> Dataset: """Get the training set.""" @@ -148,13 +150,13 @@ def make_cross_val_splits( return cv_dm -class CrossValDataModule(AbstractDataModule): +class CrossValDataModule(TUDataModule): def __init__( self, root: str | Path, train_idx: ArrayLike, val_idx: ArrayLike, - datamodule: AbstractDataModule, + datamodule: TUDataModule, batch_size: int, val_split: float, num_workers: int, diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 45452115..1e5eda4a 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -9,14 +9,14 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10, SVHN -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR10C, CIFAR10H from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class CIFAR10DataModule(AbstractDataModule): +class CIFAR10DataModule(TUDataModule): num_classes = 10 num_channels = 3 input_shape = (3, 32, 32) diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index bc5a3691..373430bd 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -10,14 +10,14 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset from torch_uncertainty.datasets.classification import CIFAR100C from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class CIFAR100DataModule(AbstractDataModule): +class CIFAR100DataModule(TUDataModule): num_classes = 100 num_channels = 3 input_shape = (3, 32, 32) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d215a79f..6d35303c 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.classification import ( ImageNetA, ImageNetO, @@ -23,7 +23,7 @@ ) -class ImageNetDataModule(AbstractDataModule): +class ImageNetDataModule(TUDataModule): num_classes = 1000 num_channels = 3 test_datasets = ["r", "o", "a"] diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 77a6f4f5..b411f502 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -6,25 +6,25 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST from torch_uncertainty.transforms import Cutout from torch_uncertainty.utils import create_train_val_split -class MNISTDataModule(AbstractDataModule): +class MNISTDataModule(TUDataModule): num_classes = 10 num_channels = 1 input_shape = (1, 28, 28) training_task = "classification" - ood_datasets = ["fashion", "not"] + ood_datasets = ["fashion", "notMNIST"] def __init__( self, root: str | Path, batch_size: int, eval_ood: bool = False, - ood_ds: Literal["fashion", "not"] = "fashion", + ood_ds: Literal["fashion", "notMNIST"] = "fashion", val_split: float | None = None, num_workers: int = 1, cutout: int | None = None, @@ -39,7 +39,7 @@ def __init__( eval_ood (bool): Whether to evaluate on out-of-distribution data. batch_size (int): Number of samples per batch. ood_ds (str): Which out-of-distribution dataset to use. Defaults to - ``"fashion"``; `fashion` stands for FashionMNIST and `not` for + ``"fashion"``; `fashion` stands for FashionMNIST and `notMNIST` for notMNIST. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. @@ -71,11 +71,11 @@ def __init__( if ood_ds == "fashion": self.ood_dataset = FashionMNIST - elif ood_ds == "not": + elif ood_ds == "notMNIST": self.ood_dataset = NotMNIST else: raise ValueError( - f"`ood_ds` should be `fashion` or `not`. Got {ood_ds}." + f"`ood_ds` should be in {self.ood_datasets}. Got {ood_ds}." ) main_transform = Cutout(cutout) if cutout else nn.Identity() @@ -95,6 +95,15 @@ def __init__( T.Normalize((0.1307,), (0.3081,)), ] ) + if self.eval_ood: # NotMNIST has 3 channels + self.ood_transform = T.Compose( + [ + T.Grayscale(num_output_channels=1), + T.ToTensor(), + T.CenterCrop(28), + T.Normalize((0.1307,), (0.3081,)), + ] + ) def prepare_data(self) -> None: # coverage: ignore """Download the datasets.""" @@ -140,7 +149,7 @@ def setup(self, stage: Literal["fit", "test"] | None = None) -> None: self.ood = self.ood_dataset( self.root, download=False, - transform=self.test_transform, + transform=self.ood_transform, ) def test_dataloader(self) -> list[DataLoader]: diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 25c62f31..49506d48 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -9,7 +9,7 @@ from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet from torch_uncertainty.utils import ( create_train_val_split, @@ -17,7 +17,7 @@ ) -class TinyImageNetDataModule(AbstractDataModule): +class TinyImageNetDataModule(TUDataModule): num_classes = 200 num_channels = 3 training_task = "classification" diff --git a/torch_uncertainty/datamodules/depth/base.py b/torch_uncertainty/datamodules/depth/base.py index 47e8cf73..34c69c89 100644 --- a/torch_uncertainty/datamodules/depth/base.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -7,12 +7,12 @@ from torchvision.datasets import VisionDataset from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class DepthDataModule(AbstractDataModule): +class DepthDataModule(TUDataModule): def __init__( self, dataset: type[VisionDataset], @@ -21,7 +21,7 @@ def __init__( min_depth: float, max_depth: float, crop_size: _size_2_t, - inference_size: _size_2_t, + eval_size: _size_2_t, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -42,7 +42,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -66,7 +66,7 @@ def __init__( self.min_depth = min_depth self.max_depth = max_depth self.crop_size = _pair(crop_size) - self.inference_size = _pair(inference_size) + self.eval_size = _pair(eval_size) self.train_transform = v2.Compose( [ @@ -91,7 +91,7 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.Resize(size=self.inference_size), + v2.Resize(size=self.eval_size), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, diff --git a/torch_uncertainty/datamodules/depth/kitti.py b/torch_uncertainty/datamodules/depth/kitti.py index 55f30296..c5035893 100644 --- a/torch_uncertainty/datamodules/depth/kitti.py +++ b/torch_uncertainty/datamodules/depth/kitti.py @@ -15,7 +15,7 @@ def __init__( min_depth: float = 1e-3, max_depth: float = 80.0, crop_size: _size_2_t = (352, 704), - inference_size: _size_2_t = (375, 1242), + eval_size: _size_2_t = (375, 1242), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -36,7 +36,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(375, 1242)``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -58,7 +58,7 @@ def __init__( min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, - inference_size=inference_size, + eval_size=eval_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py index 5ca8643b..cf4f6cde 100644 --- a/torch_uncertainty/datamodules/depth/muad.py +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -16,7 +16,7 @@ def __init__( min_depth: float, max_depth: float, crop_size: _size_2_t = 1024, - inference_size: _size_2_t = (1024, 2048), + eval_size: _size_2_t = (1024, 2048), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -36,7 +36,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -58,7 +58,7 @@ def __init__( min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, - inference_size=inference_size, + eval_size=eval_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/depth/nyu.py b/torch_uncertainty/datamodules/depth/nyu.py index c421c044..ec925ffa 100644 --- a/torch_uncertainty/datamodules/depth/nyu.py +++ b/torch_uncertainty/datamodules/depth/nyu.py @@ -15,7 +15,7 @@ def __init__( min_depth: float = 1e-3, max_depth: float = 10.0, crop_size: _size_2_t = (416, 544), - inference_size: _size_2_t = (480, 640), + eval_size: _size_2_t = (480, 640), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -36,7 +36,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(416, 544)``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -58,7 +58,7 @@ def __init__( min_depth=min_depth, max_depth=max_depth, crop_size=crop_size, - inference_size=inference_size, + eval_size=eval_size, val_split=val_split, num_workers=num_workers, pin_memory=pin_memory, diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index 4a4aee65..84f99ac7 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -4,11 +4,11 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.segmentation import CamVid -class CamVidDataModule(AbstractDataModule): +class CamVidDataModule(TUDataModule): def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/segmentation/cityscapes.py b/torch_uncertainty/datamodules/segmentation/cityscapes.py index f35bd65d..baee3d4b 100644 --- a/torch_uncertainty/datamodules/segmentation/cityscapes.py +++ b/torch_uncertainty/datamodules/segmentation/cityscapes.py @@ -6,19 +6,19 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets.segmentation import Cityscapes from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class CityscapesDataModule(AbstractDataModule): +class CityscapesDataModule(TUDataModule): def __init__( self, root: str | Path, batch_size: int, crop_size: _size_2_t = 1024, - inference_size: _size_2_t = (1024, 2048), + eval_size: _size_2_t = (1024, 2048), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -35,7 +35,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -84,7 +84,7 @@ def __init__( v2.Compose([ v2.ToImage(), - v2.Resize(size=inference_size, antialias=True), + v2.Resize(size=eval_size, antialias=True), v2.ToDtype({ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -109,7 +109,7 @@ def __init__( self.dataset = Cityscapes self.mode = "fine" self.crop_size = _pair(crop_size) - self.inference_size = _pair(inference_size) + self.eval_size = _pair(eval_size) self.train_transform = v2.Compose( [ @@ -138,7 +138,7 @@ def __init__( self.test_transform = v2.Compose( [ v2.ToImage(), - v2.Resize(size=self.inference_size, antialias=True), + v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index c126b05e..9ba10ee4 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -6,19 +6,19 @@ from torchvision import tv_tensors from torchvision.transforms import v2 -from torch_uncertainty.datamodules.abstract import AbstractDataModule +from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import MUAD from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class MUADDataModule(AbstractDataModule): +class MUADDataModule(TUDataModule): def __init__( self, root: str | Path, batch_size: int, crop_size: _size_2_t = 1024, - inference_size: _size_2_t = (1024, 2048), + eval_size: _size_2_t = (1024, 2048), val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -35,7 +35,7 @@ def __init__( :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. - inference_size (sequence or int, optional): Desired input image and + eval_size (sequence or int, optional): Desired input image and segmentation mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to @@ -84,7 +84,7 @@ def __init__( v2.Compose([ v2.ToImage(), - v2.Resize(size=inference_size, antialias=True), + v2.Resize(size=eval_size, antialias=True), v2.ToDtype({ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -108,7 +108,7 @@ def __init__( self.dataset = MUAD self.crop_size = _pair(crop_size) - self.inference_size = _pair(inference_size) + self.eval_size = _pair(eval_size) self.train_transform = v2.Compose( [ @@ -135,7 +135,7 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.Resize(size=self.inference_size, antialias=True), + v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 66571959..a5cbe8af 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -6,10 +6,10 @@ from torch_uncertainty.datasets.regression import UCIRegression -from .abstract import AbstractDataModule +from .abstract import TUDataModule -class UCIDataModule(AbstractDataModule): +class UCIDataModule(TUDataModule): training_task = "regression" def __init__( diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index d6a02df0..5acc7735 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset +from .fractals import Fractals from .frost import FrostImages from .kitti import KITTIDepth from .muad import MUAD diff --git a/torch_uncertainty/datasets/classification/__init__.py b/torch_uncertainty/datasets/classification/__init__.py index 07f228ea..9bce03a1 100644 --- a/torch_uncertainty/datasets/classification/__init__.py +++ b/torch_uncertainty/datasets/classification/__init__.py @@ -1,8 +1,8 @@ # ruff: noqa: F401 from .cifar import CIFAR10C, CIFAR10H, CIFAR10N, CIFAR100C, CIFAR100N -from .fractals import Fractals from .imagenet import ( ImageNetA, + ImageNetC, ImageNetO, ImageNetR, TinyImageNet, diff --git a/torch_uncertainty/datasets/classification/imagenet/__init__.py b/torch_uncertainty/datasets/classification/imagenet/__init__.py index 9abfd040..f5971ff5 100644 --- a/torch_uncertainty/datasets/classification/imagenet/__init__.py +++ b/torch_uncertainty/datasets/classification/imagenet/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .imagenet_a import ImageNetA +from .imagenet_c import ImageNetC from .imagenet_o import ImageNetO from .imagenet_r import ImageNetR from .tiny_imagenet import TinyImageNet diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index 891bfb9a..7d69d0f9 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -26,9 +26,9 @@ class ImageNetVariation(ImageFolder): downloaded, it is not downloaded again. Defaults to False. """ - url: str - filename: str - tgz_md5: str + url: str | list[str] + filename: str | list[str] + tgz_md5: str | list[str] dataset_name: str root_appendix: str diff --git a/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py b/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py index 10e3df58..c95e1188 100644 --- a/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py +++ b/torch_uncertainty/datasets/classification/imagenet/imagenet_c.py @@ -1,7 +1,6 @@ from .base import ImageNetVariation -# todo, build or download class ImageNetC(ImageNetVariation): """The corrupted ImageNet-C dataset. diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index 71d2bc6a..9bd27f8c 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -66,7 +66,7 @@ def __init__( ) super().__init__( - self.root + f"/notMNIST_{subset}", + self.root / f"notMNIST_{subset}", transform=transform, target_transform=target_transform, ) @@ -97,4 +97,4 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: Args: index (int): The index of the sample to get. """ - return super().__getitem__(index)[0] + return super().__getitem__(index) diff --git a/torch_uncertainty/datasets/classification/fractals.py b/torch_uncertainty/datasets/fractals.py similarity index 100% rename from torch_uncertainty/datasets/classification/fractals.py rename to torch_uncertainty/datasets/fractals.py diff --git a/torch_uncertainty/layers/bayesian/abnn.py b/torch_uncertainty/layers/bayesian/abnn.py new file mode 100644 index 00000000..dcf75122 --- /dev/null +++ b/torch_uncertainty/layers/bayesian/abnn.py @@ -0,0 +1,67 @@ +import torch +from torch import Tensor, nn +from torch.nn import functional as F + + +class BatchNormAdapter2d(nn.Module): + def __init__( + self, + num_features: int, + alpha: float = 0.1, + momentum: float = 0.1, + eps: float = 1e-5, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + self.weight = nn.Parameter( + torch.ones(num_features, device=device, dtype=dtype), + requires_grad=True, + ) + self.bias = nn.Parameter( + torch.zeros(num_features, device=device, dtype=dtype), + requires_grad=True, + ) + + self.register_buffer( + "running_mean", + torch.zeros(num_features, device=device, dtype=dtype), + ) + self.register_buffer( + "running_var", + torch.zeros(num_features, device=device, dtype=dtype), + ) + self.register_buffer( + "num_batches_tracked", + torch.tensor(0, dtype=torch.long, device=device), + ) + self.alpha = alpha + self.momentum = momentum + self.eps = eps + self.frozen = False + + def forward(self, x: Tensor) -> Tensor: + if self.frozen: + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training, + self.momentum, + self.eps, + ) + out = F.batch_norm( + x, + self.running_mean, + self.running_var, + None, + None, + self.training, + self.momentum, + self.eps, + ) + return self.weight.unsqueeze(-1).unsqueeze(-1) * out * ( + torch.randn_like(x) * self.alpha + 1 + ) + self.bias.unsqueeze(-1).unsqueeze(-1) diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index d9fc4df4..3584ba77 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -11,7 +11,7 @@ ) from torch.nn.parameter import Parameter -from .sampler import PriorDistribution, TrainableDistribution +from .sampler import CenteredGaussianMixture, TrainableDistribution __all__ = ["BayesConv1d", "BayesConv2d", "BayesConv3d"] @@ -137,11 +137,11 @@ def __init__( self.register_parameter("bias_mu", None) self.register_parameter("bias_sigma", None) - self.weight_prior_dist = PriorDistribution( + self.weight_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) if bias: - self.bias_prior_dist = PriorDistribution( + self.bias_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) @@ -290,14 +290,14 @@ def forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = ( self.weight_sampler.log_posterior() + bias_lposterior ) - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) @@ -323,9 +323,7 @@ def __init__( device=None, dtype=None, ) -> None: - """Bayesian Conv2d Layer with Mixture of Normals prior and Normal - posterior. - """ + """Bayesian Conv2d Layer with Gaussian Mixture prior and Normal posterior.""" factory_kwargs = {"device": device, "dtype": dtype} kernel_size_ = _pair(kernel_size) stride_ = _pair(stride) @@ -389,14 +387,14 @@ def forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = ( self.weight_sampler.log_posterior() + bias_lposterior ) - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) @@ -422,9 +420,7 @@ def __init__( device=None, dtype=None, ) -> None: - """Bayesian Conv3d Layer with Mixture of Normals prior and Normal - posterior. - """ + """Bayesian Conv3d Layer with Gaussian mixture prior and Normal posterior.""" factory_kwargs = {"device": device, "dtype": dtype} kernel_size_ = _triple(kernel_size) stride_ = _triple(stride) @@ -488,13 +484,13 @@ def forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = ( self.weight_sampler.log_posterior() + bias_lposterior ) - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return self._conv_forward(inputs, weight, bias) diff --git a/torch_uncertainty/layers/bayesian/bayes_linear.py b/torch_uncertainty/layers/bayesian/bayes_linear.py index 074f8554..2c9f15c4 100644 --- a/torch_uncertainty/layers/bayesian/bayes_linear.py +++ b/torch_uncertainty/layers/bayesian/bayes_linear.py @@ -3,7 +3,7 @@ from torch import Tensor, nn from torch.nn import init -from .sampler import PriorDistribution, TrainableDistribution +from .sampler import CenteredGaussianMixture, TrainableDistribution class BayesLinear(nn.Module): @@ -91,11 +91,11 @@ def __init__( self.bias_mu, self.bias_sigma ) - self.weight_prior_dist = PriorDistribution( + self.weight_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) if bias: - self.bias_prior_dist = PriorDistribution( + self.bias_prior_dist = CenteredGaussianMixture( prior_sigma_1, prior_sigma_2, prior_pi ) @@ -122,12 +122,12 @@ def _forward(self, inputs: Tensor) -> Tensor: if self.bias_mu is not None: bias = self.bias_sampler.sample() bias_lposterior = self.bias_sampler.log_posterior() - bias_lprior = self.bias_prior_dist.log_prior(bias) + bias_lprior = self.bias_prior_dist.log_prob(bias) else: bias, bias_lposterior, bias_lprior = None, 0, 0 self.lvposterior = self.weight_sampler.log_posterior() + bias_lposterior - self.lprior = self.weight_prior_dist.log_prior(weight) + bias_lprior + self.lprior = self.weight_prior_dist.log_prob(weight) + bias_lprior return F.linear(inputs, weight, bias) diff --git a/torch_uncertainty/layers/bayesian/sampler.py b/torch_uncertainty/layers/bayesian/sampler.py index a512fad7..dd5e710a 100644 --- a/torch_uncertainty/layers/bayesian/sampler.py +++ b/torch_uncertainty/layers/bayesian/sampler.py @@ -18,9 +18,7 @@ def __init__( self.weight = None def sample(self) -> Tensor: - w_sample = torch.normal( - mean=0, std=1, size=self.mu.shape, device=self.mu.device - ) + w_sample = torch.randn(size=self.mu.shape, device=self.mu.device) self.sigma = torch.log1p(torch.exp(self.rho)).to(self.mu.device) self.weight = self.mu + self.sigma * w_sample return self.weight @@ -43,26 +41,27 @@ def log_posterior(self, weight: Tensor | None = None) -> Tensor: return -lposterior.sum() -class PriorDistribution(nn.Module): +class CenteredGaussianMixture(nn.Module): def __init__( self, sigma_1: float, sigma_2: float, pi: float, ) -> None: + """Create a mixture of two centered Gaussian distributions. + + Args: + sigma_1 (float): Standard deviation of the first Gaussian. + sigma_2 (float): Standard deviation of the second Gaussian. + pi (float): Mixing coefficient. + """ super().__init__() - self.pi = torch.tensor([pi, 1 - pi]) - self.mus = torch.zeros(2) - self.sigmas = torch.tensor([sigma_1, sigma_2]) + self.register_buffer("pi", torch.tensor([pi, 1 - pi])) + self.register_buffer("mus", torch.zeros(2)) + self.register_buffer("sigmas", torch.tensor([sigma_1, sigma_2])) - def log_prior(self, weight: Tensor) -> Tensor: - self.convert(weight.device) + def log_prob(self, weight: Tensor) -> Tensor: mix = distributions.Categorical(self.pi) normals = distributions.Normal(self.mus, self.sigmas) - self.distribution = distributions.MixtureSameFamily(mix, normals) - return self.distribution.log_prob(weight).sum() - - def convert(self, device) -> None: - self.pi = self.pi.to(device) - self.mus = self.mus.to(device) - self.sigmas = self.sigmas.to(device) + distribution = distributions.MixtureSameFamily(mix, normals) + return distribution.log_prob(weight).sum() diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index 341cf5b9..4c7829c7 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -1,3 +1,5 @@ +from abc import ABC, abstractmethod + import torch.nn.functional as F from torch import Tensor, nn from torch.distributions import Distribution, Laplace, Normal @@ -5,18 +7,19 @@ from torch_uncertainty.utils.distributions import NormalInverseGamma -class _AbstractDist(nn.Module): +class TUDist(ABC, nn.Module): def __init__(self, dim: int) -> None: super().__init__() if dim < 1: raise ValueError(f"dim must be positive, got {dim}.") self.dim = dim + @abstractmethod def forward(self, x: Tensor) -> Distribution: - raise NotImplementedError + pass -class NormalLayer(_AbstractDist): +class NormalLayer(TUDist): """Normal distribution layer. Converts model outputs to Independent Normal distributions. @@ -46,7 +49,7 @@ def forward(self, x: Tensor) -> Normal: return Normal(loc, scale) -class LaplaceLayer(_AbstractDist): +class LaplaceLayer(TUDist): """Laplace distribution layer. Converts model outputs to Independent Laplace distributions. @@ -76,7 +79,7 @@ def forward(self, x: Tensor) -> Laplace: return Laplace(loc, scale) -class NormalInverseGammaLayer(_AbstractDist): +class NormalInverseGammaLayer(TUDist): """Normal-Inverse-Gamma distribution layer. Converts model outputs to Independent Normal-Inverse-Gamma distributions. diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 6f742b17..094b4834 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -6,12 +6,12 @@ def check_packed_parameters_consistency( - alpha: float, gamma: int, num_estimators: int + alpha: int, gamma: int, num_estimators: int ) -> None: """Check the consistency of the parameters of the Packed-Ensembles layers. Args: - alpha (float): The width multiplier of the layer. + alpha (int): The width multiplier of the layer. gamma (int): The number of groups in the ensemble. num_estimators (int): The number of estimators in the ensemble. """ @@ -49,7 +49,7 @@ def __init__( self, in_features: int, out_features: int, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, bias: bool = True, @@ -67,7 +67,7 @@ def __init__( Args: in_features (int): Number of input features of the linear layer. out_features (int): Number of channels produced by the linear layer. - alpha (float): The width multiplier of the linear layer. + alpha (int): The width multiplier of the linear layer. num_estimators (int): The number of estimators grouped in the layer. gamma (int, optional): Defaults to ``1``. bias (bool, optional): It ``True``, adds a learnable bias to the @@ -174,7 +174,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: _size_1_t, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, stride: _size_1_t = 1, @@ -195,7 +195,7 @@ def __init__( in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. - alpha (float): The channel multiplier of the convolutional layer. + alpha (int): The channel multiplier of the convolutional layer. num_estimators (int): Number of estimators in the ensemble. gamma (int, optional): Defaults to ``1``. stride (int or tuple, optional): Stride of the convolution. @@ -302,7 +302,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: _size_2_t, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, stride: _size_2_t = 1, @@ -323,7 +323,7 @@ def __init__( in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. - alpha (float): The channel multiplier of the convolutional layer. + alpha (int): The channel multiplier of the convolutional layer. num_estimators (int): Number of estimators in the ensemble. gamma (int, optional): Defaults to ``1``. stride (int or tuple, optional): Stride of the convolution. @@ -430,7 +430,7 @@ def __init__( in_channels: int, out_channels: int, kernel_size: _size_3_t, - alpha: float, + alpha: int, num_estimators: int, gamma: int = 1, stride: _size_3_t = 1, @@ -451,7 +451,7 @@ def __init__( in_channels (int): Number of channels in the input image. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. - alpha (float): The channel multiplier of the convolutional layer. + alpha (int): The channel multiplier of the convolutional layer. num_estimators (int): Number of estimators in the ensemble. gamma (int, optional): Defaults to ``1``. stride (int or tuple, optional): Stride of the convolution. diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index b0d6e1b8..c82ab210 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -22,14 +22,24 @@ def __init__( super().__init__() self.reduction = reduction - def forward(self, dist: Distribution, targets: Tensor) -> Tensor: + def forward( + self, + dist: Distribution, + targets: Tensor, + padding_mask: Tensor | None = None, + ) -> Tensor: """Compute the NLL of the targets given predicted distributions. Args: dist (Distribution): The predicted distributions targets (Tensor): The target values + padding_mask (Tensor, optional): The padding mask. Defaults to None. + Sets the loss to 0 for padded values. """ loss = -dist.log_prob(targets) + if padding_mask is not None: + loss = loss.masked_fill(padding_mask, 0.0) + if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": @@ -111,7 +121,8 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: for _ in range(self.num_samples): logits = self.model(inputs) aggregated_elbo += self.inner_loss(logits, targets) - aggregated_elbo += self.kl_weight * self._kl_div() + # TODO: This shouldn't be necessary + aggregated_elbo += self.kl_weight * self._kl_div().to(inputs.device) return aggregated_elbo / self.num_samples def set_model(self, model: nn.Module | None) -> None: diff --git a/torch_uncertainty/metrics/classification/calibration_error.py b/torch_uncertainty/metrics/classification/calibration_error.py index 5def4d3c..512323b3 100644 --- a/torch_uncertainty/metrics/classification/calibration_error.py +++ b/torch_uncertainty/metrics/classification/calibration_error.py @@ -60,7 +60,8 @@ def _ce_plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: ax.set_xlim(0, 100) ax.set_ylim(0, 100) ax.set_aspect("equal", "box") - fig.tight_layout() + if fig is not None: + fig.tight_layout() return fig, ax diff --git a/torch_uncertainty/metrics/classification/fpr95.py b/torch_uncertainty/metrics/classification/fpr95.py index 87a1b93a..f7e4a660 100644 --- a/torch_uncertainty/metrics/classification/fpr95.py +++ b/torch_uncertainty/metrics/classification/fpr95.py @@ -1,43 +1,11 @@ import numpy as np import torch -from numpy.typing import ArrayLike from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -def stable_cumsum(arr: ArrayLike, rtol: float = 1e-05, atol: float = 1e-08): - """Uses high precision for cumsum and checks that the final value matches - the sum. - - Args: - arr (ArrayLike): The array to be cumulatively summed as flat. - rtol (float, optional): Relative tolerance, see ``np.allclose``. - Defaults to 1e-05. - atol (float, optional): Absolute tolerance, see ``np.allclose``. - Defaults to 1e-08. - - Returns: - ArrayLike: The cumulatively summed array. - - Reference: - From https://github.com/hendrycks/anomaly-seg. - - TODO: Check if necessary. - """ - out = np.cumsum(arr, dtype=np.float64) - expected = np.sum(arr, dtype=np.float64) - if not np.allclose( - out[-1], expected, rtol=rtol, atol=atol - ): # coverage: ignore - raise RuntimeError( - "cumsum was found to be unstable: " - "its last element does not correspond to sum" - ) - return out - - class FPRx(Metric): is_differentiable: bool = False higher_is_better: bool = False @@ -53,6 +21,9 @@ def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: recall_level (float): The recall level at which to compute the FPR. pos_label (int): The positive label. kwargs: Additional arguments to pass to the metric class. + + Reference: + Inpired by https://github.com/hendrycks/anomaly-seg. """ super().__init__(**kwargs) @@ -82,13 +53,10 @@ def update(self, conf: Tensor, target: Tensor) -> None: self.targets.append(target) def compute(self) -> Tensor: - r"""Compute the actual False Positive Rate at x% Recall. + """Compute the actual False Positive Rate at x% Recall. Returns: Tensor: The value of the FPRx. - - Reference: - Inpired by https://github.com/hendrycks/anomaly-seg. """ conf = dim_zero_cat(self.conf).cpu().numpy() targets = dim_zero_cat(self.targets).cpu().numpy() @@ -120,7 +88,7 @@ def compute(self) -> Tensor: threshold_idxs = np.r_[distinct_value_indices, labels.shape[0] - 1] # accumulate the true positives with decreasing threshold - tps = stable_cumsum(labels)[threshold_idxs] + tps = np.cumsum(labels)[threshold_idxs] fps = 1 + threshold_idxs - tps # add one because of zero-based indexing thresholds = examples[threshold_idxs] diff --git a/torch_uncertainty/metrics/regression/nll.py b/torch_uncertainty/metrics/regression/nll.py index 9b2f9c3a..9db4dd31 100644 --- a/torch_uncertainty/metrics/regression/nll.py +++ b/torch_uncertainty/metrics/regression/nll.py @@ -5,17 +5,27 @@ class DistributionNLL(CategoricalNLL): - def update(self, dist: distributions.Distribution, target: Tensor) -> None: + def update( + self, + dist: distributions.Distribution, + target: Tensor, + padding_mask: Tensor | None = None, + ) -> None: """Update state with the predicted distributions and the targets. Args: dist (torch.distributions.Distribution): Predicted distributions. target (Tensor): Ground truth labels. + padding_mask (Tensor, optional): The padding mask. Defaults to None. + Sets the loss to 0 for padded values. """ + nlog_prob = -dist.log_prob(target) + if padding_mask is not None: + nlog_prob = nlog_prob.masked_fill(padding_mask, 0.0) if self.reduction is None or self.reduction == "none": - self.values.append(-dist.log_prob(target)) + self.values.append(nlog_prob) else: - self.values += -dist.log_prob(target).sum() + self.values += nlog_prob.sum() self.total += target.size(0) def compute(self) -> Tensor: diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 08dfc824..4b8964bc 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -1,3 +1,13 @@ # ruff: noqa: F401 -from .deep_ensembles import deep_ensembles -from .mc_dropout import mc_dropout +from .wrappers import ( + EMA, + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, + SWA, + SWAG, + CheckpointEnsemble, + MCDropout, + StochasticModel, + deep_ensembles, + mc_dropout, +) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index b18fa488..6804c6f9 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -8,7 +8,7 @@ from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear -from torch_uncertainty.models.utils import stochastic_model +from torch_uncertainty.models import StochasticModel __all__ = ["bayesian_lenet", "lenet", "packed_lenet"] @@ -39,7 +39,9 @@ def __init__( ): batchnorm = True else: - raise ValueError("norm must be nn.Identity or nn.BatchNorm2d") + raise ValueError( + f"norm must be nn.Identity or nn.BatchNorm2d. Got {norm}." + ) self.dropout_rate = dropout_rate self.last_layer_dropout = last_layer_dropout @@ -81,16 +83,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc3(out) -@stochastic_model -class _StochasticLeNet(_LeNet): - pass - - def _lenet( stochastic: bool, in_channels: int, num_classes: int, layer_args: dict, + num_samples: int = 16, linear_layer: type[nn.Module] = nn.Linear, conv2d_layer: type[nn.Module] = nn.Conv2d, activation: Callable = nn.ReLU, @@ -98,9 +96,8 @@ def _lenet( groups: int = 1, dropout_rate: float = 0.0, last_layer_dropout: bool = False, -) -> _LeNet | _StochasticLeNet: - model = _LeNet if not stochastic else _StochasticLeNet - return model( +) -> _LeNet | StochasticModel: + model = _LeNet( in_channels=in_channels, num_classes=num_classes, linear_layer=linear_layer, @@ -112,6 +109,9 @@ def _lenet( dropout_rate=dropout_rate, last_layer_dropout=last_layer_dropout, ) + if stochastic: + return StochasticModel(model, num_samples) + return model def lenet( @@ -170,6 +170,7 @@ def packed_lenet( def bayesian_lenet( in_channels: int, num_classes: int, + num_samples: int = 16, prior_sigma_1: float | None = None, prior_sigma_2: float | None = None, prior_pi: float | None = None, @@ -179,7 +180,7 @@ def bayesian_lenet( norm: type[nn.Module] = nn.Identity, groups: int = 1, dropout_rate: float = 0.0, -) -> _LeNet: +) -> StochasticModel: layers_args = {} if prior_sigma_1 is not None: layers_args["prior_sigma_1"] = prior_sigma_1 @@ -194,6 +195,7 @@ def bayesian_lenet( return _lenet( stochastic=True, + num_samples=num_samples, in_channels=in_channels, num_classes=num_classes, linear_layer=BayesLinear, diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index 1a50524f..d0fdee07 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -5,7 +5,7 @@ from torch_uncertainty.layers.bayesian import BayesLinear from torch_uncertainty.layers.packed import PackedLinear -from torch_uncertainty.models.utils import stochastic_model +from torch_uncertainty.models import StochasticModel __all__ = ["bayesian_mlp", "mlp", "packed_mlp"] @@ -84,29 +84,24 @@ def forward(self, x: Tensor) -> Tensor: return self.final_layer(self.layers[-1](x)) -@stochastic_model -class _StochasticMLP(_MLP): - pass - - def _mlp( stochastic: bool, in_features: int, num_outputs: int, hidden_dims: list[int], + num_samples: int = 16, layer_args: dict | None = None, layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, -) -> _MLP | _StochasticMLP: +) -> _MLP | StochasticModel: if layer_args is None: layer_args = {} if final_layer_args is None: final_layer_args = {} - model = _MLP if not stochastic else _StochasticMLP - return model( + model = _MLP( in_features=in_features, num_outputs=num_outputs, hidden_dims=hidden_dims, @@ -117,6 +112,9 @@ def _mlp( final_layer_args=final_layer_args, dropout_rate=dropout_rate, ) + if stochastic: + return StochasticModel(model, num_samples) + return model def mlp( @@ -194,13 +192,15 @@ def bayesian_mlp( in_features: int, num_outputs: int, hidden_dims: list[int], + num_samples: int = 16, activation: Callable = F.relu, final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, -) -> _StochasticMLP: +) -> StochasticModel: return _mlp( stochastic=True, + num_samples=num_samples, in_features=in_features, num_outputs=num_outputs, hidden_dims=hidden_dims, diff --git a/torch_uncertainty/models/segmentation/segformer.py b/torch_uncertainty/models/segmentation/segformer.py index 763aea71..6c34dfcb 100644 --- a/torch_uncertainty/models/segmentation/segformer.py +++ b/torch_uncertainty/models/segmentation/segformer.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.layers import DropPath, to_2tuple, trunc_normal_ from torch import Tensor, nn diff --git a/torch_uncertainty/models/utils.py b/torch_uncertainty/models/utils.py index cb0bcc02..87fd65ee 100644 --- a/torch_uncertainty/models/utils.py +++ b/torch_uncertainty/models/utils.py @@ -1,62 +1,5 @@ from torch import Tensor, nn -from torch_uncertainty.layers.bayesian import bayesian_modules - - -def stochastic_model(model: nn.Module) -> nn.Module: - """Decorator for stochastic models. - - When applied to a model, it adds the `sample`, `freeze` and `unfreeze` - methods. Use `freeze` to obtain deterministic outputs. Use unfreeze to - obtain stochastic outputs. `sample` get samples of the estimated posterior - distribution. - - Args: - model (nn.Module): PyTorch model. - """ - - def sample(self, num_samples: int = 1) -> list[dict]: - sampled_models = [{}] * num_samples - for module_name in self._modules: - module = self._modules[module_name] - if isinstance(module, bayesian_modules): - for model in sampled_models: - weight, bias = module.sample() - model[module_name + ".weight"] = weight - if bias is not None: - model[module_name + ".bias"] = bias - else: - for model in sampled_models: - state = module.state_dict() - if not len(state): # no parameter - break - # TODO: fix this - model |= { - module_name + "." + key: val - for key, val in module.state_dict().items() - } - return sampled_models - - model.sample = sample - - def freeze(self) -> None: - for module_name in self._modules: - module = self._modules[module_name] - if isinstance(module, bayesian_modules): - module.freeze() - - model.freeze = freeze - - def unfreeze(self) -> None: - for module_name in self._modules: - module = self._modules[module_name] - if isinstance(module, bayesian_modules): - module.unfreeze() - - model.unfreeze = unfreeze - - return model - class Backbone(nn.Module): def __init__(self, model: nn.Module, feat_names: list[str]) -> None: diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py new file mode 100644 index 00000000..75f37e66 --- /dev/null +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -0,0 +1,13 @@ +# ruff: noqa: F401 +from .checkpoint_ensemble import ( + CheckpointEnsemble, +) +from .deep_ensembles import deep_ensembles +from .ema import EMA +from .mc_dropout import MCDropout, mc_dropout +from .stochastic import StochasticModel +from .swa import SWA +from .swag import SWAG + +STEP_UPDATE_MODEL = (EMA,) +EPOCH_UPDATE_MODEL = (SWA, SWAG, CheckpointEnsemble) diff --git a/torch_uncertainty/models/wrappers/checkpoint_ensemble.py b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py new file mode 100644 index 00000000..f2d4d869 --- /dev/null +++ b/torch_uncertainty/models/wrappers/checkpoint_ensemble.py @@ -0,0 +1,73 @@ +import copy + +import torch +from torch import nn + + +class CheckpointEnsemble(nn.Module): + def __init__( + self, + model: nn.Module, + save_schedule: list[int] | None = None, + use_final_checkpoint: bool = True, + ) -> None: + """Ensemble of models at different points in the training trajectory. + + Args: + model (nn.Module): The model to train and ensemble. + save_schedule (list[int]): The epochs at which to save the model. + If save schedule is None, save the model at every epoch. + Defaults to None. + use_final_checkpoint (bool, optional): Whether to use the final + model as a checkpoint. Defaults to True. + + Reference: + Checkpoint Ensembles: Ensemble Methods from a Single Training Process. + Hugh Chen, Scott Lundberg, Su-In Lee. In ArXiv 2018. + """ + super().__init__() + self.core_model = model + self.save_schedule = save_schedule + self.use_final_checkpoint = use_final_checkpoint + self.num_estimators = int(use_final_checkpoint) + self.saved_models = [] + self.num_estimators = 1 + + @torch.no_grad() + def update_wrapper(self, epoch: int) -> None: + """Save the model at the given epoch if included in the schedule. + + Args: + epoch (int): The current epoch. + """ + if self.save_schedule is None or epoch in self.save_schedule: + self.saved_models.append(copy.deepcopy(self.core_model)) + self.num_estimators += 1 + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for evaluation. + + If the model is in evaluation mode, this method will return the + ensemble prediction. Otherwise, it will return the prediction of the + current model. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The model or ensemble output. + """ + if not len(self.saved_models): + return self.core_model.forward(x) + preds = torch.cat( + [model.forward(x) for model in self.saved_models], dim=0 + ) + if self.use_final_checkpoint: + model_forward = self.core_model.forward(x) + preds = torch.cat([model_forward, preds], dim=0) + return preds + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + return self.core_model.forward(x) + return self.eval_forward(x) diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/wrappers/deep_ensembles.py similarity index 92% rename from torch_uncertainty/models/deep_ensembles.py rename to torch_uncertainty/models/wrappers/deep_ensembles.py index 49640108..a72ae7c4 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/wrappers/deep_ensembles.py @@ -15,7 +15,7 @@ def __init__( ) -> None: """Create a classification deep ensembles from a list of models.""" super().__init__() - self.models = nn.ModuleList(models) + self.core_models = nn.ModuleList(models) self.num_estimators = len(models) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -29,7 +29,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: where :math:`B` is the batch size, :math:`N` is the number of estimators, and :math:`C` is the number of classes. """ - return torch.cat([model.forward(x) for model in self.models], dim=0) + return torch.cat( + [model.forward(x) for model in self.core_models], dim=0 + ) class _RegDeepEnsembles(_DeepEnsembles): @@ -52,7 +54,9 @@ def forward(self, x: torch.Tensor) -> Distribution: Distribution: """ if self.probabilistic: - return cat_dist([model.forward(x) for model in self.models], dim=0) + return cat_dist( + [model.forward(x) for model in self.core_models], dim=0 + ) return super().forward(x) @@ -92,6 +96,8 @@ def deep_ensembles( Simple and scalable predictive uncertainty estimation using deep ensembles. In NeurIPS, 2017. """ + if isinstance(models, list) and len(models) == 0: + raise ValueError("Models must not be an empty list.") if (isinstance(models, list) and len(models) == 1) or isinstance( models, nn.Module ): diff --git a/torch_uncertainty/models/wrappers/ema.py b/torch_uncertainty/models/wrappers/ema.py new file mode 100644 index 00000000..386fcca7 --- /dev/null +++ b/torch_uncertainty/models/wrappers/ema.py @@ -0,0 +1,53 @@ +import copy + +from torch import Tensor, nn + + +class EMA(nn.Module): + def __init__( + self, + model: nn.Module, + momentum: float, + ) -> None: + """Exponential Moving Average. + + Args: + model (nn.Module): The model to train and ensemble. + momentum (float): The momentum of the moving average. + """ + super().__init__() + _ema_checks(momentum) + self.core_model = model + self.ema_model = copy.deepcopy(model) + self.momentum = momentum + self.remainder = 1 - momentum + + def update_wrapper(self, epoch: int | None = None) -> None: + """Update the EMA model. + + Args: + epoch (int): The current epoch. For API consistency. + """ + for ema_param, param in zip( + self.ema_model.parameters(), + self.core_model.parameters(), + strict=False, + ): + ema_param.data = ( + ema_param.data * self.momentum + param.data * self.remainder + ) + + def eval_forward(self, x: Tensor) -> Tensor: + return self.ema_model.forward(x) + + def forward(self, x: Tensor) -> Tensor: + if self.training: + return self.core_model.forward(x) + return self.eval_forward(x) + + +def _ema_checks(momentum: float) -> None: + if momentum < 0.0 or momentum >= 1.0: + raise ValueError( + f"`momentum` must be in the range [0, 1). Got {momentum}." + ) diff --git a/torch_uncertainty/models/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py similarity index 57% rename from torch_uncertainty/models/mc_dropout.py rename to torch_uncertainty/models/wrappers/mc_dropout.py index 24a545b3..986d23a8 100644 --- a/torch_uncertainty/models/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -1,9 +1,14 @@ +import torch from torch import Tensor, nn -class _MCDropout(nn.Module): +class MCDropout(nn.Module): def __init__( - self, model: nn.Module, num_estimators: int, last_layer: bool + self, + model: nn.Module, + num_estimators: int, + last_layer: bool, + on_batch: bool, ) -> None: """MC Dropout wrapper for a model containing nn.Dropout modules. @@ -11,6 +16,8 @@ def __init__( model (nn.Module): model to wrap num_estimators (int): number of estimators to use last_layer (bool): whether to apply dropout to the last layer only. + on_batch (bool): Increase the batch_size to perform MC-Dropout. + Otherwise in a for loop. Warning: Apply dropout using modules and not functional for this wrapper to @@ -26,24 +33,10 @@ def __init__( (i.e. after all the other dropout layers). """ super().__init__() + _dropout_checks(model, num_estimators) self.last_layer = last_layer - - if not hasattr(model, "dropout_rate"): - raise ValueError( - "`dropout_rate` must be set in the model to use MC Dropout." - ) - if model.dropout_rate <= 0.0: - raise ValueError( - "`dropout_rate` must be strictly positive to use MC Dropout." - ) - if num_estimators is None: - raise ValueError("`num_estimators` must be set to use MC Dropout.") - if num_estimators <= 0: - raise ValueError( - "`num_estimators` must be strictly positive to use MC Dropout." - ) - - self.model = model + self.on_batch = on_batch + self.core_model = model self.num_estimators = num_estimators self.filtered_modules = list( @@ -76,14 +69,23 @@ def forward( self, x: Tensor, ) -> Tensor: - if not self.training: + if self.training: + return self.core_model(x) + if self.on_batch: x = x.repeat(self.num_estimators, 1, 1, 1) - return self.model(x) + return self.core_model(x) + # Else, for loop + return torch.cat( + [self.core_model(x) for _ in range(self.num_estimators)], dim=0 + ) def mc_dropout( - model: nn.Module, num_estimators: int, last_layer: bool = False -) -> _MCDropout: + model: nn.Module, + num_estimators: int, + last_layer: bool = False, + on_batch: bool = True, +) -> MCDropout: """MC Dropout wrapper for a model. Args: @@ -91,7 +93,31 @@ def mc_dropout( num_estimators (int): number of estimators to use last_layer (bool, optional): whether to apply dropout to the last layer only. Defaults to False. + on_batch (bool): Increase the batch_size to perform MC-Dropout. + Otherwise in a for loop to reduce memory footprint. Defaults + to true. + """ - return _MCDropout( - model=model, num_estimators=num_estimators, last_layer=last_layer + return MCDropout( + model=model, + num_estimators=num_estimators, + last_layer=last_layer, + on_batch=on_batch, ) + + +def _dropout_checks(model: nn.Module, num_estimators: int) -> None: + if not hasattr(model, "dropout_rate"): + raise ValueError( + "`dropout_rate` must be set in the model to use MC Dropout." + ) + if model.dropout_rate <= 0.0: + raise ValueError( + "`dropout_rate` must be strictly positive to use MC Dropout." + ) + if num_estimators is None: + raise ValueError("`num_estimators` must be set to use MC Dropout.") + if num_estimators <= 0: + raise ValueError( + "`num_estimators` must be strictly positive to use MC Dropout." + ) diff --git a/torch_uncertainty/models/wrappers/stochastic.py b/torch_uncertainty/models/wrappers/stochastic.py new file mode 100644 index 00000000..7f298a87 --- /dev/null +++ b/torch_uncertainty/models/wrappers/stochastic.py @@ -0,0 +1,55 @@ +import torch +from torch import Tensor, nn + +from torch_uncertainty.layers.bayesian import bayesian_modules + + +class StochasticModel(nn.Module): + def __init__(self, model: nn.Module, num_samples: int) -> None: + super().__init__() + self.core_model = model + self.num_samples = num_samples + + def eval_forward(self, x: Tensor) -> Tensor: + return torch.cat( + [self.core_model.forward(x) for _ in range(self.num_samples)], dim=0 + ) + + def forward(self, x: Tensor) -> Tensor: + if self.training: + return self.core_model.forward(x) + return self.eval_forward(x) + + def sample(self, num_samples: int = 1) -> list[dict]: + sampled_models = [{}] * num_samples + for module_name in self.core_model._modules: + module = self.core_model._modules[module_name] + if isinstance(module, bayesian_modules): + for model in sampled_models: + weight, bias = module.sample() + model[module_name + ".weight"] = weight + if bias is not None: + model[module_name + ".bias"] = bias + else: + for model in sampled_models: + state = module.state_dict() + if not len(state): # no parameter + break + # TODO: fix this + model |= { + module_name + "." + key: val + for key, val in module.state_dict().items() + } + return sampled_models + + def freeze(self) -> None: + for module_name in self.core_model._modules: + module = self.core_model._modules[module_name] + if isinstance(module, bayesian_modules): + module.freeze() + + def unfreeze(self) -> None: + for module_name in self.core_model._modules: + module = self.core_model._modules[module_name] + if isinstance(module, bayesian_modules): + module.unfreeze() diff --git a/torch_uncertainty/models/wrappers/swa.py b/torch_uncertainty/models/wrappers/swa.py new file mode 100644 index 00000000..27fbb20e --- /dev/null +++ b/torch_uncertainty/models/wrappers/swa.py @@ -0,0 +1,90 @@ +import copy + +import torch +from torch import Tensor, nn +from torch.utils.data import DataLoader + + +class SWA(nn.Module): + num_avgd_models: Tensor + + def __init__( + self, + model: nn.Module, + cycle_start: int, + cycle_length: int, + ) -> None: + """Stochastic Weight Averaging. + + Update the SWA model every :attr:`cycle_length` epochs starting at + :attr:`cycle_start`. Uses the SWA model only at test time. Otherwise, + uses the base model for training. + + Args: + model (nn.Module): PyTorch model to be trained. + cycle_start (int): Epoch to start SWA. + cycle_length (int): Number of epochs between SWA updates. + + Reference: + Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G. + (2018). Averaging Weights Leads to Wider Optima and Better Generalization. + In UAI 2018. + """ + super().__init__() + _swa_checks(cycle_start, cycle_length) + self.core_model = model + self.cycle_start = cycle_start + self.cycle_length = cycle_length + + self.register_buffer("num_avgd_models", torch.tensor(0, device="cpu")) + self.swa_model = None + self.need_bn_update = False + + @torch.no_grad() + def update_wrapper(self, epoch: int) -> None: + if ( + epoch >= self.cycle_start + and (epoch - self.cycle_start) % self.cycle_length == 0 + ): + if self.swa_model is None: + self.swa_model = copy.deepcopy(self.core_model) + self.num_avgd_models = torch.tensor(1) + else: + for swa_param, param in zip( + self.swa_model.parameters(), + self.core_model.parameters(), + strict=False, + ): + swa_param.data += (param.data - swa_param.data) / ( + self.num_avgd_models + 1 + ) + self.num_avgd_models += 1 + self.need_bn_update = True + + def eval_forward(self, x: Tensor) -> Tensor: + if self.swa_model is None: + return self.core_model.forward(x) + return self.swa_model.forward(x) + + def forward(self, x: Tensor) -> Tensor: + if self.training: + return self.core_model.forward(x) + return self.eval_forward(x) + + def bn_update(self, loader: DataLoader, device) -> None: + if self.need_bn_update and self.swa_model is not None: + torch.optim.swa_utils.update_bn( + loader, self.swa_model, device=device + ) + self.need_bn_update = False + + +def _swa_checks(cycle_start: int, cycle_length: int) -> None: + if cycle_start < 0: + raise ValueError( + f"`cycle_start` must be non-negative. Got {cycle_start}." + ) + if cycle_length <= 0: + raise ValueError( + f"`cycle_length` must be strictly positive. Got {cycle_length}." + ) diff --git a/torch_uncertainty/models/wrappers/swag.py b/torch_uncertainty/models/wrappers/swag.py new file mode 100644 index 00000000..fb12588c --- /dev/null +++ b/torch_uncertainty/models/wrappers/swag.py @@ -0,0 +1,262 @@ +import copy +from collections.abc import Mapping + +import torch +from torch import Tensor, nn +from torch.utils.data import DataLoader + +from .swa import SWA + + +class SWAG(SWA): + swag_stats: dict[str, Tensor] + prfx = "model.swag_stats." + + def __init__( + self, + model: nn.Module, + cycle_start: int, + cycle_length: int, + scale: float = 1.0, + diag_covariance: bool = False, + max_num_models: int = 20, + var_clamp: float = 1e-6, + num_estimators: int = 16, + ) -> None: + """Stochastic Weight Averaging Gaussian (SWAG). + + Update the SWAG posterior every `cycle_length` epochs starting at + `cycle_start`. Samples :attr:`num_estimators` models from the SWAG + posterior after each update. Uses the SWAG posterior estimation only + at test time. Otherwise, uses the base model for training. + + Args: + model (nn.Module): PyTorch model to be trained. + cycle_start (int): Begininning of the first SWAG averaging cycle. + cycle_length (int): Number of epochs between SWAG updates. The + first update occurs at :attr:`cycle_start`+:attr:`cycle_length`. + scale (float, optional): Scale of the Gaussian. Defaults to 1.0. + diag_covariance (bool, optional): Whether to use a diagonal + covariance. Defaults to False. + max_num_models (int, optional): Maximum number of models to store. + Defaults to 0. + var_clamp (float, optional): Minimum variance. Defaults to 1e-30. + num_estimators (int, optional): Number of posterior estimates to + use. Defaults to 16. + + Reference: + Maddox, W. J. et al. A simple baseline for bayesian uncertainty in + deep learning. In NeurIPS 2019. + + Note: + Originates from https://github.com/wjmaddox/swa_gaussian. + """ + super().__init__(model, cycle_start, cycle_length) + _swag_checks(scale, max_num_models, var_clamp) + + self.num_estimators = num_estimators + self.scale = scale + + self.diag_covariance = diag_covariance + self.max_num_models = max_num_models + self.var_clamp = var_clamp + + self.initialize_stats() + self.fit = False + self.samples = [] + + def eval_forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.fit: + return self.core_model.forward(x) + return torch.cat([mod.to(device=x.device)(x) for mod in self.samples]) + + def initialize_stats(self) -> None: + """Initialize the SWAG dictionary of statistics.""" + self.swag_stats = {} + for name_p, param in self.core_model.named_parameters(): + mean, squared_mean = ( + torch.zeros_like(param, device="cpu"), + torch.zeros_like(param, device="cpu"), + ) + self.swag_stats[self.prfx + name_p + "_mean"] = mean + self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean + + if not self.diag_covariance: + covariance_sqrt = torch.zeros((0, param.numel()), device="cpu") + self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = ( + covariance_sqrt + ) + + @torch.no_grad() + def update_wrapper(self, epoch: int) -> None: + """Update the SWAG posterior. + + The update is performed if the epoch is greater than the cycle start + and the difference between the epoch and the cycle start is a multiple + of the cycle length. + + Args: + epoch (int): Current epoch. + """ + if not ( + epoch > self.cycle_start + and (epoch - self.cycle_start) % self.cycle_length == 0 + ): + return + + for name_p, param in self.core_model.named_parameters(): + mean = self.swag_stats[self.prfx + name_p + "_mean"] + squared_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] + new_param = param.data.detach().cpu() + + mean = mean * self.num_avgd_models / ( + self.num_avgd_models + 1 + ) + new_param / (self.num_avgd_models + 1) + squared_mean = squared_mean * self.num_avgd_models / ( + self.num_avgd_models + 1 + ) + new_param**2 / (self.num_avgd_models + 1) + + self.swag_stats[self.prfx + name_p + "_mean"] = mean + self.swag_stats[self.prfx + name_p + "_sq_mean"] = squared_mean + + if not self.diag_covariance: + covariance_sqrt = self.swag_stats[ + self.prfx + name_p + "_covariance_sqrt" + ] + dev = (new_param - mean).view(-1, 1).t() + covariance_sqrt = torch.cat((covariance_sqrt, dev), dim=0) + if self.num_avgd_models + 1 > self.max_num_models: + covariance_sqrt = covariance_sqrt[1:, :] + self.swag_stats[self.prfx + name_p + "_covariance_sqrt"] = ( + covariance_sqrt + ) + + self.num_avgd_models += 1 + + self.samples = [ + self.sample(self.scale, self.diag_covariance) + for _ in range(self.num_estimators) + ] + self.need_bn_update = True + self.fit = True + + def bn_update(self, loader: DataLoader, device) -> None: + """Update the bachnorm statistics of the current SWAG samples. + + Args: + loader (DataLoader): DataLoader to update the batchnorm statistics. + device (torch.device): Device to perform the update. + """ + if self.need_bn_update: + for mod in self.samples: + torch.optim.swa_utils.update_bn(loader, mod, device=device) + self.need_bn_update = False + + def sample( + self, + scale: float, + diag_covariance: bool | None = None, + block: bool = False, + seed: int | None = None, + ) -> nn.Module: + """Sample a model from the SWAG posterior. + + Args: + scale (float): Rescale coefficient of the Gaussian. + diag_covariance (bool, optional): Whether to use a diagonal + covariance. Defaults to None. + block (bool, optional): Whether to sample a block diagonal + covariance. Defaults to False. + seed (int, optional): Random seed. Defaults to None. + + Returns: + nn.Module: Sampled model. + """ + if seed is not None: + torch.manual_seed(seed) + + if diag_covariance is None: + diag_covariance = self.diag_covariance + if not diag_covariance and self.diag_covariance: + raise ValueError( + "Cannot sample full rank from diagonal covariance matrix." + ) + + if not block: + return self._fullrank_sample(scale, diag_covariance) + raise NotImplementedError("Raise an issue if you need this feature.") + + def _fullrank_sample( + self, scale: float, diagonal_covariance: bool + ) -> nn.Module: + new_sample = copy.deepcopy(self.core_model) + + for name_p, param in new_sample.named_parameters(): + mean = self.swag_stats[self.prfx + name_p + "_mean"] + sq_mean = self.swag_stats[self.prfx + name_p + "_sq_mean"] + + if not diagonal_covariance: + cov_mat_sqrt = self.swag_stats[ + self.prfx + name_p + "_covariance_sqrt" + ] + + var = torch.clamp(sq_mean - mean**2, self.var_clamp) + var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False) + + if not diagonal_covariance: + cov_sample = cov_mat_sqrt.t() @ torch.randn( + (cov_mat_sqrt.size(0),) + ) + cov_sample /= (self.max_num_models - 1) ** 0.5 + var_sample += cov_sample.view_as(var_sample) + + sample = mean + scale**0.5 * var_sample + param.data = sample.to(device="cpu", dtype=param.dtype) + return new_sample + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination |= self.swag_stats + + def state_dict( + self, *args, destination=None, prefix="", keep_vars=False + ) -> dict[str, Tensor]: + return self.swag_stats | super().state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + def _load_swag_stats(self, state_dict): + self.swag_stats = { + k: v for k, v in state_dict.items() if k in self.swag_stats + } + for k in self.swag_stats: + del state_dict[k] + self.samples = [ + self.sample(self.scale, self.diag_covariance) + for _ in range(self.num_estimators) + ] + self.need_bn_update = True + self.fit = True + + def load_state_dict( + self, state_dict: Mapping, strict: bool = True, assign: bool = False + ): + self._load_swag_stats(state_dict) + return super().load_state_dict(state_dict, strict, assign) + + def compute_logdet(self, block=False): + raise NotImplementedError("Raise an issue if you need this feature.") + + def compute_logprob(self, vec=None, block=False, diag=False): + raise NotImplementedError("Raise an issue if you need this feature.") + + +def _swag_checks(scale: float, max_num_models: int, var_clamp: float) -> None: + if scale < 0: + raise ValueError(f"`scale` must be non-negative. Got {scale}.") + if max_num_models < 0: + raise ValueError( + f"`max_num_models` must be non-negative. Got {max_num_models}." + ) + if var_clamp < 0: + raise ValueError(f"`var_clamp` must be non-negative. Got {var_clamp}.") diff --git a/torch_uncertainty/optim_recipes.py b/torch_uncertainty/optim_recipes.py index a413b02c..8d648400 100644 --- a/torch_uncertainty/optim_recipes.py +++ b/torch_uncertainty/optim_recipes.py @@ -1,26 +1,35 @@ from collections.abc import Callable from functools import partial +from typing import Literal +import torch from timm.optim import Lamb from torch import nn, optim from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -__all__ = [ - "optim_cifar10_resnet18", - "optim_cifar10_resnet34", - "optim_cifar10_resnet50", - "optim_cifar10_vgg16", - "optim_cifar10_wideresnet", - "optim_cifar100_resnet18", - "optim_cifar100_resnet34", - "optim_cifar100_resnet50", - "optim_cifar100_vgg16", - "optim_imagenet_resnet50", - "optim_imagenet_resnet50_a3", - "optim_tinyimagenet_resnet34", - "optim_tinyimagenet_resnet50", -] + +def optim_abnn( + model: nn.Module, + lr: float, + momentum: float = 0.9, + weight_decay: float = 1e-4, + nesterov: bool = True, +) -> dict: + """ABNN finetuning recipe.""" + optimizer = optim.SGD( + model.parameters(), + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov=nesterov, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[1, 4], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} def optim_cifar10_resnet18( @@ -422,3 +431,43 @@ def get_procedure( procedure = partial(batch_ensemble_wrapper, optim_recipe=procedure) return procedure + + +class FullSWALR(torch.optim.lr_scheduler.SequentialLR): + def __init__( + self, + optimizer: Optimizer, + milestone: int, + swa_lr: float, + anneal_epochs: int, + optim_eta_min: float = 0, + anneal_strategy: Literal["cos", "linear"] = "cos", + ) -> None: + """Chains a Cosine scheduler and a SWA scheduler. + + This class is an example of a wrapper to enable training SWA and SWAG + models using the CLI. You may create your own class following this + example. + + Args: + optimizer (Optimizer): The optimizer to be used. + milestone (int): The epoch to start the SWA. + swa_lr (float): The learning rate to use for the SWA model. + anneal_epochs (int): The number of epochs to anneal the learning rate. + optim_eta_min (float): The minimum learning rate for the first optimizer. + anneal_strategy (Literal["cos", "linear"]): The strategy to anneal the learning rate. + """ + optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer=optimizer, T_max=milestone, eta_min=optim_eta_min + ) + swa_scheduler = torch.optim.swa_utils.SWALR( + optimizer, + swa_lr=swa_lr, + anneal_epochs=anneal_epochs, + anneal_strategy=anneal_strategy, + ) + super().__init__( + optimizer=optimizer, + schedulers=[optim_scheduler, swa_scheduler], + milestones=[milestone], + ) diff --git a/torch_uncertainty/post_processing/__init__.py b/torch_uncertainty/post_processing/__init__.py index edbdceef..bc5a59cf 100644 --- a/torch_uncertainty/post_processing/__init__.py +++ b/torch_uncertainty/post_processing/__init__.py @@ -1,3 +1,5 @@ # ruff: noqa: F401 +from .abstract import PostProcessing from .calibration import MatrixScaler, TemperatureScaler, VectorScaler +from .laplace import LaplaceApprox from .mc_batch_norm import MCBatchNorm diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py new file mode 100644 index 00000000..9dd4e79d --- /dev/null +++ b/torch_uncertainty/post_processing/abnn.py @@ -0,0 +1,212 @@ +import copy + +import torch +from torch import Tensor, nn +from torch.utils.data import DataLoader, Dataset + +from torch_uncertainty.layers.bayesian.abnn import BatchNormAdapter2d +from torch_uncertainty.models import deep_ensembles +from torch_uncertainty.optim_recipes import optim_abnn +from torch_uncertainty.routines import ClassificationRoutine +from torch_uncertainty.utils import TUTrainer + +from .abstract import PostProcessing + + +class ABNN(PostProcessing): + def __init__( + self, + num_classes: int, + random_prior: float, + alpha: float, + num_models: int, + num_samples: int, + base_lr: float, + device: torch.device | str, + max_epochs: int = 5, + use_original_model: bool = True, + batch_size: int = 128, + precision: str = "32", + model: nn.Module | None = None, + ): + """ABNN post-processing. + + Args: + num_classes (int): Number of classes of the inner model. + random_prior (float): Random prior specializing estimators on + certain classes. + alpha (float): Alpha value for ABNN to control the diversity of + the predictions. + num_models (int): Number of stochastic models. + num_samples (int): Number of samples per model. + base_lr (float): Base learning rate. + device (torch.device): Device to use. + max_epochs (int, optional): Number of training epochs. Defaults + to 5. + use_original_model (bool, optional): Use original model during + evaluation. Defaults to True. + batch_size (int, optional): Batch size for the training of ABNN. + Defaults to 128. + precision (str, optional): Machine precision for training & eval. + Defaults to "32". + model (nn.Module | None, optional): Model to use. Defaults to None. + + Reference: + + """ + super().__init__(model) + _abnn_checks( + num_classes=num_classes, + random_prior=random_prior, + alpha=alpha, + max_epochs=max_epochs, + num_models=num_models, + num_samples=num_samples, + base_lr=base_lr, + batch_size=batch_size, + ) + self.num_classes = num_classes + self.alpha = alpha + self.base_lr = base_lr + self.num_models = num_models + self.num_samples = num_samples + self.total_models = num_models + int(use_original_model) + self.use_original_model = use_original_model + self.max_epochs = max_epochs + + self.batch_size = batch_size + self.precision = precision + self.device = device + + self.final_model = None + + # Build random prior + num_rp_classes = int(num_classes**0.5) + self.weights = [] + for _ in range(num_models): + weight = torch.ones([num_classes]) + weight[torch.randperm(num_classes)[:num_rp_classes]] += ( + random_prior - 1 + ) + self.weights.append(weight) + + def fit(self, dataset: Dataset) -> None: + if self.model is None: + raise ValueError("Model must be set before fitting.") + dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + source_model = copy.deepcopy(self.model) + _replace_bn_layers(source_model, self.alpha) + + models = [copy.deepcopy(source_model) for _ in range(self.num_models)] + + baselines = [ + ClassificationRoutine( + num_classes=self.num_classes, + model=mod, + loss=nn.CrossEntropyLoss( + weight=self.weights[i].to(device=self.device) + ), + optim_recipe=optim_abnn(mod, lr=self.base_lr), + eval_ood=True, + ) + for i, mod in enumerate(models) + ] + + for baseline in baselines: + trainer = TUTrainer( + max_epochs=self.max_epochs, + accelerator=self.device, + enable_progress_bar=False, + precision=self.precision, + enable_checkpointing=False, + logger=None, + enable_model_summary=False, + ) + trainer.fit(model=baseline, train_dataloaders=dl) + + final_models = ( + [copy.deepcopy(source_model) for _ in range(self.num_samples)] + if self.use_original_model + else [] + ) + for baseline in baselines: + model = copy.deepcopy(source_model) + model.load_state_dict(baseline.model.state_dict()) + final_models.extend( + [copy.deepcopy(model) for _ in range(self.num_samples)] + ) + + self.final_model = deep_ensembles(final_models) + + def forward( + self, + x: Tensor, + ) -> Tensor: + if self.final_model is not None: + return self.final_model(x) + if self.model is not None: + return self.model(x) + raise ValueError("Model must be set before calling forward.") + + +def _abnn_checks( + num_classes, + random_prior, + alpha, + max_epochs, + num_models, + num_samples, + base_lr, + batch_size, +) -> None: + if random_prior < 0: + raise ValueError( + f"random_prior must be greater than 0. Got {random_prior}." + ) + if batch_size < 1: + raise ValueError( + f"batch_size must be greater than 0. Got {batch_size}." + ) + if max_epochs < 1: + raise ValueError(f"epoch must be greater than 0. Got {max_epochs}.") + if num_models < 1: + raise ValueError( + f"num_models must be greater than 0. Got {num_models}." + ) + if num_samples < 1: + raise ValueError( + f"num_samples must be greater than 0. Got {num_samples}." + ) + if alpha < 0: + raise ValueError(f"alpha must be greater than 0. Got {alpha}.") + if base_lr < 0: + raise ValueError(f"base_lr must be greater than 0. Got {base_lr}.") + if num_classes < 1: + raise ValueError( + f"num_classes must be greater than 0. Got {num_classes}." + ) + + +def _replace_bn_layers(model: nn.Module, alpha: float) -> None: + """Recursively replace batch normalization layers with ABNN layers. + + Args: + model (nn.Module): Model to replace batch normalization layers. + alpha (float): Alpha value for ABNN. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + _replace_bn_layers(module, alpha) + if isinstance(module, nn.BatchNorm2d) and module.track_running_stats: + num_channels = module.num_features + new_module = BatchNormAdapter2d(num_channels, alpha=alpha) + new_module.running_mean = module.running_mean + new_module.running_var = module.running_var + new_module.num_batches_tracked = module.num_batches_tracked + + new_module.weight.data = module.weight.data + new_module.bias.data = module.bias.data + setattr(model, name, new_module) + else: + _replace_bn_layers(module, alpha) diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py new file mode 100644 index 00000000..9c7908cc --- /dev/null +++ b/torch_uncertainty/post_processing/abstract.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +from torch import Tensor, nn +from torch.utils.data import Dataset + + +class PostProcessing(ABC, nn.Module): + def __init__(self, model: nn.Module | None = None): + super().__init__() + self.model = model + self.trained = False + + def set_model(self, model: nn.Module) -> None: + self.model = model + + @abstractmethod + def fit(self, dataset: Dataset) -> None: + pass + + @abstractmethod + def forward( + self, + x: Tensor, + ) -> Tensor: + pass diff --git a/torch_uncertainty/post_processing/calibration/matrix_scaler.py b/torch_uncertainty/post_processing/calibration/matrix_scaler.py index 1899dcbe..a0b2c86e 100644 --- a/torch_uncertainty/post_processing/calibration/matrix_scaler.py +++ b/torch_uncertainty/post_processing/calibration/matrix_scaler.py @@ -9,8 +9,8 @@ class MatrixScaler(Scaler): def __init__( self, - model: nn.Module, num_classes: int, + model: nn.Module | None = None, init_w: float = 1, init_b: float = 0, lr: float = 0.1, diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index d87730b9..d3400dfe 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -1,21 +1,23 @@ from typing import Literal import torch -from torch import Tensor, device, nn, optim +from torch import Tensor, nn, optim from torch.utils.data import DataLoader, Dataset from tqdm import tqdm +from torch_uncertainty.post_processing import PostProcessing -class Scaler(nn.Module): + +class Scaler(PostProcessing): criterion = nn.CrossEntropyLoss() trained = False def __init__( self, - model: nn.Module, + model: nn.Module | None = None, lr: float = 0.1, max_iter: int = 100, - device: Literal["cpu", "cuda"] | device | None = None, + device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Virtual class for scaling post-processing for calibrated probabilities. @@ -31,8 +33,7 @@ def __init__( Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. On calibration of modern neural networks. In ICML 2017. """ - super().__init__() - self.model = model + super().__init__(model) self.device = device if lr <= 0: @@ -48,7 +49,7 @@ def fit( calibration_set: Dataset, save_logits: bool = False, progress: bool = True, - ) -> "Scaler": + ) -> None: """Fit the temperature parameters to the calibration data. Args: @@ -57,9 +58,6 @@ def fit( labels. Defaults to False. progress (bool, optional): Whether to show a progress bar. Defaults to True. - - Returns: - Scaler: Calibrated scaler. """ logits_list = [] labels_list = [] @@ -89,7 +87,6 @@ def calib_eval() -> float: if save_logits: self.logits = logits self.labels = labels - return self @torch.no_grad() def forward(self, inputs: Tensor) -> Tensor: diff --git a/torch_uncertainty/post_processing/calibration/temperature_scaler.py b/torch_uncertainty/post_processing/calibration/temperature_scaler.py index cfd50084..f334cbab 100644 --- a/torch_uncertainty/post_processing/calibration/temperature_scaler.py +++ b/torch_uncertainty/post_processing/calibration/temperature_scaler.py @@ -1,7 +1,7 @@ from typing import Literal import torch -from torch import Tensor, device, nn +from torch import Tensor, nn from .scaler import Scaler @@ -9,11 +9,11 @@ class TemperatureScaler(Scaler): def __init__( self, - model: nn.Module, + model: nn.Module | None = None, init_val: float = 1, lr: float = 0.1, max_iter: int = 100, - device: Literal["cpu", "cuda"] | device | None = None, + device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Temperature scaling post-processing for calibrated probabilities. diff --git a/torch_uncertainty/post_processing/calibration/vector_scaler.py b/torch_uncertainty/post_processing/calibration/vector_scaler.py index 875945c0..53ce9551 100644 --- a/torch_uncertainty/post_processing/calibration/vector_scaler.py +++ b/torch_uncertainty/post_processing/calibration/vector_scaler.py @@ -9,8 +9,8 @@ class VectorScaler(Scaler): def __init__( self, - model: nn.Module, num_classes: int, + model: nn.Module | None = None, init_w: float = 1, init_b: float = 0, lr: float = 0.1, diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py new file mode 100644 index 00000000..e0b1203b --- /dev/null +++ b/torch_uncertainty/post_processing/laplace.py @@ -0,0 +1,85 @@ +from importlib import util +from typing import Literal + +from torch import Tensor, nn +from torch.utils.data import DataLoader, Dataset + +if util.find_spec("laplace"): + from laplace import Laplace + + laplace_installed = True + + +class LaplaceApprox(nn.Module): + def __init__( + self, + task: Literal["classification", "regression"], + model: nn.Module | None = None, + weight_subset="last_layer", + hessian_struct="kron", + pred_type: Literal["glm", "nn"] = "glm", + link_approx: Literal[ + "mc", "probit", "bridge", "bridge_norm" + ] = "probit", + batch_size: int = 256, + ) -> None: + """Laplace approximation for uncertainty estimation. + + This class is a wrapper of Laplace classes from the laplace-torch library. + + Args: + task (Literal["classification", "regression"]): task type. + model (nn.Module): model to be converted. + weight_subset (str): subset of weights to be considered. Defaults to + "last_layer". + hessian_struct (str): structure of the Hessian matrix. Defaults to + "kron". + pred_type (Literal["glm", "nn"], optional): type of posterior predictive, + See the Laplace library for more details. Defaults to "glm". + link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): + how to approximate the classification link function for the `'glm'`. + See the Laplace library for more details. Defaults to "probit". + batch_size (int, optional): batch size for the Laplace approximation. + Defaults to 256. + + Reference: + Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. + """ + super().__init__() + if not laplace_installed: # coverage: ignore + raise ImportError( + "The laplace-torch library is not installed. Please install it via `pip install laplace-torch`." + ) + + self.pred_type = pred_type + self.link_approx = link_approx + self.task = task + self.weight_subset = weight_subset + self.hessian_struct = hessian_struct + self.batch_size = batch_size + + if model is not None: + self._setup_model(model) + + def _setup_model(self, model) -> None: + self.la = Laplace( + model=model, + likelihood=self.task, + subset_of_weights=self.weight_subset, + hessian_structure=self.hessian_struct, + ) + + def set_model(self, model: nn.Module) -> None: + self._setup_model(model) + + def fit(self, dataset: Dataset) -> None: + dl = DataLoader(dataset, batch_size=self.batch_size) + self.la.fit(train_loader=dl) + + def forward( + self, + x: Tensor, + ) -> Tensor: + return self.la( + x, pred_type=self.pred_type, link_approx=self.link_approx + ) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index d99fdd7c..b011a058 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -6,18 +6,19 @@ from torch.utils.data import DataLoader, Dataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d +from torch_uncertainty.post_processing import PostProcessing -class MCBatchNorm(nn.Module): +class MCBatchNorm(PostProcessing): counter: int = 0 mc_batch_norm_layers: list[MCBatchNorm2d] = [] trained = False def __init__( self, - model: nn.Module, - num_estimators: int, - convert: bool, + model: nn.Module | None = None, + num_estimators: int = 16, + convert: bool = True, mc_batch_size: int = 32, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: @@ -39,29 +40,31 @@ def __init__( batch normalized deep networks. In ICML 2018. """ super().__init__() - self.mc_batch_size = mc_batch_size - if num_estimators < 1 or not isinstance(num_estimators, int): - raise ValueError( - f"num_estimators must be a positive integer, got {num_estimators}." - ) + self.convert = convert self.num_estimators = num_estimators - - self.model = deepcopy(model) - if not convert and not self._has_mcbn(): - raise ValueError( - "model does not contain any MCBatchNorm2d nor is not to be " - "converted." - ) self.device = device + + if model is not None: + self._setup_model(model) + + def _setup_model(self, model): + _mcbn_checks( + model, self.num_estimators, self.mc_batch_size, self.convert + ) + self.model = deepcopy(model) # Is it necessary? self.model = self.model.eval() - if convert: + if self.convert: self._convert() - if not self._has_mcbn(): + if not has_mcbn(self.model): raise ValueError( "model does not contain any MCBatchNorm2d after conversion." ) + def set_model(self, model: nn.Module) -> None: + self.model = model + self._setup_model(model) + def fit(self, dataset: Dataset) -> None: """Fit the model on the dataset. @@ -99,7 +102,7 @@ def _est_forward(self, x: Tensor) -> Tensor: def forward( self, x: Tensor, - ) -> tuple[Tensor, Tensor]: + ) -> Tensor: if self.training: return self.model(x) if not self.trained: @@ -111,13 +114,6 @@ def forward( [self._est_forward(x) for _ in range(self.num_estimators)], dim=0 ) - def _has_mcbn(self) -> bool: - """Check if the model contains any MCBatchNorm2d layers.""" - for module in self.model.modules(): - if isinstance(module, MCBatchNorm2d): - return True - return False - def _convert(self) -> None: """Convert all BatchNorm2d layers to MCBatchNorm2d layers.""" self.replace_layers(self.model) @@ -171,3 +167,24 @@ def replace_layers(self, model: nn.Module) -> None: # Save pointers to the MC BatchNorm layers self.mc_batch_norm_layers.append(mc_layer) + + +def has_mcbn(model: nn.Module) -> bool: + """Check if the model contains any MCBatchNorm2d layers.""" + return any(isinstance(module, MCBatchNorm2d) for module in model.modules()) + + +def _mcbn_checks(model, num_estimators, mc_batch_size, convert): + if num_estimators < 1 or not isinstance(num_estimators, int): + raise ValueError( + f"num_estimators must be a positive integer, got {num_estimators}." + ) + if mc_batch_size < 1 or not isinstance(mc_batch_size, int): + raise ValueError( + f"mc_batch_size must be a positive integer, got {mc_batch_size}." + ) + if not convert and not has_mcbn(model): + raise ValueError( + "model does not contain any MCBatchNorm2d nor is not to be " + "converted." + ) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 0c07b8ab..2d45976a 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -33,10 +33,30 @@ RiskAt80Cov, VariationRatio, ) -from torch_uncertainty.post_processing import TemperatureScaler -from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, +) +from torch_uncertainty.post_processing import PostProcessing +from torch_uncertainty.transforms import ( + Mixup, + MixupIO, + RegMixup, + RepeatTarget, + WarpingMixup, +) from torch_uncertainty.utils import csv_writer, plot_hist +MIXUP_PARAMS = { + "mixtype": "erm", + "mixmode": "elem", + "dist_sim": "emb", + "kernel_tau_max": 1.0, + "kernel_tau_std": 0.5, + "mixup_alpha": 0, + "cutmix_alpha": 0, +} + class ClassificationRoutine(LightningModule): def __init__( @@ -44,74 +64,62 @@ def __init__( model: nn.Module, num_classes: int, loss: nn.Module, - num_estimators: int = 1, + is_ensemble: bool = False, format_batch_fn: nn.Module | None = None, optim_recipe: dict | Optimizer | None = None, - mixtype: str = "erm", - mixmode: str = "elem", - dist_sim: str = "emb", - kernel_tau_max: float = 1.0, - kernel_tau_std: float = 0.5, - mixup_alpha: float = 0, - cutmix_alpha: float = 0, + mixup_params: dict | None = None, eval_ood: bool = False, eval_grouping_loss: bool = False, ood_criterion: Literal[ "msp", "logit", "energy", "entropy", "mi", "vr" ] = "msp", + post_processing: PostProcessing | None = None, + calibration_set: Literal["val", "test"] = "val", + num_calibration_bins: int = 15, log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, - num_calibration_bins: int = 15, ) -> None: - r"""Routine for training & testing on **classification tasks**. + r"""Routine for training & testing on **classification** tasks. Args: model (torch.nn.Module): Model to train. num_classes (int): Number of classes. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. - num_estimators (int, optional): Number of estimators for the - ensemble. Defaults to ``1`` (single model). + is_ensemble (bool, optional): Indicates whether the model is an + ensemble at test time or not. Defaults to ``False``. format_batch_fn (torch.nn.Module, optional): Function to format the batch. Defaults to :class:`torch.nn.Identity()`. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. - mixtype (str, optional): Mixup type. Defaults to ``"erm"``. - mixmode (str, optional): Mixup mode. Defaults to ``"elem"``. - dist_sim (str, optional): Distance similarity. Defaults to ``"emb"``. - kernel_tau_max (float, optional): Maximum value for the kernel tau. - Defaults to ``1.0``. - kernel_tau_std (float, optional): Standard deviation for the kernel tau. - Defaults to ``0.5``. - mixup_alpha (float, optional): Alpha parameter for Mixup. Defaults to ``0``. - cutmix_alpha (float, optional): Alpha parameter for Cutmix. - Defaults to ``0``. + mixup_params (dict, optional): Mixup parameters. Can include mixup type, + mixup mode, distance similarity, kernel tau max, kernel tau std, + mixup alpha, and cutmix alpha. If None, no augmentations. + Defaults to ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection performance or not. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. ood_criterion (str, optional): OOD criterion. Available options are - - ``"msp"`` (default): Maximum softmax probability. - ``"logit"``: Maximum logit. - ``"energy"``: Logsumexp of the mean logits. - ``"entropy"``: Entropy of the mean prediction. - ``"mi"``: Mutual information of the ensemble. - ``"vr"``: Variation ratio of the ensemble. - + post_processing (PostProcessing, optional): Post-processing method + to train on the calibration set. No post-processing if None. + Defaults to ``None``. + calibration_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. + num_calibration_bins (int, optional): Number of bins to compute calibration + metrics. Defaults to ``15``. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. save_in_csv(bool, optional): Save the results in csv. Defaults to ``False``. - calibration_set (str, optional): The post-hoc calibration dataset to - use for scaling. If not ``None``, it uses either the validation - set when set to ``"val"`` or the test set when set to ``"test"``. - Defaults to ``None``. Else, no post-hoc calibration. - num_calibration_bins (int, optional): Number of bins to compute calibration - metrics. Defaults to ``15``. Warning: - You must define :attr:`optim_recipe` if you do not use the CLI. + You must define :attr:`optim_recipe` if you do not use the Lightning CLI. Note: :attr:`optim_recipe` can be anything that can be returned by @@ -122,17 +130,19 @@ def __init__( _classification_routine_checks( model=model, num_classes=num_classes, - num_estimators=num_estimators, + is_ensemble=is_ensemble, ood_criterion=ood_criterion, eval_grouping_loss=eval_grouping_loss, num_calibration_bins=num_calibration_bins, + mixup_params=mixup_params, + post_processing=post_processing, + format_batch_fn=format_batch_fn, ) if format_batch_fn is None: format_batch_fn = nn.Identity() self.num_classes = num_classes - self.num_estimators = num_estimators self.eval_ood = eval_ood self.eval_grouping_loss = eval_grouping_loss self.ood_criterion = ood_criterion @@ -140,30 +150,48 @@ def __init__( self.save_in_csv = save_in_csv self.calibration_set = calibration_set self.binary_cls = num_classes == 1 - + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) + self.num_calibration_bins = num_calibration_bins self.model = model self.loss = loss self.format_batch_fn = format_batch_fn self.optim_recipe = optim_recipe + self.is_ensemble = is_ensemble + + self.post_processing = post_processing + if self.post_processing is not None: + self.post_processing.set_model(self.model) + + self._init_metrics() + self.mixup = self._init_mixup(mixup_params) + + self.is_elbo = isinstance(self.loss, ELBOLoss) + if self.is_elbo: + self.loss.set_model(self.model) + self.is_dec = isinstance(self.loss, DECLoss) - # metrics + self.id_logit_storage = None + self.ood_logit_storage = None + + def _init_metrics(self) -> None: task = "binary" if self.binary_cls else "multiclass" cls_metrics = MetricCollection( { - "cls/Acc": Accuracy(task=task, num_classes=num_classes), - "cls/Brier": BrierScore(num_classes=num_classes), + "cls/Acc": Accuracy(task=task, num_classes=self.num_classes), + "cls/Brier": BrierScore(num_classes=self.num_classes), "cls/NLL": CategoricalNLL(), "cal/ECE": CalibrationError( task=task, - num_bins=num_calibration_bins, - num_classes=num_classes, + num_bins=self.num_calibration_bins, + num_classes=self.num_classes, ), "cal/aECE": CalibrationError( task=task, adaptive=True, - num_bins=num_calibration_bins, - num_classes=num_classes, + num_bins=self.num_calibration_bins, + num_classes=self.num_classes, ), "sc/AURC": AURC(), "sc/CovAt5Risk": CovAt5Risk(), @@ -181,7 +209,7 @@ def __init__( self.val_cls_metrics = cls_metrics.clone(prefix="val/") self.test_cls_metrics = cls_metrics.clone(prefix="test/") - if self.calibration_set is not None: + if self.post_processing is not None: self.ts_cls_metrics = cls_metrics.clone(prefix="test/ts_") self.test_id_entropy = Entropy() @@ -199,7 +227,7 @@ def __init__( self.test_ood_entropy = Entropy() # metrics for ensembles only - if self.num_estimators > 1: + if self.is_ensemble: ens_metrics = MetricCollection( { "Disagreement": Disagreement(), @@ -213,79 +241,76 @@ def __init__( if self.eval_ood: self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") - # Mixup - self.mixtype = mixtype - self.mixmode = mixmode - self.dist_sim = dist_sim - if num_estimators == 1: - if mixup_alpha < 0 or cutmix_alpha < 0: - raise ValueError( - "Cutmix alpha and Mixup alpha must be positive." - f"Got {mixup_alpha} and {cutmix_alpha}." - ) - - self.mixup = self.init_mixup( - mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std + if self.eval_grouping_loss: + grouping_loss = MetricCollection( + {"cls/grouping_loss": GroupingLoss()} ) + self.val_grouping_loss = grouping_loss.clone(prefix="val/") + self.test_grouping_loss = grouping_loss.clone(prefix="test/") - if self.eval_grouping_loss: - grouping_loss = MetricCollection( - {"cls/grouping_loss": GroupingLoss()} - ) - self.val_grouping_loss = grouping_loss.clone(prefix="val/") - self.test_grouping_loss = grouping_loss.clone(prefix="test/") + def _init_mixup(self, mixup_params: dict | None) -> Callable: + if mixup_params is None: + mixup_params = {} + mixup_params = MIXUP_PARAMS | mixup_params + self.mixup_params = mixup_params - self.is_elbo = isinstance(self.loss, ELBOLoss) - if self.is_elbo: - self.loss.set_model(self.model) - self.is_dec = isinstance(self.loss, DECLoss) - - self.id_logit_storage = None - self.ood_logit_storage = None + if mixup_params["mixup_alpha"] < 0 or mixup_params["cutmix_alpha"] < 0: + raise ValueError( + "Cutmix alpha and Mixup alpha must be positive." + f"Got {mixup_params['mixup_alpha']} and {mixup_params['cutmix_alpha']}." + ) - def init_mixup( - self, - mixup_alpha: float, - cutmix_alpha: float, - kernel_tau_max: float, - kernel_tau_std: float, - ) -> Callable: - if self.mixtype == "timm": + if mixup_params["mixtype"] == "timm": return timm_Mixup( - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, - mode=self.mixmode, + mixup_alpha=mixup_params["mixup_alpha"], + cutmix_alpha=mixup_params["cutmix_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "mixup": + if mixup_params["mixtype"] == "mixup": return Mixup( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "mixup_io": + if mixup_params["mixtype"] == "mixup_io": return MixupIO( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "regmixup": + if mixup_params["mixtype"] == "regmixup": return RegMixup( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, ) - if self.mixtype == "kernel_warping": + if mixup_params["mixtype"] == "kernel_warping": return WarpingMixup( - alpha=mixup_alpha, - mode=self.mixmode, + alpha=mixup_params["mixup_alpha"], + mode=mixup_params["mixmode"], num_classes=self.num_classes, apply_kernel=True, - tau_max=kernel_tau_max, - tau_std=kernel_tau_std, + tau_max=mixup_params["kernel_tau_max"], + tau_std=mixup_params["kernel_tau_std"], ) return Identity() + def _apply_mixup( + self, batch: tuple[Tensor, Tensor] + ) -> tuple[Tensor, Tensor]: + if not self.is_ensemble: + if self.mixup_params["mixtype"] == "kernel_warping": + if self.mixup_params["dist_sim"] == "emb": + with torch.no_grad(): + feats = self.model.feats_forward(batch[0]).detach() + batch = self.mixup(*batch, feats) + else: # self.mixup_params["dist_sim"] == "inp": + batch = self.mixup(*batch, batch[0]) + else: + batch = self.mixup(*batch) + return batch + def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -295,27 +320,33 @@ def on_train_start(self) -> None: self.hparams, ) + def on_validation_start(self) -> None: + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + def on_test_start(self) -> None: - if isinstance(self.calibration_set, str) and self.calibration_set in [ - "val", - "test", - ]: + if self.post_processing is not None: calibration_dataset = ( self.trainer.datamodule.val_dataloader().dataset if self.calibration_set == "val" else self.trainer.datamodule.test_dataloader()[0].dataset ) with torch.inference_mode(False): - self.cal_model = TemperatureScaler( - model=self.model, device=self.device - ).fit(calibration_dataset) - else: - self.cal_model = None + self.post_processing.fit(calibration_dataset) if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = [] self.ood_logit_storage = [] + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: """Forward pass of the model. @@ -341,19 +372,7 @@ def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: - # Mixup only for single models - if self.num_estimators == 1: - if self.mixtype == "kernel_warping": - if self.dist_sim == "emb": - with torch.no_grad(): - feats = self.model.feats_forward(batch[0]).detach() - - batch = self.mixup(*batch, feats) - elif self.dist_sim == "inp": - batch = self.mixup(*batch, batch[0]) - else: - batch = self.mixup(*batch) - + batch = self._apply_mixup(batch) inputs, target = self.format_batch_fn(batch) if self.is_elbo: @@ -369,18 +388,17 @@ def training_step( loss = self.loss(logits, target) else: loss = self.loss(logits, target, self.current_epoch) - + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - inputs, target = batch - logits = self.forward( - inputs, save_feats=self.eval_grouping_loss - ) # (m*b, c) - logits = rearrange(logits, "(m b) c -> b m c", m=self.num_estimators) + inputs, targets = batch + logits = self.forward(inputs, save_feats=self.eval_grouping_loss) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) if self.binary_cls: probs_per_est = torch.sigmoid(logits).squeeze(-1) @@ -388,10 +406,10 @@ def validation_step( probs_per_est = F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) - self.val_cls_metrics.update(probs, target) + self.val_cls_metrics.update(probs, targets) if self.eval_grouping_loss: - self.val_grouping_loss.update(probs, target, self.features) + self.val_grouping_loss.update(probs, targets, self.features) def test_step( self, @@ -399,18 +417,9 @@ def test_step( batch_idx: int, dataloader_idx: int = 0, ) -> None: - inputs, target = batch - logits = self.forward( - inputs, save_feats=self.eval_grouping_loss - ) # (m*b, c) - if logits.size(0) % self.num_estimators != 0: # coverage: ignore - raise ValueError( - f"The number of predicted samples {logits.size(0)} is not " - "divisible by the reported number of estimators " - f"{self.num_estimators} of the routine. Please check the " - "correspondence between these values." - ) - logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators) + inputs, targets = batch + logits = self.forward(inputs, save_feats=self.eval_grouping_loss) + logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) if self.binary_cls: probs_per_est = torch.sigmoid(logits) @@ -438,20 +447,19 @@ def test_step( else: ood_scores = -confs - # Scaling for single models - if self.num_estimators == 1 and self.cal_model is not None: - cal_logits = self.cal_model(inputs) - cal_probs = F.softmax(cal_logits, dim=-1) - self.ts_cls_metrics.update(cal_probs, target) + if self.post_processing is not None: + pp_logits = self.post_processing(inputs) + pp_probs = F.softmax(pp_logits, dim=-1) + self.ts_cls_metrics.update(pp_probs, targets) if dataloader_idx == 0: # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( probs.squeeze(-1) if self.binary_cls else probs, - target, + targets, ) if self.eval_grouping_loss: - self.test_grouping_loss.update(probs, target, self.features) + self.test_grouping_loss.update(probs, targets, self.features) self.log_dict( self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False @@ -464,19 +472,19 @@ def test_step( add_dataloader_idx=False, ) - if self.num_estimators > 1: + if self.is_ensemble: self.test_id_ens_metrics.update(probs_per_est) if self.eval_ood: self.test_ood_metrics.update( - ood_scores, torch.zeros_like(target) + ood_scores, torch.zeros_like(targets) ) if self.id_logit_storage is not None: self.id_logit_storage.append(logits.detach().cpu()) elif self.eval_ood and dataloader_idx == 1: - self.test_ood_metrics.update(ood_scores, torch.ones_like(target)) + self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_ood_entropy(probs) self.log( "ood/Entropy", @@ -484,7 +492,7 @@ def test_step( on_epoch=True, add_dataloader_idx=False, ) - if self.num_estimators > 1: + if self.is_ensemble: self.test_ood_ens_metrics.update(probs_per_est) if self.ood_logit_storage is not None: @@ -507,11 +515,7 @@ def on_test_epoch_end(self) -> None: {"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True ) - if ( - self.num_estimators == 1 - and self.calibration_set is not None - and self.cal_model is not None - ): + if self.post_processing is not None: tmp_metrics = self.ts_cls_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -522,7 +526,7 @@ def on_test_epoch_end(self) -> None: sync_dist=True, ) - if self.num_estimators > 1: + if self.is_ensemble: tmp_metrics = self.test_id_ens_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -535,7 +539,7 @@ def on_test_epoch_end(self) -> None: # already logged result_dict.update({"ood/Entropy": self.test_ood_entropy.compute()}) - if self.num_estimators > 1: + if self.is_ensemble: tmp_metrics = self.test_ood_ens_metrics.compute() self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -549,7 +553,7 @@ def on_test_epoch_end(self) -> None: self.test_cls_metrics["sc/AURC"].plot()[0], ) - if self.cal_model is not None: + if self.post_processing is not None: self.logger.experiment.add_figure( "Reliabity diagram after calibration", self.ts_cls_metrics["cal/ECE"].plot()[0], @@ -598,17 +602,14 @@ def save_results_to_csv(self, results: dict[str, float]) -> None: def _classification_routine_checks( model: nn.Module, num_classes: int, - num_estimators: int, + is_ensemble: bool, ood_criterion: str, eval_grouping_loss: bool, num_calibration_bins: int, + mixup_params: dict | None, + post_processing: PostProcessing | None, + format_batch_fn: nn.Module | None, ) -> None: - if not isinstance(num_estimators, int) or num_estimators < 1: - raise ValueError( - "The number of estimators must be a positive integer >= 1." - f"Got {num_estimators}." - ) - if ood_criterion not in [ "msp", "logit", @@ -622,13 +623,13 @@ def _classification_routine_checks( f" 'mi' or 'vr'. Got {ood_criterion}." ) - if num_estimators == 1 and ood_criterion in ["mi", "vr"]: + if not is_ensemble and ood_criterion in ["mi", "vr"]: raise ValueError( "You cannot use mutual information or variation ratio with a single" " model." ) - if num_estimators != 1 and eval_grouping_loss: + if is_ensemble and eval_grouping_loss: raise NotImplementedError( "Groupng loss for ensembles is not yet implemented. Raise an issue if needed." ) @@ -657,3 +658,13 @@ def _classification_routine_checks( raise ValueError( f"num_calibration_bins must be at least 2, got {num_calibration_bins}." ) + + if mixup_params is not None and isinstance(format_batch_fn, RepeatTarget): + raise ValueError( + "Mixup is not supported for ensembles at training time. Please set mixup_params to None." + ) + + if post_processing is not None and is_ensemble: + raise ValueError( + "Ensembles and post-processing methods cannot be used together. Raise an issue if needed." + ) diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index d729e11b..b9762ffc 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -29,7 +29,15 @@ SILog, ThresholdAccuracy, ) -from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, +) +from torch_uncertainty.utils.distributions import ( + dist_rearrange, + dist_size, + dist_squeeze, +) class PixelRegressionRoutine(LightningModule): @@ -44,21 +52,44 @@ def __init__( output_dim: int, probabilistic: bool, loss: nn.Module, - num_estimators: int = 1, - optim_recipe: dict | Optimizer | None = None, + is_ensemble: bool = False, format_batch_fn: nn.Module | None = None, + optim_recipe: dict | Optimizer | None = None, num_image_plot: int = 4, + log_plots: bool = False, ) -> None: + """Routine for training & testing on **pixel regression** tasks. + + Args: + model (nn.Module): Model to train. + output_dim (int): Number of outputs of the model. + probabilistic (bool): Whether the model is probabilistic, i.e., + outputs a PyTorch distribution. + loss (nn.Module): Loss function to optimize the :attr:`model`. + is_ensemble (bool, optional): Whether the model is an ensemble. + Defaults to ``False``. + optim_recipe (dict or Optimizer, optional): The optimizer and + optionally the scheduler to use. Defaults to ``None``. + format_batch_fn (nn.Module, optional): The function to format the + batch. Defaults to ``None``. + num_image_plot (int, optional): Number of images to plot. Defaults to ``4``. + log_plots (bool, optional): Indicates whether to log plots from + metrics. Defaults to ``False``. + """ super().__init__() - _depth_routine_checks(num_estimators, output_dim) + _depth_routine_checks(output_dim, num_image_plot, log_plots) self.model = model self.output_dim = output_dim self.one_dim_depth = output_dim == 1 self.probabilistic = probabilistic self.loss = loss - self.num_estimators = num_estimators self.num_image_plot = num_image_plot + self.is_ensemble = is_ensemble + self.log_plots = log_plots + + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -102,6 +133,20 @@ def on_train_start(self) -> None: self.hparams, ) + def on_validation_start(self) -> None: + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + + def on_test_start(self) -> None: + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. @@ -116,10 +161,10 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: """ pred = self.model(inputs) if self.probabilistic: - if self.num_estimators == 1: - pred = squeeze_dist(pred, -1) + if not self.is_ensemble: + pred = dist_squeeze(pred, -1) else: - if self.num_estimators == 1: + if not self.is_ensemble: pred = pred.squeeze(-1) return pred @@ -131,54 +176,64 @@ def training_step( target = target.unsqueeze(1) dists = self.model(inputs) + if self.probabilistic: + out_shape = dist_size(dists)[-2:] + else: + out_shape = dists.shape[-2:] target = F.resize( - target, dists.shape[-2:], interpolation=F.InterpolationMode.NEAREST + target, out_shape, interpolation=F.InterpolationMode.NEAREST ) - valid_mask = ~torch.isnan(target) - loss = self.loss(dists[valid_mask], target[valid_mask]) + padding_mask = torch.isnan(target) + if self.probabilistic: + loss = self.loss(dists, target, padding_mask) + else: + loss = self.loss(dists[padding_mask], target[padding_mask]) + + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - inputs, target = batch + inputs, targets = batch if self.one_dim_depth: - target = target.unsqueeze(1) + targets = targets.unsqueeze(1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c h w -> (b c h w)") preds = self.model(inputs) if self.probabilistic: ens_dist = Independent( dist_rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators + preds, "(m b) c h w -> (b c h w) m", b=batch_size ), - 1, + 0, ) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones( + (dist_size(preds)[0] // batch_size), device=self.device + ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators - ) + preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) - if batch_idx == 0: + if batch_idx == 0 and self.log_plots: self._plot_depth( inputs[: self.num_image_plot, ...], preds[: self.num_image_plot, ...], - target[: self.num_image_plot, ...], + targets[: self.num_image_plot, ...], stage="val", ) - valid_mask = ~torch.isnan(target) - self.val_metrics.update(preds[valid_mask], target[valid_mask]) + padding_mask = torch.isnan(targets) + self.val_metrics.update(preds[padding_mask], targets[padding_mask]) if self.probabilistic: - self.val_prob_metrics.update( - mixture[valid_mask], target[valid_mask] - ) + self.val_prob_metrics.update(mixture, targets, padding_mask) def test_step( self, @@ -191,29 +246,29 @@ def test_step( "Depth OOD detection not implemented yet. Raise an issue " "if needed." ) - - inputs, target = batch + inputs, targets = batch if self.one_dim_depth: - target = target.unsqueeze(1) + targets = targets.unsqueeze(1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c h w -> (b c h w)") preds = self.model(inputs) if self.probabilistic: ens_dist = dist_rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators + preds, "(m b) c h w -> (b c h w) m", b=batch_size ) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones( + (dist_size(preds)[0] // batch_size), device=self.device + ) ) mixture = MixtureSameFamily(mix, ens_dist) - self.test_metrics.nll.update(mixture, target) preds = mixture.mean else: - preds = rearrange( - preds, "(m b) c h w -> b m c h w", m=self.num_estimators - ) + preds = rearrange(preds, "(m b) c h w -> (b c h w) m", b=batch_size) preds = preds.mean(dim=1) - if batch_idx == 0: + if batch_idx == 0 and self.log_plots: num_images = ( self.num_image_plot if self.num_image_plot < inputs.size(0) @@ -222,16 +277,14 @@ def test_step( self._plot_depth( inputs[:num_images, ...], preds[:num_images, ...], - target[:num_images, ...], + targets[:num_images, ...], stage="test", ) - valid_mask = ~torch.isnan(target) - self.test_metrics.update(preds[valid_mask], target[valid_mask]) + padding_mask = torch.isnan(targets) + self.test_metrics.update(preds[padding_mask], targets[padding_mask]) if self.probabilistic: - self.test_prob_metrics.update( - mixture[valid_mask], target[valid_mask] - ) + self.test_prob_metrics.update(mixture, targets, padding_mask) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_metrics.compute(), sync_dist=True) @@ -311,11 +364,12 @@ def colorize( return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0 -def _depth_routine_checks(num_estimators: int, output_dim: int) -> None: - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - +def _depth_routine_checks( + output_dim: int, num_image_plot: int, log_plots: bool +) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") + if num_image_plot < 1 and log_plots: + raise ValueError( + f"num_image_plot must be positive, got {num_image_plot}." + ) diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 55998518..2beeb435 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -6,7 +6,6 @@ from torch.distributions import ( Categorical, Distribution, - Independent, MixtureSameFamily, ) from torch.optim import Optimizer @@ -15,7 +14,15 @@ from torch_uncertainty.metrics import ( DistributionNLL, ) -from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, +) +from torch_uncertainty.utils.distributions import ( + dist_rearrange, + dist_size, + dist_squeeze, +) class RegressionRoutine(LightningModule): @@ -25,11 +32,11 @@ def __init__( output_dim: int, probabilistic: bool, loss: nn.Module, - num_estimators: int = 1, + is_ensemble: bool = False, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, ) -> None: - r"""Routine for training & testing on **regression tasks**. + r"""Routine for training & testing on **regression** tasks. Args: model (torch.nn.Module): Model to train. @@ -37,8 +44,8 @@ def __init__( probabilistic (bool): Whether the model is probabilistic, i.e., outputs a PyTorch distribution. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. - num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to ``1`` (single model). + is_ensemble (bool, optional): Whether the model is an ensemble. + Defaults to ``False``. optim_recipe (dict or torch.optim.Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. format_batch_fn (torch.nn.Module, optional): The function to format the @@ -58,13 +65,15 @@ def __init__( `here `_. """ super().__init__() - _regression_routine_checks(num_estimators, output_dim) + _regression_routine_checks(output_dim) self.model = model self.probabilistic = probabilistic self.output_dim = output_dim self.loss = loss - self.num_estimators = num_estimators + self.is_ensemble = is_ensemble + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -102,6 +111,20 @@ def on_train_start(self) -> None: self.hparams, ) + def on_validation_start(self) -> None: + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + + def on_test_start(self) -> None: + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. @@ -117,13 +140,13 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = self.model(inputs) if self.probabilistic: if self.one_dim_regression: - pred = squeeze_dist(pred, -1) - if self.num_estimators == 1: - pred = squeeze_dist(pred, -1) + pred = dist_squeeze(pred, -1) + if not self.is_ensemble: + pred = dist_squeeze(pred, -1) else: if self.one_dim_regression: pred = pred.squeeze(-1) - if self.num_estimators == 1: + if not self.is_ensemble: pred = pred.squeeze(-1) return pred @@ -137,6 +160,8 @@ def training_step( targets = targets.unsqueeze(-1) loss = self.loss(dists, targets) + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss @@ -146,22 +171,22 @@ def validation_step( inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c -> (b c)") preds = self.model(inputs) if self.probabilistic: - ens_dist = Independent( - dist_rearrange( - preds, "(m b) c -> b m c", m=self.num_estimators - ), - 1, - ) + ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones( + dist_size(preds)[0] // batch_size, device=self.device + ) ) + print(ens_dist, type(ens_dist)) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size) preds = preds.mean(dim=1) self.val_metrics.update(preds, targets) @@ -183,22 +208,21 @@ def test_step( inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) + batch_size = targets.size(0) + targets = rearrange(targets, "b c -> (b c)") preds = self.model(inputs) if self.probabilistic: - ens_dist = Independent( - dist_rearrange( - preds, "(m b) c -> b m c", m=self.num_estimators - ), - 1, - ) + ens_dist = dist_rearrange(preds, "(m b) c -> (b c) m", b=batch_size) mix = Categorical( - torch.ones(self.num_estimators, device=self.device) + torch.ones( + dist_size(preds)[0] // batch_size, device=self.device + ) ) mixture = MixtureSameFamily(mix, ens_dist) preds = mixture.mean else: - preds = rearrange(preds, "(m b) c -> b m c", m=self.num_estimators) + preds = rearrange(preds, "(m b) c -> (b c) m", b=batch_size) preds = preds.mean(dim=1) self.test_metrics.update(preds, targets) @@ -225,11 +249,6 @@ def on_test_epoch_end(self) -> None: self.test_prob_metrics.reset() -def _regression_routine_checks(num_estimators: int, output_dim: int) -> None: - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - +def _regression_routine_checks(output_dim: int) -> None: if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 6e1e11e2..f3ece492 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -15,6 +15,10 @@ CategoricalNLL, MeanIntersectionOverUnion, ) +from torch_uncertainty.models import ( + EPOCH_UPDATE_MODEL, + STEP_UPDATE_MODEL, +) class SegmentationRoutine(LightningModule): @@ -23,21 +27,18 @@ def __init__( model: nn.Module, num_classes: int, loss: nn.Module, - num_estimators: int = 1, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, metric_subsampling_rate: float = 1e-2, log_plots: bool = False, num_calibration_bins: int = 15, ) -> None: - """Routine for training & testing on segmentation tasks. + r"""Routine for training & testing on **segmentation** tasks. Args: model (torch.nn.Module): Model to train. num_classes (int): Number of classes in the segmentation task. loss (torch.nn.Module): Loss function to optimize the :attr:`model`. - num_estimators (int, optional): The number of estimators for the - ensemble. Defaults to ̀`1` (single model). optim_recipe (dict or Optimizer, optional): The optimizer and optionally the scheduler to use. Defaults to ``None``. format_batch_fn (torch.nn.Module, optional): The function to format the @@ -45,7 +46,7 @@ def __init__( metric_subsampling_rate (float, optional): The rate of subsampling for the memory consuming metrics. Defaults to ``1e-2``. log_plots (bool, optional): Indicates whether to log plots from - metrics. Defaults to ``False` + metrics. Defaults to ``False``. num_calibration_bins (int, optional): Number of bins to compute calibration metrics. Defaults to ``15``. @@ -59,7 +60,6 @@ def __init__( """ super().__init__() _segmentation_routine_checks( - num_estimators, num_classes, metric_subsampling_rate, num_calibration_bins, @@ -68,7 +68,8 @@ def __init__( self.model = model self.num_classes = num_classes self.loss = loss - self.num_estimators = num_estimators + self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) + self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) if format_batch_fn is None: format_batch_fn = nn.Identity() @@ -131,6 +132,20 @@ def on_train_start(self) -> None: if self.logger is not None: # coverage: ignore self.logger.log_hyperparams(self.hparams) + def on_validation_start(self) -> None: + if self.needs_epoch_update and not self.trainer.sanity_checking: + self.model.update_wrapper(self.current_epoch) + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + + def on_test_start(self) -> None: + if hasattr(self.model, "need_bn_update"): + self.model.bn_update( + self.trainer.train_dataloader, device=self.device + ) + def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> STEP_OUTPUT: @@ -144,44 +159,50 @@ def training_step( target = target.flatten() valid_mask = target != 255 loss = self.loss(logits[valid_mask], target[valid_mask]) + if self.needs_step_update: + self.model.update_wrapper(self.current_epoch) self.log("train_loss", loss) return loss def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - img, target = batch + img, targets = batch logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + targets = F.resize( + targets, + logits.shape[-2:], + interpolation=F.InterpolationMode.NEAREST, ) logits = rearrange( - logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + logits, "(m b) c h w -> (b h w) m c", b=targets.size(0) ) probs_per_est = logits.softmax(dim=-1) probs = probs_per_est.mean(dim=1) - target = target.flatten() - valid_mask = target != 255 - probs, target = probs[valid_mask], target[valid_mask] - self.val_seg_metrics.update(probs, target) - self.val_sbsmpl_seg_metrics.update(*self.subsample(probs, target)) + targets = targets.flatten() + valid_mask = targets != 255 + probs, targets = probs[valid_mask], targets[valid_mask] + self.val_seg_metrics.update(probs, targets) + self.val_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: - img, target = batch + img, targets = batch logits = self.forward(img) - target = F.resize( - target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST + targets = F.resize( + targets, + logits.shape[-2:], + interpolation=F.InterpolationMode.NEAREST, ) logits = rearrange( - logits, "(m b) c h w -> (b h w) m c", m=self.num_estimators + logits, "(m b) c h w -> (b h w) m c", b=targets.size(0) ) probs_per_est = logits.softmax(dim=-1) probs = probs_per_est.mean(dim=1) - target = target.flatten() - valid_mask = target != 255 - probs, target = probs[valid_mask], target[valid_mask] - self.test_seg_metrics.update(probs, target) - self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, target)) + targets = targets.flatten() + valid_mask = targets != 255 + probs, targets = probs[valid_mask], targets[valid_mask] + self.test_seg_metrics.update(probs, targets) + self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) def on_validation_epoch_end(self) -> None: self.log_dict(self.val_seg_metrics.compute(), sync_dist=True) @@ -210,16 +231,10 @@ def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: def _segmentation_routine_checks( - num_estimators: int, num_classes: int, metric_subsampling_rate: float, num_calibration_bins: int, ) -> None: - if num_estimators < 1: - raise ValueError( - f"num_estimators must be positive, got {num_estimators}." - ) - if num_classes < 2: raise ValueError(f"num_classes must be at least 2, got {num_classes}.") diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 58f8532b..1bf2e669 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -3,12 +3,34 @@ import torch from einops import rearrange from torch import Tensor -from torch.distributions import Distribution, Laplace, Normal, constraints +from torch.distributions import ( + Distribution, + Laplace, + Normal, + constraints, +) from torch.distributions.utils import broadcast_all +def dist_size(distribution: Distribution) -> torch.Size: + """Get the size of the distribution. + + Args: + distribution (Distribution): The distribution. + + Returns: + torch.Size: The size of the distribution. + """ + if isinstance(distribution, Normal | Laplace | NormalInverseGamma): + return distribution.loc.size() + raise NotImplementedError( + f"Size of {type(distribution)} distributions is not supported." + "Raise an issue if needed." + ) + + def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: - r"""Concatenate a list of distributions into a single distribution. + """Concatenate a list of distributions into a single distribution. Args: distributions (list[Distribution]): The list of distributions. @@ -44,14 +66,16 @@ def cat_dist(distributions: list[Distribution], dim: int) -> Distribution: betas = torch.cat( [distribution.beta for distribution in distributions], dim=dim ) - return dist_type(loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas) + return NormalInverseGamma( + loc=locs, lmbda=lmbdas, alpha=alphas, beta=betas + ) raise NotImplementedError( f"Concatenation of {dist_type} distributions is not supported." "Raise an issue if needed." ) -def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: +def dist_squeeze(distribution: Distribution, dim: int) -> Distribution: """Squeeze the distribution along a given dimension. Args: @@ -62,16 +86,16 @@ def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: Distribution: The squeezed distribution. """ dist_type = type(distribution) - if dist_type in (Normal, Laplace): + if isinstance(distribution, Normal | Laplace): loc = distribution.loc.squeeze(dim) scale = distribution.scale.squeeze(dim) return dist_type(loc=loc, scale=scale) - if dist_type == NormalInverseGamma: + if isinstance(distribution, NormalInverseGamma): loc = distribution.loc.squeeze(dim) lmbda = distribution.lmbda.squeeze(dim) alpha = distribution.alpha.squeeze(dim) beta = distribution.beta.squeeze(dim) - return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) + return NormalInverseGamma(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( f"Squeezing of {dist_type} distributions is not supported." "Raise an issue if needed." @@ -82,19 +106,18 @@ def dist_rearrange( distribution: Distribution, pattern: str, **axes_lengths: int ) -> Distribution: dist_type = type(distribution) - if dist_type in (Normal, Laplace): + if isinstance(distribution, Normal | Laplace): loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) scale = rearrange(distribution.scale, pattern=pattern, **axes_lengths) return dist_type(loc=loc, scale=scale) - if dist_type == NormalInverseGamma: + if isinstance(distribution, NormalInverseGamma): loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) lmbda = rearrange(distribution.lmbda, pattern=pattern, **axes_lengths) alpha = rearrange(distribution.alpha, pattern=pattern, **axes_lengths) beta = rearrange(distribution.beta, pattern=pattern, **axes_lengths) - return dist_type(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) + return NormalInverseGamma(loc=loc, lmbda=lmbda, alpha=alpha, beta=beta) raise NotImplementedError( - f"Ensemble distribution of {dist_type} is not supported." - "Raise an issue if needed." + f"Rearrange of {dist_type} is not supported. Raise an issue if needed." ) diff --git a/torch_uncertainty/utils/hub.py b/torch_uncertainty/utils/hub.py index b48bc324..acb7e3f5 100644 --- a/torch_uncertainty/utils/hub.py +++ b/torch_uncertainty/utils/hub.py @@ -7,7 +7,9 @@ from safetensors.torch import load_file -def load_hf(weight_id: str, version: int = 0) -> tuple[torch.Tensor, dict]: +def load_hf( + weight_id: str, version: int = 0 +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: """Load a model from the HuggingFace hub. Args: @@ -15,7 +17,7 @@ def load_hf(weight_id: str, version: int = 0) -> tuple[torch.Tensor, dict]: version (int): The id of the version when there are several on HF. Returns: - Tuple[Tensor, Dict]: The model weights and config. + tuple[dict[str, torch.Tensor], dict[str, str]]: The model weights and config. Note - License: TorchUncertainty's weights are released under the Apache 2.0 license.