Skip to content

⚡ Version 0.5.1 #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 44 additions & 40 deletions auto_tutorial_source/Bayesian_Methods/tutorial_bayesian.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
# ruff: noqa: E402, E703, D212, D415, T201
"""
Train a Bayesian Neural Network in Three Minutes
Training a Bayesian Neural Network in 20 seconds
================================================

In this tutorial, we will train a variational inference Bayesian Neural Network (BNN) LeNet classifier on the MNIST dataset.
In this tutorial, we will train a variational inference Bayesian Neural Network (viBNN) LeNet classifier on the MNIST dataset.

Foreword on Bayesian Neural Networks
------------------------------------

Bayesian Neural Networks (BNNs) are a class of neural networks that estimate the uncertainty on their predictions via uncertainty
on their weights. This is achieved by considering the weights of the neural network as random variables, and by learning their
posterior distribution. This is in contrast to standard neural networks, which only learn a single set of weights, which can be
seen as Dirac distributions on the weights.
posterior distribution. This is in contrast to standard neural networks, which only learn a single set of weights (this can be
seen as Dirac distributions on the weights).

For more information on Bayesian Neural Networks, we refer the reader to the following resources:
For more information on Bayesian Neural Networks, we refer to the following resources:

- Weight Uncertainty in Neural Networks `ICML2015 <https://arxiv.org/pdf/1505.05424.pdf>`_
- Hands-on Bayesian Neural Networks - a Tutorial for Deep Learning Users `IEEE Computational Intelligence Magazine <https://arxiv.org/pdf/2007.06823.pdf>`_
- Hands-on Bayesian Neural Networks - a Tutorial for Deep Learning Users `IEEE Computational Intelligence Magazine
<https://arxiv.org/pdf/2007.06823.pdf>`_

Training a Bayesian LeNet using TorchUncertainty models and Lightning
---------------------------------------------------------------------

In this part, we train a Bayesian LeNet, based on the model and routines already implemented in TU.
In this first part, we train a Bayesian LeNet, based on the model and routines already implemented in TU.

1. Loading the utilities
~~~~~~~~~~~~~~~~~~~~~~~~

To train a BNN using TorchUncertainty, we have to load the following modules:

- our TUTrainer
- the model: bayesian_lenet, which lies in the torch_uncertainty.model
- the classification training routine from torch_uncertainty.routines
- our TUTrainer to improve the display of our metrics
- the model: bayesian_lenet, which lies in the torch_uncertainty.model.classification.lenet module
- the classification training routine from torch_uncertainty.routines module
- the Bayesian objective: the ELBOLoss, which lies in the torch_uncertainty.losses file
- the datamodule that handles dataloaders: MNISTDataModule from torch_uncertainty.datamodules

Expand All @@ -46,39 +47,43 @@
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.models.classification import bayesian_lenet
from torch_uncertainty.models.classification.lenet import bayesian_lenet
from torch_uncertainty.routines import ClassificationRoutine

# We also define the main hyperparameters, with just one epoch for the sake of time
BATCH_SIZE = 512
MAX_EPOCHS = 2

# %%
# 2. Creating the necessary variables
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In the following, we instantiate our trainer, define the root of the datasets and the logs.
# We also create the datamodule that handles the MNIST dataset, dataloaders and transforms.
# Please note that the datamodules can also handle OOD detection by setting the eval_ood
# parameter to True. Finally, we create the model using the blueprint from torch_uncertainty.models.
# Please note that the datamodules can also handle OOD detection by setting the `eval_ood`
# parameter to True, as well as distribution shift with `eval_shift`.
# Finally, we create the model using the blueprint from torch_uncertainty.models.

trainer = TUTrainer(accelerator="gpu", devices=1, enable_progress_bar=False, max_epochs=1)
trainer = TUTrainer(accelerator="gpu", devices=1, enable_progress_bar=False, max_epochs=MAX_EPOCHS)

# datamodule
root = Path("data")
datamodule = MNISTDataModule(root=root, batch_size=128, eval_ood=False)
datamodule = MNISTDataModule(root=root, batch_size=BATCH_SIZE, num_workers=8)

# model
model = bayesian_lenet(datamodule.num_channels, datamodule.num_classes)

# %%
# 3. The Loss and the Training Routine
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Then, we just have to define the loss to be used during training. To do this,
# we redefine the default parameters from the ELBO loss using the partial
# function from functools. We use the hyperparameters proposed in the blitz
# library. As we are train a classification model, we use the CrossEntropyLoss
# as the likelihood.
# We then define the training routine using the classification training routine
# from torch_uncertainty.classification. We provide the model, the ELBO
#
# Then, we just define the loss to be used during training, which is a bit special and called
# the evidence lower bound. We use the hyperparameters proposed in the blitz
# library. As we are training a classification model, we use the CrossEntropyLoss
# as the negative log likelihood. We then define the training routine using the classification
# training routine from torch_uncertainty.classification. We provide the model, the ELBO
# loss and the optimizer to the routine.
# We will use the Adam optimizer with the default learning rate of 0.001.
# We use an Adam optimizer with a learning rate of 0.02.

loss = ELBOLoss(
model=model,
Expand All @@ -91,10 +96,7 @@
model=model,
num_classes=datamodule.num_classes,
loss=loss,
optim_recipe=optim.Adam(
model.parameters(),
lr=1e-3,
),
optim_recipe=optim.Adam(model.parameters(), lr=2e-2),
is_ensemble=True,
)

Expand All @@ -105,25 +107,26 @@
# Now that we have prepared all of this, we just have to gather everything in
# the main function and to train the model using our wrapper of Lightning Trainer.
# Specifically, it needs the routine, that includes the model as well as the
# training/eval logic and the datamodule
# training/eval logic and the datamodule.
# The dataset will be downloaded automatically in the root/data folder, and the
# logs will be saved in the root/logs folder.

trainer.fit(model=routine, datamodule=datamodule)
trainer.test(model=routine, datamodule=datamodule)

# %%
# 5. Testing the Model
# ~~~~~~~~~~~~~~~~~~~~
#
# Now that the model is trained, let's test it on MNIST.
# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble
# and to the batch. As for TorchUncertainty 0.2.0, the ensemble dimension is merged with the batch dimension
# and to the batch. As for TorchUncertainty 0.5.1, the ensemble dimension is merged with the batch dimension
# in this order (num_estimator x batch, classes).

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from einops import rearrange


def imshow(img) -> None:
Expand All @@ -134,32 +137,33 @@ def imshow(img) -> None:
plt.show()


dataiter = iter(datamodule.val_dataloader())
images, labels = next(dataiter)
images, labels = next(iter(datamodule.val_dataloader()))

# print images
imshow(torchvision.utils.make_grid(images[:4, ...]))
print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))

# Put the model in eval mode to use several samples
model = model.eval()
logits = model(images).reshape(16, 128, 10) # num_estimators, batch_size, num_classes
model = routine.eval()
logits = routine(images[:4, ...])
print("Output logit shape (Num predictions x Batch) x Classes: ", logits.shape)
logits = rearrange(logits, "(m b) c -> b m c", b=4) # batch_size, num_estimators, num_classes

# We apply the softmax on the classes and average over the estimators
# We apply the softmax on the classes then average over the estimators
probs = torch.nn.functional.softmax(logits, dim=-1)
avg_probs = probs.mean(dim=0)
var_probs = probs.std(dim=0)
avg_probs = probs.mean(dim=1)
var_probs = probs.std(dim=1)

_, predicted = torch.max(avg_probs, 1)
predicted = torch.argmax(avg_probs, -1)

print("Predicted digits: ", " ".join(f"{predicted[j]}" for j in range(4)))
print(
"Std. dev. of the scores over the posterior samples",
" ".join(f"{var_probs[j][predicted[j]]:.3}" for j in range(4)),
" ".join(f"{var_probs[j][predicted[j]]:.3f}" for j in range(4)),
)
# %%
# Here, we show the variance of the top prediction. This is a non-standard but intuitive way to show the diversity of the predictions
# of the ensemble. Ideally, the variance should be high when the average top prediction is incorrect.
# of the ensemble. Ideally, the variance should be high when the prediction is incorrect.
#
# References
# ----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def enet_weighing(dataloader, num_classes, c=1.02):
optim_recipe={"optimizer": optimizer, "lr_scheduler": lr_updater},
)

trainer = TUTrainer(accelerator="gpu", devices=1, max_epochs=NB_EPOCHS, enable_progress_bar=True)
trainer = TUTrainer(accelerator="gpu", devices=1, max_epochs=NB_EPOCHS, enable_progress_bar=False)
# %%
# 6. Training the model
# ~~~~~~~~~~~~~~~~~~~~~
Expand Down
174 changes: 0 additions & 174 deletions auto_tutorial_source/Classification/tutorial_bayesian.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@
# We specify the maximum number of epochs, the precision and the device to be used.

# Initialize the TUTrainer with a maximum of 10 epochs and the specified device
trainer = TUTrainer(max_epochs=10, precision="16-mixed", accelerator="cuda", devices=1)
trainer = TUTrainer(
max_epochs=10, precision="16-mixed", accelerator="cuda", devices=1, enable_progress_bar=False
)

# Begin training the model using the CIFAR-10 DataModule
trainer.fit(routine, datamodule=datamodule)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pathlib import Path

import torch
from torch import nn, optim
from torch import optim

from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import MNISTDataModule
Expand Down
Loading