Skip to content

✨ Multi-dataset General ood eval #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 28 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
53ea81e
added initial support for multi-dataset ood eval
fira7s Apr 9, 2025
9bae3eb
:sparkles: added scale ood eval method
fira7s Apr 14, 2025
3a069e9
:hammer: fixed ood table output and defined near and far ood
fira7s Apr 15, 2025
bf9c384
:hammer: restructure + added eval ood for imagenet
fira7s Apr 16, 2025
6ac1fff
Restructure + added various ood eval methods
fira7s Apr 23, 2025
1fb58d8
Changed ood splits
fira7s Apr 30, 2025
7f1d876
:hammer: Code restructure and fixes
fira7s May 7, 2025
30e05b4
:hammer: fixed missing dataset
fira7s May 7, 2025
ad05e35
:hammer: minor fixes
fira7s May 7, 2025
be802f0
:hammer: minor fixes
fira7s May 7, 2025
d994edc
:bug: Fix ImageNet Variation download
o-laurent May 14, 2025
bd80851
:hammer: fixed tests and ood eval
fira7s May 14, 2025
f9ab886
Merge remote-tracking branch 'origin/OodEval' into OodEval
fira7s May 14, 2025
85a89e9
Merge branch 'dev' into OodEval
fira7s May 14, 2025
f695223
:hammer: minor fixes
fira7s May 15, 2025
36211da
:hammer: fixed tmpscale/confromal for ensemble methods
fira7s May 16, 2025
622bb11
:hammer: fixed issues with postprocess+ensembling
fira7s May 19, 2025
8faf5a0
minor fixes
fira7s May 19, 2025
d4c0a00
:hammer: fixing tests
fira7s May 19, 2025
4607b0f
:hammer: minor fixes
fira7s May 23, 2025
8300f3b
:hammer: minor fixes
fira7s May 23, 2025
e26a274
Merge
fira7s May 23, 2025
c60b3c2
minor fixes
fira7s May 23, 2025
51be518
:hammer: ood eval fixes
fira7s Jun 2, 2025
bec0916
:hammer: fixed dependecies
fira7s Jun 4, 2025
2f219f5
:hammer: fixed ood splits
fira7s Jun 11, 2025
8c5d53d
:hammer: minor fixes
fira7s Jun 11, 2025
59a0240
:hammer: fix tests
fira7s Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ repos:
name: check-added-large-files
entry: check-added-large-files
language: system
args: ['--maxkb=2048']
- id: ruff-check
name: ruff-check
entry: ruff check
Expand Down
224 changes: 224 additions & 0 deletions auto_tutorial_source/Bayesian_Methods/tuto_ood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""
Simple Ood Evaluation
================================================


In this tutorial, we’ll demonstrate how to perform out-of-distribution (OOD) evaluation using TorchUncertainty’s datamodules and routines. You’ll learn to:

1. **Set up a CIFAR-100 datamodule** that automatically handles in-distribution, near-OOD, and far-OOD splits.
2. **Run the `ClassificationRoutine`** to compute both in-distribution accuracy and OOD metrics (AUROC, AUPR, FPR95).
3. **Plug in your own OOD datasets** for fully custom evaluation.

Foreword on Out-of-Distribution Detection
-----------------------------------------

Out-of-Distribution (OOD) detection measures a model’s ability to recognize inputs that differ from its training distribution. TorchUncertainty integrates common OOD metrics directly into the Lightning test loop, including:

- **AUROC** (Area Under the ROC Curve)
- **AUPR** (Area Under the Precision-Recall Curve)
- **FPR95** (False Positive Rate at 95% True Positive Rate)

With just a few lines of code you can compare in-distribution performance to OOD detection performance under both “near” and “far” shifts. Per default, TorchUncertainty uses the
popular OpenOOD library to define the near and far OOD datasets and splits. You can also use your own datasets by passing them to the datamodule.

Supported Datamodules and Default OOD Splits
--------------------------------------------

.. list-table:: Datamodules & Default OOD Splits
:header-rows: 1
:widths: 20 15 20 20

* - **Datamodule**
- **In-Domain**
- **Default Near-OOD (Hard)**
- **Default Far-OOD (Easy)**
* - ``CIFAR10DataModule``
- CIFAR-10
- CIFAR-100, Tiny ImageNet
- MNIST, SVHN, Textures, Places365
* - ``CIFAR100DataModule``
- CIFAR-100
- CIFAR-10, Tiny ImageNet
- MNIST, SVHN, Textures, Places365
* - ``ImageNetDataModule``
- ImageNet-1K
- SSB-hard, NINCO
- iNaturalist, Textures, OpenImage-O
* - ``ImageNet200DataModule``
- ImageNet200
- SSB-hard, NINCO
- iNaturalist, Textures, OpenImage-O

Supported OOD Criteria
----------------------

.. list-table:: Supported OOD Criteria
:header-rows: 1
:widths: 15 50

* - **Criterion**
- **Original Reference (Year, Venue)**
* - ``msp``
- Hendrycks & Gimpel, A Baseline for Detecting Misclassified and Out-of-Distribution Examples in Neural Networks `ICLR Workshop 2017 <https://arxiv.org/abs/1610.02136>`_.
* - ``Maxlogit``
- /
* - ``energy``
- Liu et al., Energy-based Out-of-Distribution Detection `NeurIPS 2020 <https://arxiv.org/abs/2010.03759>`_.
* - ``odin``
- Liang, Li & Srikant, Enhancing The Reliability of Out-of-Distribution Image Detection in Neural Networks `ICML 2018 <https://arxiv.org/abs/1706.02690>`_.
* - ``entropy``
- /
* - ``mutual_information``
- /
* - ``variation_ratio``
- /
* - ``scale``
- Scaling Out-of-Distribution Detection for Real-World Settings Hendrycks et al. `ICML 2022 <https://proceedings.mlr.press/v162/hendrycks22a/hendrycks22a.pdf>`_.
* - ``ash``
- AASH: Extremely Simple Activation Shaping for OOD Detection, Djurisic et al. `ICLR 2023 <https://arxiv.org/pdf/2209.09858>`_.
* - ``react``
- ReAct: Out-of-distribution Detection with Rectified Activations, Sun et al. `NeurIPS 2021 <https://proceedings.neurips.cc/paper/2021/file/01894d6f048493d2cacde3c579c315a3-Paper.pdf>`_.
* - ``adascale_a``
- AdaSCALE: Adaptive Scaling for OOD Detection `Regmi et al. <https://arxiv.org/pdf/2503.08023>`_.
* - ``vim``
- ViM: Out-of-Distribution with Virtual-Logit Matching, Wang et al. `CVPR 2022 <https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_ViM_Out-of-Distribution_With_Virtual-Logit_Matching_CVPR_2022_paper.pdf>`_.
* - ``knn``
- Out-of-Distribution Detection with Deep Nearest Neighbors, Sun et al. `ICML 2022 <https://arxiv.org/abs/2106.01477>`_.
* - ``gen``
- GEN: Generalized ENtropy Score for OOD Detection, Liu et al. `CVPR 2023 <https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_GEN_Pushing_the_Limits_of_Softmax-Based_Out-of-Distribution_Detection_CVPR_2023_paper.pdf>`_.
* - ``nnguide``
- NNGuide: Nearest-Neighbor Guidance for OOD Detection, Park et al. `ICCV 2023 <https://openaccess.thecvf.com/content/ICCV2023/papers/Park_Nearest_Neighbor_Guidance_for_Out-of-Distribution_Detection_ICCV_2023_paper.pdf>`_.

.. note::

- All of these criteria can be passed as the `ood_criterion` argument to
`ClassificationRoutine`.
- Methods marked “ensemble-only” will require multiple stochastic passes.



.. note::

- **Near-OOD** splits are semantically similar to the in-domain data.
- **Far-OOD** splits come from more distant distributions (e.g., ImageNet variants).
- Override defaults by passing your own ``near_ood_datasets`` / ``far_ood_datasets``.


1. Loading the utilities
~~~~~~~~~~~~~~~~~~~~~~~~

To eval ood using TorchUncertainty, we have to load the following:

- the model:ResNet18_32x32 trained on in-domain data cifar100
- the classification routine from torch_uncertainty.routines
- the datamodule that handles dataloaders: CIFAR100DataModule from torch_uncertainty.datamodules.
"""

# %%
from pathlib import Path

# %%
# 2. Load the trained model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# In this tutorial we will be loading a pretrained model, but you can also train your own using the same classification routine and still get ood related metrics at test phase.


import torch
from torch_uncertainty.models.resnet import resnet
from huggingface_hub import hf_hub_download

net = resnet(in_channels=3, arch=18, num_classes=100, style="cifar", conv_bias=False)

# load the model
path = hf_hub_download(repo_id="torch-uncertainty/resnet18_c100", filename="resnet18_c100.ckpt")
net.load_state_dict(torch.load(path))

net.cuda()
net.eval()


# %%
# 3. Defining the necessary datamodules
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# In the following, we instantiate our trainer, define the root of the datasets and the logs.
# We also create the datamodule that handles the cifar100 dataset, dataloaders and transforms.
# Datamodules can also handle OOD detection by setting the eval_ood parameter to True.

from torch_uncertainty.datamodules import CIFAR100DataModule
from torch_uncertainty.routines import ClassificationRoutine
import torch.nn as nn
from pathlib import Path
from torch_uncertainty import TUTrainer


root = Path("data1")
datamodule = CIFAR100DataModule(root=root, batch_size=200, eval_ood=True, eval_shift=True)
trainer = TUTrainer(accelerator="gpu", enable_progress_bar=True)


# %%
# 4. Define the classification routine and launch the test
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Define the classification routine for evaluation. We use the CrossEntropyLoss
# as the loss function since we are working on a classification task.
# The routine is configured to handle OOD detection and distributional shifts using the specified model, loss function, and evaluation criteria.

routine = ClassificationRoutine(
num_classes=datamodule.num_classes,
eval_ood=True,
model=net,
loss=nn.CrossEntropyLoss(),
eval_shift=True,
ood_criterion="ash",
)

# Perform testing using the defined routine and datamodule.
results = trainer.test(model=routine, datamodule=datamodule)


# %%
# Here, we show the various test metrics along with the ood eval metrics, auroc,aupr and fpr95 on Near and far ood datasets defined per defualt according to OpenOOD splits (link to library)


# %%
# 5. Defining custom ood datasets
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# If you don't want to use the open ood datasets or dataset splits, you can pass your own datasets in a list to near_ood_datasets or far_ood_datasets datamodule arguments
# and use them for ood evaluation but make sure they inherit from the
# Dataset class from torch.utils.data, below is an example of such a case.

from torchvision.datasets import CIFAR10, MNIST
from torchvision.transforms import v2


test_transform = v2.Compose(
[
v2.ToImage(),
v2.Resize(32),
v2.CenterCrop(32),
v2.ToDtype(dtype=torch.float32, scale=True),
v2.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.5071, 0.4867, 0.4408)),
]
)

custom_dataset1 = CIFAR10(root=root, train=False, download=True, transform=test_transform)
custom_dataset2 = MNIST(root=root, train=False, download=True, transform=test_transform)

datamodule = CIFAR100DataModule(
root=root,
batch_size=200,
eval_ood=True,
eval_shift=True,
near_ood_datasets=[custom_dataset1],
far_ood_datasets=[custom_dataset2],
)

# Perform testing using the CUSTOM defined ood datasets.
results = trainer.test(model=routine, datamodule=datamodule)


# %%
# References
# ----------
# - **OpenOOD:** Jingyang Zhang & al. (`Neurips 2025 <https://arxiv.org/pdf/2306.09301>`_). OpenOOD v1.5: Enhanced Benchmark for Out-of-Distribution Detection.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ dependencies = [
"torchvision>=0.16",
"einops",
"seaborn",
"scikit-learn",
"scipy",
"faiss-cpu",
"statsmodels"
]

[project.optional-dependencies]
Expand All @@ -62,9 +66,7 @@ docs = [
]
all = [
"torch_uncertainty[dev,docs,tabular]",
"scikit-learn",
"laplace-torch",
"scipy",
"glest==0.0.1a1",
]

Expand Down
2 changes: 1 addition & 1 deletion tests/_dummies/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import nn

from torch_uncertainty.models import EMA, SWA, deep_ensembles
from torch_uncertainty.ood_criteria import TUOODCriterion
from torch_uncertainty.ood.ood_criteria import TUOODCriterion
from torch_uncertainty.optim_recipes import optim_cifar10_resnet18
from torch_uncertainty.post_processing import TemperatureScaler
from torch_uncertainty.routines import (
Expand Down
Loading
Loading