Skip to content

✨ Refactor wrappers & PP, Add Checkpoint Ensembles, EMA, SWA, & SWAG, Add LaplaceApprox & ABNN #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
870af02
:sparkles: Add LeNet experiment on MNIST
o-laurent May 29, 2024
ec61536
:bug: Fix notMNIST
o-laurent May 30, 2024
d268911
:bug: Fix MNIST datamodule OODs
o-laurent May 31, 2024
e137610
Merge branch 'main' of github.com:ENSTA-U2IS-AI/torch-uncertainty int…
o-laurent May 31, 2024
19fafbd
:sparkles: Add Laplace wrapper
o-laurent Jun 2, 2024
acf90eb
:books: Add Laplace to the references
o-laurent Jun 3, 2024
63fb874
:hammer: Refactor Mixup params
o-laurent Jun 5, 2024
acbd582
:bug: Fix #99 error in calibration plots
o-laurent Jun 6, 2024
8c2de92
:shirt: Slightly improve dropout
o-laurent Jun 7, 2024
8655bec
:bug: Fix MC Dropout test
o-laurent Jun 7, 2024
91cf1c0
:book: Remove Packed-Ensembles mentionned twice
o-laurent Jun 12, 2024
988a89b
:sparkles: Add Trajectory Ensemble
o-laurent Jun 12, 2024
8e7c188
:sparkles: Add EMA & SWA & Reformat models
o-laurent Jun 12, 2024
8b0a02a
:book: Add SWA to docs
o-laurent Jun 12, 2024
3c231e2
:hammer: Refactor EMA, SWA, & Checkpoint Ens.
o-laurent Jun 12, 2024
1f72ead
:book: Fix conf error
o-laurent Jun 12, 2024
1c7059d
:sparkles: Merge pull request #96 from ENSTA-U2IS-AI/laplace
o-laurent Jun 13, 2024
c091c9a
:shirt: Small changes
o-laurent Jun 13, 2024
4bfd351
:hammer: Refactor the post processing methods
o-laurent Jun 13, 2024
2a15dce
Merge branch 'trajectory' of github.com:ENSTA-U2IS-AI/torch-uncertain…
o-laurent Jun 13, 2024
be85bf8
:hammer: Refactor the AbstractDatamodule
o-laurent Jun 13, 2024
951ff09
:bug: Fix test of abstract methods
o-laurent Jun 13, 2024
83c4cce
:hammer: Refactor pp methods
o-laurent Jun 16, 2024
e48368d
:sparkles: Add first version of SWAG
o-laurent Jun 16, 2024
df5330e
:hammer: Refactor wrappers
o-laurent Jun 16, 2024
1d0c595
:hammer: Refactor the classification routine
o-laurent Jun 16, 2024
6e989f6
:book: Add links to the conf. in ReadMe
o-laurent Jun 16, 2024
fbc8c55
:white_check_mark: Update tests
o-laurent Jun 16, 2024
27bb610
:sparkles: Improve SWAG code
o-laurent Jun 17, 2024
0010d95
:shirt: Minor fix
o-laurent Jun 17, 2024
2c63b3b
:wrench: Fix online install
o-laurent Jun 17, 2024
914599b
:sparkles: Add a full scheduler for SWA & SWAG & update config
o-laurent Jun 17, 2024
a3443e3
:bug: Improve SWA & SWAG
o-laurent Jun 17, 2024
4a4eeac
:hammer: Refactor stochastic models
o-laurent Jun 17, 2024
737d862
:bug: Fix Stochastic MLP error
o-laurent Jun 17, 2024
6f57332
:books: Update documentation
o-laurent Jun 17, 2024
00eb701
:books: Fix bugs in docs
o-laurent Jun 17, 2024
4e1e8af
:white_check_mark: Add first battery of tests
o-laurent Jun 17, 2024
501b5d4
:heavy_check_mark: Fix tests
o-laurent Jun 17, 2024
d135612
:white_check_mark: Improve SWAG tests
o-laurent Jun 17, 2024
f1b6546
:white_check_mark: Improve Stochastic tests
o-laurent Jun 17, 2024
edbb88e
:shirt: Minor changes
o-laurent Jun 17, 2024
fdbaf76
:white_check_mark: Finetune tests
o-laurent Jun 17, 2024
63e874a
Merge pull request #101 from ENSTA-U2IS-AI/trajectory
o-laurent Jun 17, 2024
a601ff9
:ok_hand: Take review comments into account
o-laurent Jun 18, 2024
3ea90e9
:books: Improve documentation & tutorials
o-laurent Jun 18, 2024
84fb04f
:book: Add a tutorial on Packed-Ensembles
o-laurent Jun 18, 2024
f6fb41c
:white_check_mark: Improve tests
o-laurent Jun 18, 2024
676f272
:shirt: Improve ReadMe
o-laurent Jun 18, 2024
80aaaa8
:bug: Fix SWAG
o-laurent Jun 18, 2024
06b990f
:sparkles: Propagate changes to the other routines & update tests
o-laurent Jun 18, 2024
725bd9c
:hammer: rename inference_size to eval_size
o-laurent Jun 18, 2024
fcbfeaa
:bug: Fix regression routines
o-laurent Jun 18, 2024
c9e0404
:sparkles: Add first version for ABNN
o-laurent Jun 18, 2024
4fe4aec
:white_check_mark: Improve coverage
o-laurent Jun 18, 2024
7a57586
:wrench: Lock plt version
o-laurent Jun 18, 2024
0a3a5a7
:bug: Minor changes
o-laurent Jun 19, 2024
41f2f80
:fire: Remove webdataset
o-laurent Jun 19, 2024
17c3071
:books: Improve API Page
o-laurent Jun 21, 2024
f56866a
:white_check_mark: Slightly improve tests
o-laurent Jun 21, 2024
2d84fe6
:ok_hand: Make review modifications before merging
o-laurent Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
if: steps.changed-files-specific.outputs.only_changed != 'true'
run: |
python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
python3 -m pip install .[image,dev,docs]
python3 -m pip install .[all]

- name: Check style & format
if: steps.changed-files-specific.outputs.only_changed != 'true'
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ _TorchUncertainty_ is a package designed to help you leverage [uncertainty quant

:books: Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). :books:

TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **WACV 2024** and **ECCV 2024**.
TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **[WACV](https://wacv2024.thecvf.com/) 2024**, **[HAICON](https://haicon24.de/) 2024** and **[ECCV](https://eccv.ecva.net/) 2024**.

---

Expand Down Expand Up @@ -69,6 +69,8 @@ To date, the following deep learning baselines have been implemented:
- MIMO
- Packed-Ensembles (see [Blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html)
- Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html)
- Checkpoint Ensembles & Snapshot Ensembles
- Stochastic Weight Averaging & Stochastic Weight Averaging Gaussian
- Regression with Beta Gaussian NLL Loss
- Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html)

Expand All @@ -84,6 +86,7 @@ To date, the following post-processing methods have been implemented:

- Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html)
- Monte Carlo Batch Normalization - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_batch_norm.html)
- A wrapper for Laplace appoximation using the [Laplace library](https://github.com/aleximmer/Laplace)

## Tutorials

Expand Down
27 changes: 19 additions & 8 deletions auto_tutorials_source/tutorial_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
# We will use the Adam optimizer with the default learning rate of 0.001.


def optim_lenet(model: nn.Module) -> dict:
def optim_lenet(model: nn.Module):
optimizer = optim.Adam(
model.parameters(),
lr=1e-3,
)
return {"optimizer": optimizer}
return optimizer


# %%
Expand All @@ -75,7 +75,7 @@ def optim_lenet(model: nn.Module) -> dict:
trainer = Trainer(accelerator="cpu", enable_progress_bar=False, max_epochs=1)

# datamodule
root = Path("") / "data"
root = Path("data")
datamodule = MNISTDataModule(root=root, batch_size=128, eval_ood=False)

# model
Expand Down Expand Up @@ -105,6 +105,7 @@ def optim_lenet(model: nn.Module) -> dict:
num_classes=datamodule.num_classes,
loss=loss,
optim_recipe=optim_lenet(model),
is_ensemble=True
)

# %%
Expand All @@ -125,8 +126,10 @@ def optim_lenet(model: nn.Module) -> dict:
# 6. Testing the Model
# ~~~~~~~~~~~~~~~~~~~~
#
# Now that the model is trained, let's test it on MNIST

# Now that the model is trained, let's test it on MNIST.
# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble
# and to the batch. As for TorchUncertainty 2.0, the ensemble dimension is merged with the batch dimension
# in this order (num_estimator x batch, classes).
import matplotlib.pyplot as plt
import numpy as np
import torch
Expand All @@ -148,14 +151,22 @@ def imshow(img):
imshow(torchvision.utils.make_grid(images[:4, ...]))
print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))

logits = model(images)
# Put the model in eval mode to use several samples
model = model.eval()
logits = model(images).reshape(16, 128, 10) # num_estimators, batch_size, num_classes

# We apply the softmax on the classes and average over the estimators
probs = torch.nn.functional.softmax(logits, dim=-1)
avg_probs = probs.mean(dim=0)
var_probs = probs.std(dim=0)

_, predicted = torch.max(probs, 1)
_, predicted = torch.max(avg_probs, 1)

print("Predicted digits: ", " ".join(f"{predicted[j]}" for j in range(4)))

print("Std. dev. of the scores over the posterior samples", " ".join(f"{var_probs[j][predicted[j]]:.3}" for j in range(4)))
# %%
# The scores should be quite certain.
#
# References
# ----------
#
Expand Down
5 changes: 4 additions & 1 deletion auto_tutorials_source/tutorial_mc_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@
# .eval() to enable Monte Carlo batch normalization at inference.
# In this tutorial, we plot the most uncertain images, i.e. the images for which
# the variance of the predictions is the highest.
# Please note that we apply a reshape to the logits to determine the dimension corresponding to the ensemble
# and to the batch. As for TorchUncertainty 2.0, the ensemble dimension is merged with the batch dimension
# in this order (num_estimator x batch, classes).

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -121,7 +124,7 @@ def imshow(img):
images, labels = next(dataiter)

routine.eval()
logits = routine(images).reshape(8, 128, 10)
logits = routine(images).reshape(8, 128, 10) # num_estimators, batch_size, num_classes

probs = torch.nn.functional.softmax(logits, dim=-1)
most_uncertain = sorted(probs.var(0).sum(-1).topk(4).indices)
Expand Down
13 changes: 7 additions & 6 deletions auto_tutorials_source/tutorial_mc_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
# dataloaders and transforms. We create the model using the
# blueprint from torch_uncertainty.models and we wrap it into mc_dropout.
#
# It is important to specify the arguments,``num_estimators`` and the ``dropout_rate``
# It is important to specify the arguments,``num_estimators``
# and the ``dropout_rate``
# to use Monte Carlo dropout.

trainer = Trainer(accelerator="cpu", max_epochs=2, enable_progress_bar=False)
Expand All @@ -64,7 +65,7 @@
model = lenet(
in_channels=datamodule.num_channels,
num_classes=datamodule.num_classes,
dropout_rate=0.4,
dropout_rate=0.5,
)

mc_model = mc_dropout(model, num_estimators=16, last_layer=False)
Expand Down Expand Up @@ -118,22 +119,22 @@ def imshow(img):
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images[:4, ...]))
print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))
imshow(torchvision.utils.make_grid(images[:6, ...], padding=0))
print("Ground truth labels: ", " ".join(f"{labels[j]}" for j in range(6)))

routine.eval()
logits = routine(images).reshape(16, 128, 10)

probs = torch.nn.functional.softmax(logits, dim=-1)


for j in range(4):
for j in range(6):
values, predicted = torch.max(probs[:, j], 1)
print(
f"Predicted digits for the image {j+1}: ",
" ".join([str(image_id.item()) for image_id in predicted]),
)

# %%
# We see that there is some disagreement between the samples of the dropout
# Most of the time, we see that there is some disagreement between the samples of the dropout
# approximation of the posterior distribution.
29 changes: 18 additions & 11 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,21 @@ Models

.. currentmodule:: torch_uncertainty.models

Deep Ensembles
^^^^^^^^^^^^^^
Wrappers
^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

deep_ensembles

Monte Carlo Dropout

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class.rst

CheckpointEnsemble
EMA
StochasticModel
SWA
SWAG
MCDropout
mc_dropout

Metrics
Expand Down Expand Up @@ -242,6 +240,16 @@ Post-Processing Methods

.. currentmodule:: torch_uncertainty.post_processing

.. autosummary::
:toctree: generated/
:nosignatures:
:template: class_inherited.rst
MCBatchNorm
LaplaceApprox

Scaling Methods
^^^^^^^^^^^^^^^

.. autosummary::
:toctree: generated/
:nosignatures:
Expand All @@ -250,7 +258,6 @@ Post-Processing Methods
TemperatureScaler
VectorScaler
MatrixScaler
MCBatchNorm

Datamodules
-----------
Expand Down
10 changes: 1 addition & 9 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ To install TorchUncertainty with contribution in mind, check the
-----

Official Implementations
^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^

TorchUncertainty also houses multiple official implementations of papers from major conferences & journals.

Expand All @@ -56,14 +56,6 @@ TorchUncertainty also houses multiple official implementations of papers from ma
* Authors: *Gianni Franchi, Xuanlong Yu, Andrei Bursuc, Angel Tena, Rémi Kazmierczak, Séverine Dubuisson, Emanuel Aldea, David Filliat*
* Paper: `BMVC 2022 <https://arxiv.org/abs/2203.01437>`_.

Packed-Ensembles
^^^^^^^^^^^^^^^^

**Packed-Ensembles for Efficient Uncertainty Estimation**

* Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi*
* Paper: `here <https://arxiv.org/abs/2210.09184>`_.

.. toctree::
:maxdepth: 2
:caption: Contents:
Expand Down
46 changes: 43 additions & 3 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ For Deep Evidential Regression, consider citing:
* Paper: `NeurIPS 2020 <https://arxiv.org/pdf/1910.02600>`__.


Bayesian Neural Networks
^^^^^^^^^^^^^^^^^^^^^^^^
Variational Inference Bayesian Neural Networks
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For Bayesian Neural Networks, consider citing:
For Variational Inference Bayesian Neural Networks, consider citing:

**Weight Uncertainty in Neural Networks**

Expand Down Expand Up @@ -73,6 +73,36 @@ For Monte-Carlo Dropout, consider citing:
* Authors: *Yarin Gal and Zoubin Ghahramani*
* Paper: `ICML 2016 <https://arxiv.org/pdf/1506.02142.pdf>`__.

Stochastic Weight Averaging
^^^^^^^^^^^^^^^^^^^^^^^^^^^

For Stochastic Weight Averaging, consider citing:

**Averaging Weights Leads to Wider Optima and Better Generalization**

* Authors: *Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson*
* Paper: `UAI 2018 <https://arxiv.org/pdf/1803.05407.pdf>`__.

Stochastic Weight Averaging Gaussian
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For Stochastic Weight Averaging Gaussian, consider citing:

**A simple baseline for Bayesian uncertainty in deep learning**

* Authors: *Wesley Maddox, Timur Garipov, Pavel Izmailov, Dmitry Vetrov, Andrew Gordon Wilson*
* Paper: `NeurIPS 2019 <https://arxiv.org/pdf/1902.02476.pdf>`__.


CheckpointEnsemble
^^^^^^^^^^^^^^^^^^

For CheckpointEnsemble, consider citing:

**Checkpoint Ensembles: Ensemble Methods from a Single Training Process**

* Authors: *Hugh Chen, Scott Lundberg, Su-In Lee*
* Paper: `ArXiv <https://arxiv.org/pdf/1710.03282>`__.

BatchEnsemble
^^^^^^^^^^^^^
Expand Down Expand Up @@ -193,6 +223,16 @@ For Monte-Carlo Batch Normalization, consider citing:
* Authors: *Mathias Teye, Hossein Azizpour, and Kevin Smith*
* Paper: `ICML 2018 <https://arxiv.org/pdf/1802.06455.pdf>`__.

Laplace Approximation
^^^^^^^^^^^^^^^^^^^^^

For Laplace Approximation, consider citing:

**Laplace Redux - Effortless Bayesian Deep Learning**

* Authors: *Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, Philipp Hennig*
* Paper: `NeurIPS 2021 <https://arxiv.org/abs/2106.14806>`__.

Metrics
-------

Expand Down
59 changes: 59 additions & 0 deletions experiments/classification/mnist/configs/lenet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
accelerator: gpu
devices: 1
precision: 16-mixed
max_epochs: 75
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/lenet
name: standard
default_hp_metric: false
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val/cls/Acc
mode: max
save_last: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/cls/Acc
patience: 1000
check_finite: true
model:
model:
class_path: torch_uncertainty.models.lenet._LeNet
init_args:
in_channels: 1
num_classes: 10
linear_layer: torch.nn.Linear
conv2d_layer: torch.nn.Conv2d
activation: torch.nn.ReLU
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
last_layer_dropout: false
layer_args: {}
num_classes: 10
loss: CrossEntropyLoss
data:
root: ./data
batch_size: 128
optimizer:
lr: 0.05
momentum: 0.9
weight_decay: 5e-4
nesterov: true
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 25
- 50
gamma: 0.1
Loading
Loading