Skip to content

🔨 SegmentationRoutine supports OOD detection #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 18, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ data:
root: ./data
batch_size: 16
crop_size: 256
eval_ood: true
eval_size:
- 512
- 1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ model:
bilinear: true
num_estimators: 4
task: segmentation
ckpt_paths: ./logs/muad/unet/deep_ensembles/version_0/checkpoints
ckpt_paths: ./logs/muad/unet/deep_ensembles/version_2/checkpoints
use_tu_ckpt_format: true
num_classes: 15
loss:
class_path: torch.nn.CrossEntropyLoss
Expand Down Expand Up @@ -60,6 +61,7 @@ data:
root: ./data
batch_size: 32
crop_size: 256
eval_ood: true
eval_size:
- 512
- 1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ trainer:
accelerator: gpu
precision: bf16-mixed
max_epochs: 100
accumulate_grad_batches: 4
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
Expand Down Expand Up @@ -53,8 +54,9 @@ model:
num_repeats: 4
data:
root: ./data
batch_size: 32
batch_size: 8
crop_size: 256
eval_ood: true
eval_size:
- 512
- 1024
Expand Down
76 changes: 76 additions & 0 deletions experiments/segmentation/muad/configs/muad/unet/mc_dropout.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
accelerator: gpu
precision: bf16-mixed
max_epochs: 100
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/muad/unet
name: mc_dropout
default_hp_metric: false
callbacks:
- class_path: torch_uncertainty.callbacks.TUSegCheckpoint
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
model:
model:
class_path: torch_uncertainty.models.mc_dropout
init_args:
model:
class_path: torch_uncertainty.models.segmentation.unet
init_args:
in_channels: 3
num_classes: 15
bilinear: true
dropout_rate: 0.1
num_estimators: 10
on_batch: false
num_classes: 15
loss:
class_path: torch.nn.CrossEntropyLoss
init_args:
weight:
class_path: torch.Tensor
dict_kwargs:
data:
- 4.1712
- 19.4603
- 3.2345
- 49.2588
- 36.2490
- 34.0272
- 47.0651
- 49.7145
- 12.4178
- 48.3962
- 14.3876
- 32.8862
- 5.2729
- 17.8703
- 50.4984
data:
root: ./data
batch_size: 32
crop_size: 256
eval_ood: true
eval_size:
- 512
- 1024
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001
weight_decay: 1e-4
lr_scheduler:
class_path: torch.optim.lr_scheduler.MultiStepLR
init_args:
milestones:
- 20
- 40
- 60
- 80
gamma: 0.5
1 change: 1 addition & 0 deletions experiments/segmentation/muad/configs/muad/unet/mimo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ data:
root: ./data
batch_size: 32
crop_size: 256
eval_ood: true
eval_size:
- 512
- 1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ data:
root: ./data
batch_size: 32
crop_size: 256
eval_ood: true
eval_size:
- 512
- 1024
Expand Down
2 changes: 2 additions & 0 deletions experiments/segmentation/muad/configs/muad/unet/standard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ model:
in_channels: 3
num_classes: 15
bilinear: true
dropout_rate: 0.5
num_classes: 15
loss:
class_path: torch.nn.CrossEntropyLoss
Expand Down Expand Up @@ -50,6 +51,7 @@ data:
root: ./data
batch_size: 32
crop_size: 256
eval_ood: true
eval_size:
- 512
- 1024
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# lightning.pytorch==2.1.3
seed_everything: false
eval_after_fit: true
trainer:
accelerator: gpu
devices: 1
max_epochs: 50
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: logs/muad-small/unet
name: standard
default_hp_metric: false
callbacks:
- class_path: torch_uncertainty.callbacks.TUSegCheckpoint
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
model:
model:
class_path: torch_uncertainty.models.mc_dropout
init_args:
model:
class_path: torch_uncertainty.models.segmentation.small_unet
init_args:
in_channels: 3
num_classes: 15
bilinear: true
dropout_rate: 0.5
num_estimators: 4
num_classes: 15
loss:
class_path: torch.nn.CrossEntropyLoss
init_args:
weight:
class_path: torch.Tensor
dict_kwargs:
data:
- 4.3817
- 19.7927
- 3.3011
- 48.8031
- 36.2141
- 33.0049
- 47.5130
- 48.8560
- 12.4401
- 48.0600
- 14.4807
- 30.8762
- 4.7467
- 19.3913
- 50.4984
data:
root: ./data
batch_size: 10
version: small
eval_ood: true
eval_size:
- 256
- 512
num_workers: 10
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1e-3
weight_decay: 2e-4
lr_scheduler:
class_path: torch.optim.lr_scheduler.StepLR
init_args:
step_size: 20
gamma: 0.1
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ data:
root: ./data
batch_size: 10
version: small
eval_ood: true
eval_size:
- 256
- 512
Expand Down
42 changes: 41 additions & 1 deletion torch_uncertainty/datamodules/segmentation/muad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
batch_size: int,
version: Literal["full", "small"] = "full",
eval_batch_size: int | None = None,
eval_ood: bool = False,
crop_size: _size_2_t = 1024,
eval_size: _size_2_t = (1024, 2048),
train_transform: nn.Module | None = None,
Expand All @@ -45,6 +46,9 @@
``full`` or ``small``. Defaults to ``full``.
eval_batch_size (int | None) : Number of samples per batch during evaluation (val
and test). Set to batch_size if None. Defaults to None.
eval_ood (bool): Whether to evaluate on the OOD dataset. Defaults to
``False``. If set to ``True``, the OOD dataset will be used for
evaluation in addition of the test dataset.
crop_size (sequence or int, optional): Desired input image and
segmentation mask sizes during training. If :attr:`crop_size` is an
int instead of sequence like :math:`(H, W)`, a square crop
Expand Down Expand Up @@ -137,9 +141,14 @@

self.dataset = MUAD
self.version = version
self.eval_ood = eval_ood
self.crop_size = _pair(crop_size)
self.eval_size = _pair(eval_size)

# FIXME: should be the same split names (update huggingface dataset)
self.test_split = "test" if version == "small" else "test_id"
self.ood_split = "ood" if version == "small" else "test_ood"

if train_transform is not None:
self.train_transform = train_transform
else:
Expand Down Expand Up @@ -212,6 +221,22 @@
self.dataset(
root=self.root, split="val", version=self.version, target_type="semantic", download=True
)
self.dataset(
root=self.root,
split=self.test_split,
version=self.version,
target_type="semantic",
download=True,
)

if self.eval_ood:
self.dataset(
root=self.root,
split=self.ood_split,
version=self.version,
target_type="semantic",
download=True,
)

def setup(self, stage: str | None = None) -> None:
if stage == "fit" or stage is None:
Expand Down Expand Up @@ -242,11 +267,26 @@
if stage == "test" or stage is None:
self.test = self.dataset(
root=self.root,
split="val",
split=self.test_split,
version=self.version,
target_type="semantic",
transforms=self.test_transform,
)
if self.eval_ood:
self.ood = self.dataset(

Check warning on line 276 in torch_uncertainty/datamodules/segmentation/muad.py

View check run for this annotation

Codecov / codecov/patch

torch_uncertainty/datamodules/segmentation/muad.py#L276

Added line #L276 was not covered by tests
root=self.root,
split=self.ood_split,
version=self.version,
target_type="semantic",
transforms=self.test_transform,
)

if stage not in ["fit", "test", None]:
raise ValueError(f"Stage {stage} is not supported.")

def test_dataloader(self) -> torch.utils.data.DataLoader:
"""Returns the test dataloader."""
dataloader = [self._data_loader(self.get_test_set(), training=False, shuffle=False)]
if self.eval_ood:
dataloader.append(self._data_loader(self.get_ood_set(), training=False, shuffle=False))

Check warning on line 291 in torch_uncertainty/datamodules/segmentation/muad.py

View check run for this annotation

Codecov / codecov/patch

torch_uncertainty/datamodules/segmentation/muad.py#L291

Added line #L291 was not covered by tests
return dataloader
3 changes: 1 addition & 2 deletions torch_uncertainty/datasets/muad.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class MUAD(VisionDataset):
"val": 492,
"test_id": 551,
"test_ood": 1668,
"test_id_no_shadow": 102,
"test_id_low_adv": 605,
"test_id_high_adv": 602,
"test_ood_low_adv": 1552,
Expand Down Expand Up @@ -168,7 +167,7 @@ def __init__(

if split not in self.huggingface_splits[version]:
raise ValueError(
f"split must be one of {self.huggingface_splits[version].keys()}. Got {split}."
f"split must be one of {self.huggingface_splits[version]}. Got {split}."
)
self.split = split
self.version = version
Expand Down
5 changes: 5 additions & 0 deletions torch_uncertainty/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,9 @@
SILog,
ThresholdAccuracy,
)
from .segmentation import (
SegmentationBinaryAUROC,
SegmentationBinaryAveragePrecision,
SegmentationFPR95,
)
from .sparsification import AUSE
4 changes: 4 additions & 0 deletions torch_uncertainty/metrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# ruff: noqa: F401
from .seg_binary_auroc import SegmentationBinaryAUROC
from .seg_binary_average_precision import SegmentationBinaryAveragePrecision
from .seg_fpr95 import SegmentationFPR95
42 changes: 42 additions & 0 deletions torch_uncertainty/metrics/segmentation/seg_binary_auroc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.classification import BinaryAUROC


class SegmentationBinaryAUROC(Metric):
is_differentiable = False
higher_is_better = True
full_state_update = False

def __init__(
self,
max_fpr: float | None = None,
thresholds: int | list[float] | Tensor | None = None,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: Any,
):
super().__init__(**kwargs)
self.auroc_metric = BinaryAUROC(

Check warning on line 23 in torch_uncertainty/metrics/segmentation/seg_binary_auroc.py

View check run for this annotation

Codecov / codecov/patch

torch_uncertainty/metrics/segmentation/seg_binary_auroc.py#L22-L23

Added lines #L22 - L23 were not covered by tests
max_fpr=max_fpr,
thresholds=thresholds,
ignore_index=ignore_index,
validate_args=validate_args,
**kwargs,
)
self.add_state("binary_auroc", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")

Check warning on line 31 in torch_uncertainty/metrics/segmentation/seg_binary_auroc.py

View check run for this annotation

Codecov / codecov/patch

torch_uncertainty/metrics/segmentation/seg_binary_auroc.py#L30-L31

Added lines #L30 - L31 were not covered by tests

def update(self, preds: Tensor, target: Tensor) -> None:
batch_size = preds.size(0)
auroc = self.auroc_metric(preds, target)
self.binary_auroc += auroc * batch_size
self.total += batch_size

Check warning on line 37 in torch_uncertainty/metrics/segmentation/seg_binary_auroc.py

View check run for this annotation

Codecov / codecov/patch

torch_uncertainty/metrics/segmentation/seg_binary_auroc.py#L34-L37

Added lines #L34 - L37 were not covered by tests

def compute(self) -> Tensor:
if self.total == 0:
return torch.tensor(0.0, device=self.binary_auroc.device)
return self.binary_auroc / self.total

Check warning on line 42 in torch_uncertainty/metrics/segmentation/seg_binary_auroc.py

View check run for this annotation

Codecov / codecov/patch

torch_uncertainty/metrics/segmentation/seg_binary_auroc.py#L41-L42

Added lines #L41 - L42 were not covered by tests
Loading