From 7c645bd7674ecd2c7b5d427aa597714219c92f12 Mon Sep 17 00:00:00 2001 From: fira7s Date: Wed, 12 Mar 2025 13:58:25 +0100 Subject: [PATCH 01/47] :hammer: make batch size and other parameters configurable in scaler fit method --- .../post_processing/calibration/scaler.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index a5398df2..96ea93e7 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -48,6 +48,9 @@ def __init__( def fit( self, calibration_set: Dataset, + batch_size: int = 32, + shuffle: bool = False, + drop_last: bool = False, save_logits: bool = False, progress: bool = True, ) -> None: @@ -55,6 +58,9 @@ def fit( Args: calibration_set (Dataset): Calibration dataset. + batch_size (int, optional): Batch size for the calibration dataset. Defaults to 32. + shuffle (bool, optional): Whether to shuffle the calibration dataset. Defaults to False. + drop_last (bool, optional): Whether to drop the last batch if it's smaller than batch_size. Defaults to False. save_logits (bool, optional): Whether to save the logits and labels. Defaults to False. progress (bool, optional): Whether to show a progress bar. @@ -67,7 +73,9 @@ def fit( logits_list = [] labels_list = [] - calibration_dl = DataLoader(calibration_set, batch_size=32, shuffle=False, drop_last=False) + calibration_dl = DataLoader( + calibration_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) with torch.no_grad(): for inputs, labels in tqdm(calibration_dl, disable=not progress): logits = self.model(inputs.to(self.device)) From 76b899fc63ccf370eba4c703e4e2d617844313b0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Mar 2025 13:12:11 +0100 Subject: [PATCH 02/47] :hammer: Improve memory efficiency & small fix --- .../post_processing/calibration/scaler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 96ea93e7..ed5d8e36 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -71,18 +71,18 @@ def fit( "Cannot fit a Scaler method without model. Call .set_model(model) first." ) - logits_list = [] - labels_list = [] + all_logits = [] + all_labels = [] calibration_dl = DataLoader( calibration_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) with torch.no_grad(): for inputs, labels in tqdm(calibration_dl, disable=not progress): logits = self.model(inputs.to(self.device)) - logits_list.append(logits) - labels_list.append(labels) - all_logits = torch.cat(logits_list).detach().to(self.device) - all_labels = torch.cat(labels_list).detach().to(self.device) + all_logits.append(logits) + all_labels.append(labels) + all_logits = torch.cat(all_logits).to(self.device) + all_labels = torch.cat(all_labels).to(self.device) optimizer = optim.LBFGS(self.temperature, lr=self.lr, max_iter=self.max_iter) @@ -95,8 +95,8 @@ def calib_eval() -> float: optimizer.step(calib_eval) self.trained = True if save_logits: - self.logits = logits - self.labels = labels + self.logits = all_logits + self.labels = all_labels @torch.no_grad() def forward(self, inputs: Tensor) -> Tensor: From 28085aeefea502a96324e0b95b4c737246ff180a Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Mar 2025 13:12:29 +0100 Subject: [PATCH 03/47] :hammer: Rename x to inputs --- torch_uncertainty/post_processing/abstract.py | 2 +- torch_uncertainty/post_processing/laplace.py | 4 ++-- torch_uncertainty/post_processing/mc_batch_norm.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py index 9c7908cc..5b4ae9a6 100644 --- a/torch_uncertainty/post_processing/abstract.py +++ b/torch_uncertainty/post_processing/abstract.py @@ -20,6 +20,6 @@ def fit(self, dataset: Dataset) -> None: @abstractmethod def forward( self, - x: Tensor, + inputs: Tensor, ) -> Tensor: pass diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 582a4bd8..9264da0a 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -86,6 +86,6 @@ def fit(self, dataset: Dataset) -> None: def forward( self, - x: Tensor, + inputs: Tensor, ) -> Tensor: - return self.la(x, pred_type=self.pred_type, link_approx=self.link_approx) + return self.la(inputs, pred_type=self.pred_type, link_approx=self.link_approx) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 0a1d250d..67c8dabb 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -93,14 +93,14 @@ def _est_forward(self, x: Tensor) -> Tensor: def forward( self, - x: Tensor, + inputs: Tensor, ) -> Tensor: if self.training: - return self.model(x) + return self.model(inputs) if not self.trained: raise RuntimeError("MCBatchNorm has not been trained. Call .fit() first.") self.reset_counters() - return torch.cat([self._est_forward(x) for _ in range(self.num_estimators)], dim=0) + return torch.cat([self._est_forward(inputs) for _ in range(self.num_estimators)], dim=0) def _convert(self) -> None: """Convert all BatchNorm2d layers to MCBatchNorm2d layers.""" From 55e677ab82e6a85649100d66ba2b7e6d437725c2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 12 Mar 2025 13:15:21 +0100 Subject: [PATCH 04/47] :shirt: Some more small improvements --- torch_uncertainty/post_processing/laplace.py | 2 +- torch_uncertainty/post_processing/mc_batch_norm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 9264da0a..7a918bd5 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -51,7 +51,7 @@ def __init__( Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. """ super().__init__() - if not laplace_installed: # coverage: ignore + if not laplace_installed: raise ImportError( "The laplace-torch library is not installed. Please install" "torch_uncertainty with the all option:" diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 67c8dabb..33dbf35d 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -50,7 +50,7 @@ def __init__( def _setup_model(self, model): _mcbn_checks(model, self.num_estimators, self.mc_batch_size, self.convert) - self.model = deepcopy(model) # Is it necessary? + self.model = deepcopy(model) # TODO: Is it necessary? self.model = self.model.eval() if self.convert: self._convert() From 06c5b7df1b0b045eb781fc84d11e75864995f95e Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 15:50:12 +0100 Subject: [PATCH 05/47] :bug: Fix corruption for non square images use ImageNet-C parameters --- torch_uncertainty/transforms/corruption.py | 134 +++++++++++---------- 1 file changed, 68 insertions(+), 66 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 0b28c73a..6b0237b3 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -10,18 +10,13 @@ else: # coverage: ignore cv2_installed = False +import math as m + import numpy as np import torch +from kornia.augmentation import RandomSaltAndPepperNoise from PIL import Image -if util.find_spec("skimage"): - from skimage.filters import gaussian - from skimage.util import random_noise - - skimage_installed = True -else: # coverage: ignore - skimage_installed = False - if util.find_spec("scipy"): from scipy.ndimage import map_coordinates from scipy.ndimage import zoom as scizoom @@ -40,7 +35,7 @@ ) if util.find_spec("kornia"): - from kornia.filters import motion_blur + from kornia.filters import gaussian_blur2d, motion_blur kornia_installed = True else: # coverage: ignore @@ -98,7 +93,7 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - self.scale = [0, 0.04, 0.06, 0.08, 0.09, 0.10][severity] + self.scale = [0.08, 0.12, 0.18, 0.26, 0.38][severity] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -114,7 +109,7 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - self.scale = [500, 250, 100, 75, 50][severity - 1] + self.scale = [60, 25, 12, 5, 3][severity - 1] def forward(self, img: Tensor): if self.severity == 0: @@ -130,21 +125,23 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - if not skimage_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.scale = [0, 0.01, 0.02, 0.03, 0.05, 0.07][severity] + self.aug = RandomSaltAndPepperNoise( + amount=[0.03, 0.06, 0.09, 0.17, 0.27][severity - 1], salt_vs_pepper=0.5, p=1 + ) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img return torch.clamp( - torch.as_tensor(random_noise(img, mode="s&p", amount=self.scale)), - torch.zeros(1), - torch.ones(1), - ) + input=torch.as_tensor(self.aug(img.unsqueeze(0))), + min=torch.zeros(1), + max=torch.ones(1), + ).squeeze(0) class DefocusBlur(TUCorruption): @@ -183,18 +180,18 @@ def forward(self, img: Tensor) -> Tensor: class GlassBlur(TUCorruption): # TODO: batch def __init__(self, severity: int) -> None: super().__init__(severity) - if not skimage_installed or not cv2_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.sigma = [0.05, 0.25, 0.4, 0.25, 0.4][severity - 1] - self.max_delta = 1 + self.sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] self.iterations = [1, 1, 1, 2, 2][severity - 1] + self.max_delta = [1, 2, 2, 3, 4][severity - 1] def forward(self, img: Tensor) -> Tensor: img_size = img.shape - img = torch.as_tensor(gaussian(img, sigma=self.sigma)) + img = gaussian_blur2d(img, kernel_size=self.sigma * 4, sigma=self.sigma) for _ in range(self.iterations): for h in range(img_size[0] - self.max_delta, self.max_delta, -1): for w in range(img_size[1] - self.max_delta, self.max_delta, -1): @@ -204,7 +201,7 @@ def forward(self, img: Tensor) -> Tensor: img[h_prime, w_prime], img[h, w], ) - return torch.clamp(torch.as_tensor(gaussian(img, sigma=self.sigma)), 0, 1) + return torch.clamp(gaussian_blur2d(img, kernel_size=self.sigma * 4, sigma=self.sigma), 0, 1) def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): @@ -227,7 +224,7 @@ def __init__(self, severity: int) -> None: Note: Originally, Hendrycks et al. used Gaussian motion blur. To remove the dependency with with `Wand` we changed the transform to a simpler motion blur and kept the values of - sigma as the new half kernel sizes. + sigma as the new kernel radius sizes. """ super().__init__(severity) self.rng = np.random.default_rng() @@ -255,20 +252,23 @@ def forward(self, img: Tensor) -> Tensor: def clipped_zoom(img, zoom_factor): - h = img.shape[0] + h, w = img.shape[:2] # ceil crop height(= crop width) - ch = int(np.ceil(h / zoom_factor)) + ceil_crop_height = int(np.ceil(h / zoom_factor)) + left_crop_width = int(np.ceil(w / zoom_factor)) - top = (h - ch) // 2 + top = (h - ceil_crop_height) // 2 + left = (w - left_crop_width) // 2 img = scizoom( - img[top : top + ch, top : top + ch], + img[top : top + ceil_crop_height, left : left + left_crop_width], (zoom_factor, zoom_factor, 1), order=1, ) # trim off any extra pixels trim_top = (img.shape[0] - h) // 2 + trim_left = (img.shape[1] - w) // 2 - return img[trim_top : trim_top + h, trim_top : trim_top + h] + return img[trim_top : trim_top + h, trim_left : trim_left + w] class ZoomBlur(TUCorruption): @@ -357,7 +357,7 @@ class Frost(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) self.rng = np.random.default_rng() - self.mix = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][severity - 1] + self.mix = [(1, 0.4), (0.8, 0.6), (0.7, 0.7), (0.65, 0.7), (0.6, 0.75)][severity - 1] self.frost_ds = FrostImages("./data", download=True, transform=ToTensor()) def forward(self, img: Tensor) -> Tensor: @@ -365,7 +365,7 @@ def forward(self, img: Tensor) -> Tensor: return img _, height, width = img.shape frost_img = RandomResizedCrop((height, width))( - self.frost_ds[self.rng.integers(low=0, high=4)] + self.frost_ds[self.rng.integers(low=0, high=5)] ) return torch.clamp(self.mix[0] * img + self.mix[1] * frost_img, 0, 1) @@ -422,29 +422,23 @@ def filldiamonds(): class Fog(TUCorruption): - def __init__(self, severity: int, size: int = 256) -> None: + def __init__(self, severity: int) -> None: super().__init__(severity) - if (size & (size - 1) == 0) and size != 0: - self.size = size - self.resize = Resize((size, size), InterpolationMode.BICUBIC) - else: - raise ValueError(f"Size must be a power of 2. Got {size}.") self.mix = [(1.5, 2), (2, 2), (2.5, 1.7), (2.5, 1.5), (3, 1.4)][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img _, height, width = img.shape - if height != width: - raise ValueError(f"Image must be square. Got {height}x{width}.") - img = self.resize(img) max_val = img.max() + random_height_map_size = int(2 ** (m.ceil(m.log2(max(height, width) - 1)))) fog = ( self.mix[0] - * plasma_fractal(height=height, width=width, wibbledecay=self.mix[1])[:height, :width] + * plasma_fractal( + height=random_height_map_size, width=random_height_map_size, wibbledecay=self.mix[1] + )[:height, :width] ) - final = torch.clamp((img + fog) * max_val / (max_val + self.mix[0]), 0, 1) - return Resize((height, width), InterpolationMode.BICUBIC)(final) + return torch.clamp((img + fog) * max_val / (max_val + self.mix[0]), 0, 1) class Brightness(IBrightness, TUCorruption): @@ -472,24 +466,26 @@ def forward(self, img: Tensor) -> Tensor | Image.Image: class Pixelate(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - self.quality = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1] + self.quality = [0.6, 0.5, 0.4, 0.3, 0.25][severity - 1] + self.to_pil = ToPILImage() + self.to_tensor = ToTensor() def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img _, height, width = img.shape - img = ToPILImage()(img) + img = self.to_pil(img) img = Resize( (int(height * self.quality), int(width * self.quality)), InterpolationMode.BOX, )(img) - return ToTensor()(Resize((height, width), InterpolationMode.BOX)(img)) + return self.to_tensor(Resize((height, width), InterpolationMode.BOX)(img)) class JPEGCompression(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - self.quality = [80, 65, 58, 50, 40][severity - 1] + self.quality = [25, 18, 15, 10, 7][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -551,27 +547,33 @@ def forward(self, img: Tensor) -> Tensor: borderMode=cv2.BORDER_REFLECT_101, ) + sigma = self.mix[1] * shape_size[0] dx = ( - gaussian( - self.rng.uniform(-1, 1, size=shape[:2]), - self.mix[1] * shape_size[0], - mode="reflect", - truncate=3, + ( + gaussian_blur2d( + torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])), + kernel_size=sigma * 3, + sigma=sigma, + ) + * self.mix[0] + * shape_size[0] ) - * self.mix[0] - * shape_size[0] - ).astype(np.float32) + .numpy() + .astype(np.float32)[..., np.newaxis] + ) dy = ( - gaussian( - self.rng.uniform(-1, 1, size=shape[:2]), - self.mix[1] * shape_size[0], - mode="reflect", - truncate=3, + ( + gaussian_blur2d( + torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])), + kernel_size=sigma * 3, + sigma=self.mix[1] * shape_size[0], + ) + * self.mix[0] + * shape_size[0] ) - * self.mix[0] - * shape_size[0] - ).astype(np.float32) - dx, dy = dx[..., np.newaxis], dy[..., np.newaxis] + .numpy() + .astype(np.float32)[..., np.newaxis] + ) x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) indices = ( @@ -590,7 +592,7 @@ def forward(self, img: Tensor) -> Tensor: class SpeckleNoise(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - self.scale = [0.06, 0.1, 0.12, 0.16, 0.2][severity - 1] + self.scale = [0.15, 0.2, 0.35, 0.45, 0.6][severity - 1] self.rng = np.random.default_rng() def forward(self, img: Tensor) -> Tensor: @@ -606,7 +608,7 @@ def forward(self, img: Tensor) -> Tensor: class GaussianBlur(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - if not skimage_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -617,7 +619,7 @@ def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img return torch.clamp( - torch.as_tensor(gaussian(img, sigma=self.sigma)), + gaussian_blur2d(img, kernel_size=self.sigma * 4, sigma=self.sigma), min=0, max=1, ) From 3ce9286840cdcbe42aac688d26cd9bcbcf55515c Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 15:50:58 +0100 Subject: [PATCH 06/47] :heavy_minus_sign: Remove scikit-image dependency --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e404e640..77dd1ee1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ experiments = [ "safetensors", ] image = [ - "scikit-image", "kornia", "h5py", "opencv-python", From d7bb9882cf4c5af6a232a0afb21c9d0e7df78626 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 15:52:45 +0100 Subject: [PATCH 07/47] :hammer: Start improving corrupted dataset --- torch_uncertainty/datasets/corrupted.py | 31 ++++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py index dbdff04a..96eb3d6b 100644 --- a/torch_uncertainty/datasets/corrupted.py +++ b/torch_uncertainty/datasets/corrupted.py @@ -1,11 +1,10 @@ from copy import deepcopy from pathlib import Path -from PIL import Image from torch import nn from torchvision.datasets import VisionDataset from torchvision.transforms import ToPILImage, ToTensor -from tqdm.auto import tqdm +from tqdm import tqdm, trange from tqdm.contrib.logging import logging_redirect_tqdm from torch_uncertainty.transforms.corruption import corruption_transforms @@ -18,6 +17,7 @@ def __init__( shift_severity: int, on_the_fly: bool = False, ) -> None: + """Generate the corrupted version of any VisionDataset.""" super().__init__() self.core_dataset = core_dataset if shift_severity <= 0: @@ -32,29 +32,38 @@ def __init__( self.root = Path(core_dataset.root) dataset_name = str(type(core_dataset)).split(".")[-1][:-2].lower() - self.root /= dataset_name + "_corrupted" - self.root /= f"severity_{self.shift_severity}" - self.root.mkdir(parents=True) + self.root /= dataset_name + "-corrupted" + self.root /= f"severity-{self.shift_severity}" + self.root.mkdir(parents=True, exist_ok=True) if not on_the_fly: self.to_tensor = ToTensor() self.to_pil = ToPILImage() self.samples = [] - self.targets = self.core_dataset.targets * 10 + if hasattr(self.core_dataset, "targets"): + self.targets = self.core_dataset.targets + elif hasattr(self.core_dataset, "labels"): + self.targets = self.core_dataset.labels + elif hasattr(self.core_dataset, "_labels"): + self.targets = self.core_dataset._labels + else: + raise ValueError("The dataset should implement either targets, labels, or _labels.") + + self.targets = self.targets * len(corruption_transforms) self.prepare_data() def prepare_data(self): with logging_redirect_tqdm(): - for corruption in tqdm(corruption_transforms): + pbar = tqdm(corruption_transforms) + for corruption in pbar: corruption_name = corruption.__name__.lower() - (self.root / corruption_name).mkdir(parents=True) + pbar.set_description(f"Processing {corruption.__name__}") + (self.root / corruption_name).mkdir(parents=True, exist_ok=True) self.save_corruption(self.root / corruption_name, corruption(self.shift_severity)) def save_corruption(self, root: Path, corruption: nn.Module) -> None: - for i in range(self.core_length): + for i in trange(self.core_length, leave=False): img, tgt = self.core_dataset[i] - if isinstance(img, str | Path): - img = Image.open(img).convert("RGB") img = corruption(self.to_tensor(img)) self.to_pil(img).save(root / f"{i}.png") self.samples.append(root / f"{i}.png") From 3bef45b778fd1d3b8af6a44481418c6acb946986 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 16:02:51 +0100 Subject: [PATCH 08/47] :bug: Follow Hendryck more closely for Impulse noise --- torch_uncertainty/transforms/corruption.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 6b0237b3..362b108d 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -118,11 +118,13 @@ def forward(self, img: Tensor): class ImpulseNoise(TUCorruption): - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, black_white: bool = False) -> None: """Add impulse noise to an image. Args: severity (int): Severity level of the corruption. + black_white (bool): If black and white, set all pixel channel values to 0 or 1. + Defaults to ``False`` (as in the original paper). """ super().__init__(severity) if not kornia_installed: @@ -131,17 +133,23 @@ def __init__(self, severity: int) -> None: """pip install -U "torch_uncertainty[image]".""" ) self.aug = RandomSaltAndPepperNoise( - amount=[0.03, 0.06, 0.09, 0.17, 0.27][severity - 1], salt_vs_pepper=0.5, p=1 + amount=[0.03, 0.06, 0.09, 0.17, 0.27][severity - 1], + salt_vs_pepper=0.5, + p=1, + same_on_batch=False, ) + self.black_white = black_white def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - return torch.clamp( - input=torch.as_tensor(self.aug(img.unsqueeze(0))), + img = img.unsqueeze(0) if self.black_white else img.unsqueeze(1) + img = torch.clamp( + input=torch.as_tensor(self.aug(img)), min=torch.zeros(1), max=torch.ones(1), - ).squeeze(0) + ) + return img.squeeze(0) if self.black_white else img.squeeze(1) class DefocusBlur(TUCorruption): @@ -638,8 +646,8 @@ def forward(self, img: Tensor) -> Tensor: corruption_transforms = ( - GaussianNoise, - ShotNoise, + # GaussianNoise, + # ShotNoise, ImpulseNoise, DefocusBlur, GlassBlur, From a80d06ac6ab90b931e1b685b96b3f987e0b344b8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 16:37:41 +0100 Subject: [PATCH 09/47] :bug: Continue improving corruption and dataset --- torch_uncertainty/datasets/corrupted.py | 28 ++-- torch_uncertainty/transforms/corruption.py | 164 ++++++++++++++++++--- 2 files changed, 155 insertions(+), 37 deletions(-) diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py index 96eb3d6b..f774eeb8 100644 --- a/torch_uncertainty/datasets/corrupted.py +++ b/torch_uncertainty/datasets/corrupted.py @@ -30,10 +30,8 @@ def __init__( self.core_dataset.transform = None self.core_dataset.target_transform = None - self.root = Path(core_dataset.root) dataset_name = str(type(core_dataset)).split(".")[-1][:-2].lower() - self.root /= dataset_name + "-corrupted" - self.root /= f"severity-{self.shift_severity}" + self.root = Path(core_dataset.root) / (dataset_name + "-C") self.root.mkdir(parents=True, exist_ok=True) if not on_the_fly: @@ -56,17 +54,22 @@ def prepare_data(self): with logging_redirect_tqdm(): pbar = tqdm(corruption_transforms) for corruption in pbar: - corruption_name = corruption.__name__.lower() - pbar.set_description(f"Processing {corruption.__name__}") - (self.root / corruption_name).mkdir(parents=True, exist_ok=True) - self.save_corruption(self.root / corruption_name, corruption(self.shift_severity)) + corruption_name = corruption.name + pbar.set_description(f"Processing {corruption_name}") + (self.root / corruption_name / f"{self.shift_severity}").mkdir( + parents=True, exist_ok=True + ) + self.save_corruption( + self.root / corruption_name / f"{self.shift_severity}", + corruption(self.shift_severity), + ) def save_corruption(self, root: Path, corruption: nn.Module) -> None: for i in trange(self.core_length, leave=False): img, tgt = self.core_dataset[i] img = corruption(self.to_tensor(img)) - self.to_pil(img).save(root / f"{i}.png") - self.samples.append(root / f"{i}.png") + self.to_pil(img).save(root / f"{i}.jpg") + self.samples.append(root / f"{i}.jpg") self.targets.append(tgt) def __len__(self): @@ -103,8 +106,7 @@ def __getitem__(self, idx: int): if __name__ == "__main__": - from torchvision.datasets import CIFAR10 + from torchvision.datasets import OxfordIIITPet - dataset = CIFAR10(root="data", download=True) - corrupted_dataset = CorruptedDataset(dataset, shift_severity=1) - print(len(corrupted_dataset)) + dataset = OxfordIIITPet(root="data", split="test", download=True) + corrupted_dataset = CorruptedDataset(dataset, shift_severity=5) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 362b108d..68398e60 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -72,7 +72,7 @@ class TUCorruption(nn.Module): def __init__(self, severity: int) -> None: - """Base class for corruptions.""" + """Base class for corruption transforms.""" super().__init__() if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") @@ -86,14 +86,16 @@ def __repr__(self) -> str: class GaussianNoise(TUCorruption): + name = "gaussian_noise" + def __init__(self, severity: int) -> None: - """Add Gaussian noise to an image. + """Apply a Gaussian noise corruption to unbatched tensor images. Args: severity (int): Severity level of the corruption. """ super().__init__(severity) - self.scale = [0.08, 0.12, 0.18, 0.26, 0.38][severity] + self.scale = [0.08, 0.12, 0.18, 0.26, 0.38][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -102,8 +104,10 @@ def forward(self, img: Tensor) -> Tensor: class ShotNoise(TUCorruption): + name = "shot_noise" + def __init__(self, severity: int) -> None: - """Add shot noise to an image. + """Apply a shot (Poisson) noise corruption to unbatched tensor images. Args: severity (int): Severity level of the corruption. @@ -118,8 +122,11 @@ def forward(self, img: Tensor): class ImpulseNoise(TUCorruption): + name = "impulse_noise" + def __init__(self, severity: int, black_white: bool = False) -> None: - """Add impulse noise to an image. + """Apply an impulse (channel-independent Salt & Pepper) noise corruption to unbatched + tensor images. Args: severity (int): Severity level of the corruption. @@ -153,8 +160,10 @@ def forward(self, img: Tensor) -> Tensor: class DefocusBlur(TUCorruption): + name = "defocus_blur" + def __init__(self, severity: int) -> None: - """Add defocus blur to an image. + """Apply a defocus blur corruption to unbatched tensor images. Args: severity (int): Severity level of the corruption. @@ -186,30 +195,47 @@ def forward(self, img: Tensor) -> Tensor: class GlassBlur(TUCorruption): # TODO: batch + name = "glass_blur" + def __init__(self, severity: int) -> None: + """Apply a glass blur corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] + sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] + self.sigma = (sigma, sigma) + self.kernel_size = m.ceil(sigma * 4) self.iterations = [1, 1, 1, 2, 2][severity - 1] self.max_delta = [1, 2, 2, 3, 4][severity - 1] def forward(self, img: Tensor) -> Tensor: img_size = img.shape - img = gaussian_blur2d(img, kernel_size=self.sigma * 4, sigma=self.sigma) + img = gaussian_blur2d( + img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma + ).squeeze(0) for _ in range(self.iterations): - for h in range(img_size[0] - self.max_delta, self.max_delta, -1): - for w in range(img_size[1] - self.max_delta, self.max_delta, -1): + for h in range(img_size[1] - self.max_delta, self.max_delta, -1): + for w in range(img_size[2] - self.max_delta, self.max_delta, -1): dx, dy = torch.randint(-self.max_delta, self.max_delta, size=(2,)) h_prime, w_prime = h + dy, w + dx - img[h, w], img[h_prime, w_prime] = ( - img[h_prime, w_prime], - img[h, w], + img[:, h, w], img[:, h_prime, w_prime] = ( + img[:, h_prime, w_prime], + img[:, h, w], ) - return torch.clamp(gaussian_blur2d(img, kernel_size=self.sigma * 4, sigma=self.sigma), 0, 1) + return torch.clamp( + gaussian_blur2d( + img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma + ).squeeze(0), + 0, + 1, + ) def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): @@ -226,8 +252,13 @@ def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): class MotionBlur(TUCorruption): + name = "motion_blur" + def __init__(self, severity: int) -> None: - """Apply a motion blur corruption on the image. + """Apply a motion blur corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. Note: Originally, Hendrycks et al. used Gaussian motion blur. To remove the dependency with @@ -280,7 +311,14 @@ def clipped_zoom(img, zoom_factor): class ZoomBlur(TUCorruption): + name = "zoom_blur" + def __init__(self, severity: int) -> None: + """Apply a zoom blur corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) self.zooms = [ np.arange(1, 1.11, 0.01), @@ -308,8 +346,13 @@ def forward(self, img: Tensor) -> Tensor: class Snow(TUCorruption): + name = "snow" + def __init__(self, severity: int) -> None: - """Apply a snow effect on the image. + """Apply a snow effect on unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. Note: The transformation has been slightly modified, see MotionBlur for details. @@ -362,7 +405,14 @@ def forward(self, img: Tensor) -> Tensor: class Frost(TUCorruption): + name = "frost" + def __init__(self, severity: int) -> None: + """Apply a frost corruption effect on unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) self.rng = np.random.default_rng() self.mix = [(1, 0.4), (0.8, 0.6), (0.7, 0.7), (0.65, 0.7), (0.6, 0.75)][severity - 1] @@ -430,7 +480,14 @@ def filldiamonds(): class Fog(TUCorruption): + name = "fog" + def __init__(self, severity: int) -> None: + """Apply a fog corruption effect on unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) self.mix = [(1.5, 2), (2, 2), (2.5, 1.7), (2.5, 1.5), (3, 1.4)][severity - 1] @@ -450,7 +507,14 @@ def forward(self, img: Tensor) -> Tensor: class Brightness(IBrightness, TUCorruption): + name = "brightness" + def __init__(self, severity: int) -> None: + """Apply a brightness corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ TUCorruption.__init__(self, severity) self.level = [1.1, 1.2, 1.3, 1.4, 1.5][severity - 1] @@ -461,7 +525,14 @@ def forward(self, img: Tensor) -> Tensor: class Contrast(IContrast, TUCorruption): + name = "contrast" + def __init__(self, severity: int) -> None: + """Apply a contrast corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ TUCorruption.__init__(self, severity) self.level = [0.4, 0.3, 0.2, 0.1, 0.05][severity - 1] @@ -472,7 +543,14 @@ def forward(self, img: Tensor) -> Tensor | Image.Image: class Pixelate(TUCorruption): + name = "pixelate" + def __init__(self, severity: int) -> None: + """Apply a pixelation corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) self.quality = [0.6, 0.5, 0.4, 0.3, 0.25][severity - 1] self.to_pil = ToPILImage() @@ -491,7 +569,14 @@ def forward(self, img: Tensor) -> Tensor: class JPEGCompression(TUCorruption): + name = "jpeg_compression" + def __init__(self, severity: int) -> None: + """Apply a JPEG compression corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) self.quality = [25, 18, 15, 10, 7][severity - 1] @@ -504,7 +589,14 @@ def forward(self, img: Tensor) -> Tensor: class Elastic(TUCorruption): + name = "elastic" + def __init__(self, severity: int) -> None: + """Apply an elastic corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) if not cv2_installed or not scipy_installed: raise ImportError( @@ -559,10 +651,10 @@ def forward(self, img: Tensor) -> Tensor: dx = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])), + torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])).unsqueeze(0), kernel_size=sigma * 3, - sigma=sigma, - ) + sigma=(sigma, sigma), + ).squeeze(0) * self.mix[0] * shape_size[0] ) @@ -572,10 +664,10 @@ def forward(self, img: Tensor) -> Tensor: dy = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])), + torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])).unsqueeze(0), kernel_size=sigma * 3, - sigma=self.mix[1] * shape_size[0], - ) + sigma=(sigma, sigma), + ).squeeze(0) * self.mix[0] * shape_size[0] ) @@ -597,8 +689,18 @@ def forward(self, img: Tensor) -> Tensor: return torch.as_tensor(img).permute(2, 0, 1) +# Additional corruption transforms + + class SpeckleNoise(TUCorruption): + name = "speckle_noise" + def __init__(self, severity: int) -> None: + """Apply speckle noise to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) self.scale = [0.15, 0.2, 0.35, 0.45, 0.6][severity - 1] self.rng = np.random.default_rng() @@ -614,7 +716,14 @@ def forward(self, img: Tensor) -> Tensor: class GaussianBlur(TUCorruption): + name = "gaussian_blur" + def __init__(self, severity: int) -> None: + """Apply a Gaussian blur corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) if not kornia_installed: raise ImportError( @@ -634,7 +743,14 @@ def forward(self, img: Tensor) -> Tensor: class Saturation(ISaturation, TUCorruption): + name = "saturation" + def __init__(self, severity: int) -> None: + """Apply a saturation corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ TUCorruption.__init__(self, severity) self.severity = severity self.level = [0.1, 0.2, 0.3, 0.4, 0.5][severity - 1] @@ -646,8 +762,8 @@ def forward(self, img: Tensor) -> Tensor: corruption_transforms = ( - # GaussianNoise, - # ShotNoise, + GaussianNoise, + ShotNoise, ImpulseNoise, DefocusBlur, GlassBlur, From 42a38285b3d9e6bcd8132145655be636e06422b2 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 16:59:23 +0100 Subject: [PATCH 10/47] :bug: Continue fixing transforms --- torch_uncertainty/transforms/corruption.py | 36 +++++++++++++++------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 68398e60..a1f0f202 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -211,7 +211,7 @@ def __init__(self, severity: int) -> None: ) sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] self.sigma = (sigma, sigma) - self.kernel_size = m.ceil(sigma * 4) + self.kernel_size = int(sigma * 4 // 2 * 2 + 1) self.iterations = [1, 1, 1, 2, 2][severity - 1] self.max_delta = [1, 2, 2, 3, 4][severity - 1] @@ -596,6 +596,9 @@ def __init__(self, severity: int) -> None: Args: severity (int): Severity level of the corruption. + + Note: + mix[0][1] has been changed to 0.5 to avoid errors when dealing with small images. """ super().__init__(severity) if not cv2_installed or not scipy_installed: @@ -606,7 +609,7 @@ def __init__(self, severity: int) -> None: # The following pertubation values are based on the original repo but # are quite strange, notably for the severities 3 and 4 self.mix = [ - (2, 0.7, 0.1), + (2, 0.5, 0.1), (2, 0.08, 0.2), (0.05, 0.01, 0.02), (0.07, 0.01, 0.02), @@ -648,13 +651,18 @@ def forward(self, img: Tensor) -> Tensor: ) sigma = self.mix[1] * shape_size[0] + ks = int((sigma * 3 // 2) * 2 + 1) dx = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])).unsqueeze(0), - kernel_size=sigma * 3, + torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])) + .unsqueeze(0) + .unsqueeze(0), + kernel_size=ks, sigma=(sigma, sigma), - ).squeeze(0) + ) + .squeeze(0) + .squeeze(0) * self.mix[0] * shape_size[0] ) @@ -664,10 +672,14 @@ def forward(self, img: Tensor) -> Tensor: dy = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])).unsqueeze(0), - kernel_size=sigma * 3, + torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])) + .unsqueeze(0) + .unsqueeze(0), + kernel_size=ks, sigma=(sigma, sigma), - ).squeeze(0) + ) + .squeeze(0) + .squeeze(0) * self.mix[0] * shape_size[0] ) @@ -730,16 +742,18 @@ def __init__(self, severity: int) -> None: "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.sigma = [0.4, 0.6, 0.7, 0.8, 1.0][severity - 1] + sigma = [0.4, 0.6, 0.7, 0.8, 1.0][severity - 1] + self.sigma = (sigma, sigma) + self.kernel_size = int(sigma // 2 * 2 * 4 + 1) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img return torch.clamp( - gaussian_blur2d(img, kernel_size=self.sigma * 4, sigma=self.sigma), + gaussian_blur2d(img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma), min=0, max=1, - ) + ).squeeze(0) class Saturation(ISaturation, TUCorruption): From 95f445b34dde42bc0f71374df5e76447a765fa31 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 17:02:03 +0100 Subject: [PATCH 11/47] :white_check_mark: Fix corruption test --- tests/transforms/test_corruption.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index ce6a76cc..9076648f 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -135,18 +135,11 @@ def test_snow(self): def test_fog(self): inputs = torch.rand(3, 32, 32) - transform = Fog(1, size=32) + transform = Fog(1) transform(inputs) - - with pytest.raises(ValueError, match="Image must be square. Got "): - transform(torch.rand(3, 32, 12)) - - transform = Fog(0, size=32) + transform = Fog(0) transform(inputs) - with pytest.raises(ValueError, match="Size must be a power of 2. Got "): - _ = Fog(1, size=15) - def test_brightness(self): inputs = torch.rand(3, 32, 32) transform = Brightness(1) From 6e56109d8ebcc114389ddef43b0a9dad580a1a73 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 14 Mar 2025 17:24:34 +0100 Subject: [PATCH 12/47] :zap: Update cls dms to v2 --- tests/_dummies/datamodule.py | 14 ++++---- .../datamodules/classification/cifar10.py | 29 ++++++++--------- .../datamodules/classification/cifar100.py | 25 ++++++++------- .../datamodules/classification/imagenet.py | 29 +++++++++-------- .../datamodules/classification/mnist.py | 32 +++++++++++-------- .../classification/tiny_imagenet.py | 25 ++++++++------- 6 files changed, 81 insertions(+), 73 deletions(-) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 66555c6c..8da6b391 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -2,10 +2,10 @@ import numpy as np import torch -import torchvision.transforms.v2 as T from numpy.typing import ArrayLike from torch.utils.data import DataLoader from torchvision import tv_tensors +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule @@ -52,8 +52,8 @@ def __init__( self.ood_dataset = DummyClassificationDataset self.shift_dataset = DummyClassificationDataset - self.train_transform = T.ToTensor() - self.test_transform = T.ToTensor() + self.train_transform = v2.ToTensor() + self.test_transform = v2.ToTensor() def prepare_data(self) -> None: pass @@ -207,7 +207,7 @@ def __init__( self.dataset = DummySegmentationDataset - self.train_transform = T.ToDtype( + self.train_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -215,7 +215,7 @@ def __init__( }, scale=True, ) - self.test_transform = T.ToDtype( + self.test_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -296,7 +296,7 @@ def __init__( self.dataset = DummPixelRegressionDataset - self.train_transform = T.ToDtype( + self.train_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.float32, @@ -304,7 +304,7 @@ def __init__( }, scale=True, ) - self.test_transform = T.ToDtype( + self.test_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.float32, diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 1e1441c2..17461976 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -2,12 +2,13 @@ from typing import Literal import numpy as np -import torchvision.transforms as T +import torch from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10, SVHN +from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset @@ -100,10 +101,10 @@ def __init__( ) if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), + v2.RandomCrop(32, padding=4), + v2.RandomHorizontalFlip(), ] ) else: @@ -116,25 +117,21 @@ def __init__( else: main_transform = nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), + v2.ToDtype(torch.float32), basic_transform, main_transform, - T.Normalize( - self.mean, - self.std, - ), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Normalize( - self.mean, - self.std, - ), + v2.ToImage(), + v2.ToDtype(torch.float32), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 11d0a7fa..520e0498 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -3,12 +3,12 @@ import numpy as np import torch -import torchvision.transforms as T from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import AggregatedDataset @@ -94,10 +94,10 @@ def __init__( ) if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), + v2.RandomCrop(32, padding=4), + v2.RandomHorizontalFlip(), ] ) else: @@ -106,25 +106,26 @@ def __init__( if cutout: main_transform = Cutout(cutout) elif randaugment: - main_transform = T.RandAugment(num_ops=2, magnitude=20) + main_transform = v2.RandAugment(num_ops=2, magnitude=20) elif auto_augment: main_transform = rand_augment_transform(auto_augment, {}) else: main_transform = nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), + v2.ToDtype(torch.float32), basic_transform, main_transform, - T.ConvertImageDtype(torch.float32), - T.Normalize(mean=self.mean, std=self.std), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.ToDtype(torch.float32), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 24fd5c6d..6f2ad4b9 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -2,13 +2,14 @@ from pathlib import Path from typing import Literal -import torchvision.transforms as T +import torch import yaml from timm.data.auto_augment import rand_augment_transform from timm.data.mixup import Mixup from torch import nn from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import ( @@ -137,10 +138,10 @@ def __init__( self.procedure = procedure if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomResizedCrop(train_size, interpolation=self.interpolation), - T.RandomHorizontalFlip(), + v2.RandomResizedCrop(train_size, interpolation=self.interpolation), + v2.RandomHorizontalFlip(), ] ) else: @@ -153,7 +154,7 @@ def __init__( main_transform = nn.Identity() elif self.procedure == "ViT": train_size = 224 - main_transform = T.Compose( + main_transform = v2.Compose( [ Mixup(mixup_alpha=0.2, cutmix_alpha=1.0), rand_augment_transform("rand-m9-n2-mstd0.5", {}), @@ -165,21 +166,23 @@ def __init__( else: raise ValueError("The procedure is unknown") - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), + v2.ToDtype(torch.float32), basic_transform, main_transform, - T.Normalize(mean=self.mean, std=self.std), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Resize(256, interpolation=self.interpolation), - T.CenterCrop(224), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.ToDtype(torch.float32), + v2.Resize(256, interpolation=self.interpolation), + v2.CenterCrop(224), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index f6879c1a..62442bb5 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -1,10 +1,11 @@ from pathlib import Path from typing import Literal -import torchvision.transforms as T +import torch from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST @@ -82,32 +83,35 @@ def __init__( self.shift_dataset = MNISTC self.shift_severity = 1 - basic_transform = T.RandomCrop(28, padding=4) if basic_augment else nn.Identity() + basic_transform = v2.RandomCrop(28, padding=4) if basic_augment else nn.Identity() main_transform = Cutout(cutout) if cutout else nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), + v2.ToDtype(torch.float32), basic_transform, main_transform, - T.Normalize(mean=self.mean, std=self.std), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.CenterCrop(28), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.ToDtype(torch.float32), + v2.CenterCrop(28), + v2.Normalize(mean=self.mean, std=self.std), ] ) if self.eval_ood: # NotMNIST has 3 channels - self.ood_transform = T.Compose( + self.ood_transform = v2.Compose( [ - T.ToTensor(), - T.Grayscale(num_output_channels=1), - T.CenterCrop(28), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.ToDtype(torch.float32), + v2.Grayscale(num_output_channels=1), + v2.CenterCrop(28), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index bf95159e..937c7b63 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -2,12 +2,13 @@ from typing import Literal import numpy as np -import torchvision.transforms as T +import torch from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import ( @@ -71,10 +72,10 @@ def __init__( raise ValueError(f"OOD dataset {ood_ds} not supported for TinyImageNet.") self.shift_dataset = TinyImageNetC if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomCrop(64, padding=4), - T.RandomHorizontalFlip(), + v2.RandomCrop(64, padding=4), + v2.RandomHorizontalFlip(), ] ) else: @@ -85,20 +86,22 @@ def __init__( else: main_transform = nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), + v2.ToDtype(torch.float32), basic_transform, main_transform, - T.Normalize(mean=self.mean, std=self.std), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Resize(64, interpolation=self.interpolation), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.ToDtype(torch.float32), + v2.Resize(64, interpolation=self.interpolation), + v2.Normalize(mean=self.mean, std=self.std), ] ) From d8c6a104d702e65a4a832c3b334f2bb5cb6a870f Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 14 Mar 2025 22:51:37 +0100 Subject: [PATCH 13/47] :hammer: Update transforms in dms for cls and seg --- torch_uncertainty/datamodules/classification/cifar10.py | 4 ++-- torch_uncertainty/datamodules/classification/cifar100.py | 4 ++-- torch_uncertainty/datamodules/classification/imagenet.py | 4 ++-- torch_uncertainty/datamodules/classification/mnist.py | 6 +++--- .../datamodules/classification/tiny_imagenet.py | 4 ++-- torch_uncertainty/datamodules/segmentation/camvid.py | 2 -- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 17461976..4aa12bfd 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -120,9 +120,9 @@ def __init__( self.train_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), basic_transform, main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) @@ -130,7 +130,7 @@ def __init__( self.test_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 520e0498..70224847 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -115,16 +115,16 @@ def __init__( self.train_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), basic_transform, main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 6f2ad4b9..157bd8a2 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -169,9 +169,9 @@ def __init__( self.train_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), basic_transform, main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) @@ -179,9 +179,9 @@ def __init__( self.test_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), v2.Resize(256, interpolation=self.interpolation), v2.CenterCrop(224), + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index 62442bb5..ed31116a 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -90,17 +90,17 @@ def __init__( self.train_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), basic_transform, main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) self.test_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), v2.CenterCrop(28), + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) @@ -108,9 +108,9 @@ def __init__( self.ood_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), v2.Grayscale(num_output_channels=1), v2.CenterCrop(28), + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 937c7b63..b8aa34e5 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -89,9 +89,9 @@ def __init__( self.train_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), basic_transform, main_transform, + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) @@ -99,8 +99,8 @@ def __init__( self.test_transform = v2.Compose( [ v2.ToImage(), - v2.ToDtype(torch.float32), v2.Resize(64, interpolation=self.interpolation), + v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index 08fd432a..66d774b8 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -125,7 +125,6 @@ def __init__( self.train_transform = v2.Compose( [ - v2.ToImage(), basic_transform, v2.ToDtype( dtype={ @@ -140,7 +139,6 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.ToImage(), v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ From d98b8d9601aa3ff2b92dbc6e971858d1a5a82c94 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 15 Mar 2025 13:44:45 +0100 Subject: [PATCH 14/47] :bug: Finish fixing corruptions --- torch_uncertainty/transforms/corruption.py | 41 +++++++++++----------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index a1f0f202..6e01e75a 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -359,11 +359,11 @@ def __init__(self, severity: int) -> None: """ super().__init__(severity) self.mix = [ - (0.1, 0.3, 3, 0.5, 4, 0.8), - (0.2, 0.3, 2, 0.5, 4, 0.7), - (0.55, 0.3, 4, 0.9, 8, 0.7), - (0.55, 0.3, 4.5, 0.85, 8, 0.65), - (0.55, 0.3, 2.5, 0.85, 12, 0.55), + (0.1, 3, 0.5, 4, 0.8), + (0.2, 2, 0.5, 4, 0.7), + (0.55, 4, 0.9, 8, 0.7), + (0.55, 4.5, 0.85, 8, 0.65), + (0.55, 2.5, 0.85, 12, 0.55), ][severity - 1] self.rng = np.random.default_rng() @@ -377,31 +377,32 @@ def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img _, height, width = img.shape - x = img.numpy() - snow_layer = self.rng.normal(size=x.shape[1:], loc=self.mix[0], scale=self.mix[1])[ - ..., np.newaxis - ] - snow_layer = clipped_zoom(snow_layer, self.mix[2]) - snow_layer[snow_layer < self.mix[3]] = 0 + x = img.permute(1, 2, 0).numpy() + snow_layer = self.rng.normal(size=x.shape[:2], loc=self.mix[0], scale=0.3)[..., np.newaxis] + snow_layer = clipped_zoom(snow_layer, self.mix[1]) + snow_layer[snow_layer < self.mix[2]] = 0 snow_layer = np.clip(snow_layer.squeeze(), 0, 1) snow_layer = ( motion_blur( torch.as_tensor(snow_layer).unsqueeze(0).unsqueeze(0), - kernel_size=self.mix[4] * 2 + 1, + kernel_size=self.mix[3] * 2 + 1, angle=self.rng.uniform(-135, -45), direction=0, ) .squeeze(0) + .squeeze(0) + .unsqueeze(-1) .numpy() ) - - x = self.mix[5] * x + (1 - self.mix[5]) * np.maximum( + x = self.mix[4] * x + (1 - self.mix[4]) * np.maximum( x, - cv2.cvtColor(x.transpose([1, 2, 0]), cv2.COLOR_RGB2GRAY).reshape(1, height, width) * 1.5 - + 0.5, + cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(height, width, 1) * 1.5 + 0.5, ) - return torch.clamp(torch.as_tensor(x + snow_layer + np.rot90(snow_layer, k=2)), 0, 1) + + return torch.clamp( + torch.as_tensor(x + snow_layer + np.rot90(snow_layer, k=2)), 0, 1 + ).permute(2, 0, 1) class Frost(TUCorruption): @@ -516,7 +517,7 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ TUCorruption.__init__(self, severity) - self.level = [1.1, 1.2, 1.3, 1.4, 1.5][severity - 1] + self.level = [1.3, 1.6, 1.9, 2.2, 2.5][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -606,8 +607,6 @@ def __init__(self, severity: int) -> None: "Please install torch_uncertainty with the all option:" """pip install -U "torch_uncertainty[all]".""" ) - # The following pertubation values are based on the original repo but - # are quite strange, notably for the severities 3 and 4 self.mix = [ (2, 0.5, 0.1), (2, 0.08, 0.2), @@ -651,7 +650,7 @@ def forward(self, img: Tensor) -> Tensor: ) sigma = self.mix[1] * shape_size[0] - ks = int((sigma * 3 // 2) * 2 + 1) + ks = min(int((sigma * 3 // 2) * 2 + 1), min(shape_size[:2]) // 2 * 2 - 1) dx = ( ( gaussian_blur2d( From dba4bcfc1bded1eb76005338a2a918cd8e4217bf Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 15 Mar 2025 14:01:07 +0100 Subject: [PATCH 15/47] :fire: Temporarily remove tj-actions/changed-files@v42 --- .github/workflows/run-tests.yml | 44 +++++++++++++++++---------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 8435ac79..09095375 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -36,23 +36,23 @@ jobs: echo "PYTHON_VERSION=$(python -c "import platform; print(platform.python_version())")" echo "PYTHON_VERSION=$(python -c "import platform; print(platform.python_version())")" >> $GITHUB_ENV - - name: Get changed files - id: changed-files-specific - uses: tj-actions/changed-files@v42 - with: - files: | - auto_tutorials_source/** - data/** - experiments/** - docs/** - *.md - *.yaml - *.yml - LICENSE - .gitignore + # - name: Get changed files + # id: changed-files-specific + # uses: tj-actions/changed-files@v42 + # with: + # files: | + # auto_tutorials_source/** + # data/** + # experiments/** + # docs/** + # *.md + # *.yaml + # *.yml + # LICENSE + # .gitignore - name: Cache folder for TorchUncertainty - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' uses: actions/cache@v4 id: cache-folder with: @@ -61,24 +61,25 @@ jobs: key: torch-uncertainty-${{ runner.os }} - name: Install dependencies - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[all] - name: Check style & format - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m ruff check torch_uncertainty --no-fix --statistics python3 -m ruff format torch_uncertainty --check - name: Test with pytest and compute coverage - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m pytest --cov --cov-report xml --durations 10 --junitxml=junit.xml - name: Upload coverage to Codecov - if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + # if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + if: github.event_name != 'pull_request' || github.base_ref == 'dev' uses: codecov/codecov-action@v4 continue-on-error: true with: @@ -89,7 +90,8 @@ jobs: env_vars: PYTHON_VERSION - name: Upload test results to Codecov - if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + # if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + if: github.event_name != 'pull_request' || github.base_ref == 'dev' uses: codecov/test-results-action@v1 continue-on-error: true with: @@ -98,6 +100,6 @@ jobs: env_vars: PYTHON_VERSION - name: Test sphinx build without tutorials - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | cd docs && make clean && make html-noplot From f4a59e0adb25bb817ada1960434396d855bca08f Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 15 Mar 2025 15:38:06 +0100 Subject: [PATCH 16/47] :book: Add some documentation --- torch_uncertainty/transforms/corruption.py | 25 ++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 6e01e75a..ee3cb303 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -1,4 +1,17 @@ -"""Adapted from https://github.com/hendrycks/robustness.""" +"""These corruptive transformations are mostly PyTorch portings of the originals provided by +Dan Hendrycks and Thomas Dietterich in "Benchmarking neural network robustness to common +corruptions and perturbations" published at ICLR 2019 through their GitHub repository +https://github.com/hendrycks/robustness. + +However, please note that these transforms have been rewritten with more modern tools to improve +their efficiency as well as reduce the number of dependencies. As a result, some parameters had +to be modified to remain as close as possible to the original transforms. + +The authors of the library advise avoiding using the stochastic transforms to generate your dataset +to avoid reproducibility issues. It may be preferable to first check if the corrupted dataset is +available on TorchUncertainty's Hugging Face https://huggingface.co/torch-uncertainty. File an +issue if you would like one specific and missing dataset to be published on this page. +""" from importlib import util from io import BytesIO @@ -202,6 +215,10 @@ def __init__(self, severity: int) -> None: Args: severity (int): Severity level of the corruption. + + Note: + We have changed the number of iterations that was too high given the size of the + images. """ super().__init__(severity) if not kornia_installed: @@ -515,6 +532,10 @@ def __init__(self, severity: int) -> None: Args: severity (int): Severity level of the corruption. + + Note: + The values have been changed to better reflect the magnitude of the original + transformation replaced with the more principled torchvision adjust_brightness. """ TUCorruption.__init__(self, severity) self.level = [1.3, 1.6, 1.9, 2.2, 2.5][severity - 1] @@ -741,7 +762,7 @@ def __init__(self, severity: int) -> None: "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - sigma = [0.4, 0.6, 0.7, 0.8, 1.0][severity - 1] + sigma = [1, 2, 3, 4, 6][severity - 1] self.sigma = (sigma, sigma) self.kernel_size = int(sigma // 2 * 2 * 4 + 1) From 06ef3074a10ced597bf9cd03dc8f0341cb6a0261 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sat, 15 Mar 2025 16:22:27 +0100 Subject: [PATCH 17/47] :book: Add a ref to the new datasets on the RM --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e0f735a4..1e476b03 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,9 @@ This package provides a multi-level API, including: - easy-to-use :zap: lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. - ready-to-train baselines on research datasets, such as ImageNet and CIFAR -- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR ( :construction: work in progress :construction: ). - **layers**, **models**, **metrics**, & **losses** available for use in your networks - scikit-learn style post-processing methods such as Temperature Scaling. +- transformations, including corruptions resulting in additional "corrupted datasets" available on [HuggingFace](https://huggingface.co/torch-uncertainty) Have a look at the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. From 6c7285816c24306d4223397d90af3840a299f139 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Mar 2025 15:12:08 +0100 Subject: [PATCH 18/47] :bug: Fix speckle noise --- torch_uncertainty/transforms/corruption.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index ee3cb303..8875556e 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -726,22 +726,29 @@ def forward(self, img: Tensor) -> Tensor: class SpeckleNoise(TUCorruption): name = "speckle_noise" + batched = True - def __init__(self, severity: int) -> None: - """Apply speckle noise to unbatched tensor images. + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply speckle noise to tensor images. Args: severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. """ super().__init__(severity) self.scale = [0.15, 0.2, 0.35, 0.45, 0.6][severity - 1] - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) def forward(self, img: Tensor) -> Tensor: + """Apply speckle noise on images. + + Args: + img (Tensor): A potentially batched image of shape (C, H, W) or (B, C, H, W). + """ if self.severity == 0: return img return torch.clamp( - img + img * self.rng.normal(img, self.scale), + img * self.rng.normal(1, self.scale, size=img.shape), 0, 1, ) From cfdb70eb5e22b68487e1d265836aacab14c95618 Mon Sep 17 00:00:00 2001 From: Olivier Date: Sun, 16 Mar 2025 16:38:26 +0100 Subject: [PATCH 19/47] :hammer: Continue improving corruptions --- tests/transforms/test_corruption.py | 69 ++++--- torch_uncertainty/transforms/corruption.py | 220 ++++++++++++--------- 2 files changed, 176 insertions(+), 113 deletions(-) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 9076648f..51b75d3c 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -34,84 +34,103 @@ def test_gaussian_noise(self): _ = GaussianNoise(0.1) inputs = torch.rand(3, 32, 32) transform = GaussianNoise(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = GaussianNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + assert transform(inputs).ndim == 4 + print(transform) def test_shot_noise(self): inputs = torch.rand(3, 32, 32) transform = ShotNoise(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = ShotNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + assert transform(inputs).ndim == 4 def test_impulse_noise(self): inputs = torch.rand(3, 32, 32) - transform = ImpulseNoise(1) - transform(inputs) + transform = ImpulseNoise(1, black_white=True) + assert transform(inputs).ndim == 3 transform = ImpulseNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + transform = ImpulseNoise(1, black_white=False) + inputs = torch.rand(3, 3, 32, 32) + assert transform(inputs).ndim == 4 def test_speckle_noise(self): inputs = torch.rand(3, 32, 32) transform = SpeckleNoise(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = SpeckleNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + transform = MotionBlur(1) + assert transform(inputs).ndim == 4 def test_gaussian_blur(self): inputs = torch.rand(3, 32, 32) transform = GaussianBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = GaussianBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_glass_blur(self): inputs = torch.rand(3, 32, 32) transform = GlassBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = GlassBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_defocus_blur(self): inputs = torch.rand(3, 32, 32) transform = DefocusBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = DefocusBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + transform = DefocusBlur(1) + assert transform(inputs).ndim == 4 def test_motion_blur(self): inputs = torch.rand(3, 32, 32) transform = MotionBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = MotionBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 - inputs = torch.rand(1, 3, 32, 32) + inputs = torch.rand(3, 3, 32, 32) transform = MotionBlur(1) - transform(inputs) + assert transform(inputs).ndim == 4 def test_zoom_blur(self): inputs = torch.rand(3, 32, 32) transform = ZoomBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = ZoomBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_jpeg_compression(self): inputs = torch.rand(3, 32, 32) transform = JPEGCompression(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = JPEGCompression(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_pixelate(self): inputs = torch.rand(3, 32, 32) transform = Pixelate(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = Pixelate(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_frost(self): try: diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 8875556e..b1435da1 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -16,6 +16,8 @@ from importlib import util from io import BytesIO +from einops import rearrange + if util.find_spec("cv2"): import cv2 @@ -48,7 +50,8 @@ ) if util.find_spec("kornia"): - from kornia.filters import gaussian_blur2d, motion_blur + from kornia.color import rgb_to_grayscale + from kornia.filters import filter2d, gaussian_blur2d, motion_blur kornia_installed = True else: # coverage: ignore @@ -84,6 +87,9 @@ class TUCorruption(nn.Module): + name: str + batched: bool = False + def __init__(self, severity: int) -> None: """Base class for corruption transforms.""" super().__init__() @@ -100,9 +106,10 @@ def __repr__(self) -> str: class GaussianNoise(TUCorruption): name = "gaussian_noise" + batched = True def __init__(self, severity: int) -> None: - """Apply a Gaussian noise corruption to unbatched tensor images. + """Apply a Gaussian noise corruption to tensor images. Args: severity (int): Severity level of the corruption. @@ -111,6 +118,11 @@ def __init__(self, severity: int) -> None: self.scale = [0.08, 0.12, 0.18, 0.26, 0.38][severity - 1] def forward(self, img: Tensor) -> Tensor: + """Apply Gaussian noise on an input image. + + Args: + img (Tensor): A potentially batched image of shape (C, H, W) or (B, C, H, W) + """ if self.severity == 0: return img return torch.clamp(torch.normal(img, self.scale), 0, 1) @@ -118,9 +130,10 @@ def forward(self, img: Tensor) -> Tensor: class ShotNoise(TUCorruption): name = "shot_noise" + batched = True def __init__(self, severity: int) -> None: - """Apply a shot (Poisson) noise corruption to unbatched tensor images. + """Apply a shot (Poisson) noise corruption to tensor images. Args: severity (int): Severity level of the corruption. @@ -129,6 +142,11 @@ def __init__(self, severity: int) -> None: self.scale = [60, 25, 12, 5, 3][severity - 1] def forward(self, img: Tensor): + """Apply Poisson noise on an input image. + + Args: + img (Tensor): A potentially batched image of shape (C, H, W) or (B, C, H, W) + """ if self.severity == 0: return img return torch.clamp(torch.poisson(img * self.scale) / self.scale, 0, 1) @@ -136,6 +154,7 @@ def forward(self, img: Tensor): class ImpulseNoise(TUCorruption): name = "impulse_noise" + batched = True def __init__(self, severity: int, black_white: bool = False) -> None: """Apply an impulse (channel-independent Salt & Pepper) noise corruption to unbatched @@ -163,17 +182,46 @@ def __init__(self, severity: int, black_white: bool = False) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - img = img.unsqueeze(0) if self.black_white else img.unsqueeze(1) + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + channels = img.shape[1] + if not self.black_white: + img = rearrange(img, "b c ... -> (b c) 1 ...") img = torch.clamp( - input=torch.as_tensor(self.aug(img)), + input=self.aug(img), min=torch.zeros(1), max=torch.ones(1), ) + if not self.black_white: + img = rearrange(img, "(b c) 1 ... -> b c ... ", c=channels) + + if no_batch: + img = img.squeeze(0) return img.squeeze(0) if self.black_white else img.squeeze(1) +def disk(radius: int, alias_blur: float = 0.1, dtype=torch.float32): + """Generate a Gaussian disk of shape (1, radius, radius) for filtering.""" + if radius <= 8: + size = torch.arange(-8, 8 + 1) + ksize = (3, 3) + else: # coverage: ignore + size = torch.arange(-radius, radius + 1) + ksize = (5, 5) + xs, ys = torch.meshgrid(size, size, indexing="xy") + + aliased_disk = ((xs**2 + ys**2) <= radius**2).to(dtype=dtype) + aliased_disk /= aliased_disk.sum() + return gaussian_blur2d( + aliased_disk.unsqueeze(0).unsqueeze(0), kernel_size=ksize, sigma=(alias_blur, alias_blur) + ).squeeze(0) + + class DefocusBlur(TUCorruption): name = "defocus_blur" + batched = True def __init__(self, severity: int) -> None: """Apply a defocus blur corruption to unbatched tensor images. @@ -182,43 +230,37 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - if not cv2_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.radius = [3, 4, 6, 8, 10][severity - 1] - self.alias_blur = [0.1, 0.5, 0.5, 0.5, 0.5][severity - 1] + radius = [3, 4, 6, 8, 10][severity - 1] + alias_blur = [0.1, 0.5, 0.5, 0.5, 0.5][severity - 1] + self.disk = disk(radius, alias_blur=alias_blur) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - img = img.numpy() - channels = [ - torch.as_tensor( - cv2.filter2D( - img[ch, :, :], - -1, - disk(self.radius, alias_blur=self.alias_blur), - ) - ) - for ch in range(3) - ] - return torch.clamp(torch.stack(channels), 0, 1) + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + out = torch.clamp(filter2d(img, kernel=self.disk), 0, 1) + if no_batch: + out = out.squeeze(0) + return out -class GlassBlur(TUCorruption): # TODO: batch +class GlassBlur(TUCorruption): name = "glass_blur" - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a glass blur corruption to unbatched tensor images. Args: severity (int): Severity level of the corruption. - - Note: - We have changed the number of iterations that was too high given the size of the - images. + seed (int | None): Optional seed for the rng. """ super().__init__(severity) if not kornia_installed: @@ -229,53 +271,56 @@ def __init__(self, severity: int) -> None: sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] self.sigma = (sigma, sigma) self.kernel_size = int(sigma * 4 // 2 * 2 + 1) - self.iterations = [1, 1, 1, 2, 2][severity - 1] + self.iterations = [2, 1, 3, 2, 2][severity - 1] self.max_delta = [1, 2, 2, 3, 4][severity - 1] + if seed is None: + self.rng = None + else: + self.rng = torch.Generator(device="cpu").manual_seed(seed) + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img img_size = img.shape - img = gaussian_blur2d( - img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma - ).squeeze(0) - for _ in range(self.iterations): - for h in range(img_size[1] - self.max_delta, self.max_delta, -1): - for w in range(img_size[2] - self.max_delta, self.max_delta, -1): - dx, dy = torch.randint(-self.max_delta, self.max_delta, size=(2,)) + img = ( + gaussian_blur2d(img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma) + .squeeze(0) + .permute(1, 2, 0) + ) + rands = torch.randint( + -self.max_delta, + self.max_delta, + size=(self.iterations, img_size[1] - self.max_delta, img_size[2] - self.max_delta, 2), + generator=self.rng, + ) + + for iteration in range(self.iterations): + for i, h in enumerate(range(img_size[1] - self.max_delta, self.max_delta, -1)): + for j, w in enumerate(range(img_size[2] - self.max_delta, self.max_delta, -1)): + dx, dy = rands[iteration, i, j, :] h_prime, w_prime = h + dy, w + dx - img[:, h, w], img[:, h_prime, w_prime] = ( - img[:, h_prime, w_prime], - img[:, h, w], - ) + img[h, w, :], img[h_prime, w_prime, :] = img[h_prime, w_prime, :], img[h, w, :] + return torch.clamp( gaussian_blur2d( - img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma + img.permute(2, 0, 1).unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma ).squeeze(0), 0, 1, ) -def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): - if radius <= 8: - size = np.arange(-8, 8 + 1) - ksize = (3, 3) - else: # coverage: ignore - size = np.arange(-radius, radius + 1) - ksize = (5, 5) - xs, ys = np.meshgrid(size, size) - aliased_disk = np.array((xs**2 + ys**2) <= radius**2, dtype=dtype) - aliased_disk /= np.sum(aliased_disk) - return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) - - class MotionBlur(TUCorruption): name = "motion_blur" + batched = True - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a motion blur corruption to unbatched tensor images. Args: severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. Note: Originally, Hendrycks et al. used Gaussian motion blur. To remove the dependency with @@ -283,7 +328,7 @@ def __init__(self, severity: int) -> None: sigma as the new kernel radius sizes. """ super().__init__(severity) - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) self.radius = [3, 5, 8, 12, 15][severity - 1] if not kornia_installed: @@ -365,11 +410,12 @@ def forward(self, img: Tensor) -> Tensor: class Snow(TUCorruption): name = "snow" - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a snow effect on unbatched tensor images. Args: severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. Note: The transformation has been slightly modified, see MotionBlur for details. @@ -382,9 +428,9 @@ def __init__(self, severity: int) -> None: (0.55, 4.5, 0.85, 8, 0.65), (0.55, 2.5, 0.85, 12, 0.55), ][severity - 1] - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) - if not kornia_installed: + if not kornia_installed or not scipy_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -393,60 +439,53 @@ def __init__(self, severity: int) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - _, height, width = img.shape - x = img.permute(1, 2, 0).numpy() - snow_layer = self.rng.normal(size=x.shape[:2], loc=self.mix[0], scale=0.3)[..., np.newaxis] + snow_layer = self.rng.normal(size=img.shape[1:], loc=self.mix[0], scale=0.3)[ + ..., np.newaxis + ] snow_layer = clipped_zoom(snow_layer, self.mix[1]) snow_layer[snow_layer < self.mix[2]] = 0 snow_layer = np.clip(snow_layer.squeeze(), 0, 1) - snow_layer = ( - motion_blur( - torch.as_tensor(snow_layer).unsqueeze(0).unsqueeze(0), - kernel_size=self.mix[3] * 2 + 1, - angle=self.rng.uniform(-135, -45), - direction=0, - ) - .squeeze(0) - .squeeze(0) - .unsqueeze(-1) - .numpy() - ) - x = self.mix[4] * x + (1 - self.mix[4]) * np.maximum( - x, - cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(height, width, 1) * 1.5 + 0.5, + snow_layer = motion_blur( + torch.as_tensor(snow_layer).unsqueeze(0).unsqueeze(0), + kernel_size=self.mix[3] * 2 + 1, + angle=self.rng.uniform(-135, -45), + direction=0, + ).squeeze(0) + + x = self.mix[4] * img + (1 - self.mix[4]) * torch.maximum( + img, + rgb_to_grayscale(img) * 1.5 + 0.5, ) - return torch.clamp( - torch.as_tensor(x + snow_layer + np.rot90(snow_layer, k=2)), 0, 1 - ).permute(2, 0, 1) + return torch.clamp(x + snow_layer + snow_layer.flip(dims=(1, 2)), 0, 1) class Frost(TUCorruption): name = "frost" - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a frost corruption effect on unbatched tensor images. Args: severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. """ super().__init__(severity) - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) self.mix = [(1, 0.4), (0.8, 0.6), (0.7, 0.7), (0.65, 0.7), (0.6, 0.75)][severity - 1] self.frost_ds = FrostImages("./data", download=True, transform=ToTensor()) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - _, height, width = img.shape - frost_img = RandomResizedCrop((height, width))( + frost_img = RandomResizedCrop(img.shape[1:])( self.frost_ds[self.rng.integers(low=0, high=5)] ) return torch.clamp(self.mix[0] * img + self.mix[1] * frost_img, 0, 1) -def plasma_fractal(height, width, wibbledecay=3): +def plasma_fractal(height, width, rng, wibbledecay=3): """Generate a heightmap using diamond-square algorithm. Return square 2d array, side length 'mapsize', of floats in range 0-1. 'mapsize' must be a power of two. @@ -455,7 +494,6 @@ def plasma_fractal(height, width, wibbledecay=3): maparray[0, 0] = 0 stepsize = height wibble = 100 - rng = np.random.default_rng() def wibbledmean(array): return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape) @@ -500,14 +538,16 @@ def filldiamonds(): class Fog(TUCorruption): name = "fog" - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a fog corruption effect on unbatched tensor images. Args: severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. """ super().__init__(severity) self.mix = [(1.5, 2), (2, 2), (2.5, 1.7), (2.5, 1.5), (3, 1.4)][severity - 1] + self.rng = np.random.default_rng(seed) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -518,7 +558,10 @@ def forward(self, img: Tensor) -> Tensor: fog = ( self.mix[0] * plasma_fractal( - height=random_height_map_size, width=random_height_map_size, wibbledecay=self.mix[1] + height=random_height_map_size, + width=random_height_map_size, + wibbledecay=self.mix[1], + rng=self.rng, )[:height, :width] ) return torch.clamp((img + fog) * max_val / (max_val + self.mix[0]), 0, 1) @@ -613,11 +656,12 @@ def forward(self, img: Tensor) -> Tensor: class Elastic(TUCorruption): name = "elastic" - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: """Apply an elastic corruption to unbatched tensor images. Args: severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. Note: mix[0][1] has been changed to 0.5 to avoid errors when dealing with small images. @@ -635,7 +679,7 @@ def __init__(self, severity: int) -> None: (0.07, 0.01, 0.02), (0.12, 0.01, 0.02), ][severity - 1] - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: From 2a6893d718c5cb18135d8f657baee6a5e1dd460c Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 01:01:31 +0100 Subject: [PATCH 20/47] :bug: Fix saturation --- torch_uncertainty/transforms/corruption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index b1435da1..22b5dab1 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -838,7 +838,7 @@ def __init__(self, severity: int) -> None: """ TUCorruption.__init__(self, severity) self.severity = severity - self.level = [0.1, 0.2, 0.3, 0.4, 0.5][severity - 1] + self.level = [0.8, 0.6, 0.4, 0.2, 0.1][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: From 0f285961f0aaffeef8101f13731f9431cc723c81 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 01:12:50 +0100 Subject: [PATCH 21/47] :sparkles: Add new implementation of GlassBlur --- tests/transforms/test_corruption.py | 7 ++ torch_uncertainty/transforms/corruption.py | 130 ++++++++++++++++++++- 2 files changed, 136 insertions(+), 1 deletion(-) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 51b75d3c..7cdac5cd 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -15,6 +15,7 @@ ImpulseNoise, JPEGCompression, MotionBlur, + OriginalGlassBlur, Pixelate, Saturation, ShotNoise, @@ -89,6 +90,12 @@ def test_glass_blur(self): transform = GlassBlur(0) assert transform(inputs).ndim == 3 + inputs = torch.rand(3, 32, 32) + transform = OriginalGlassBlur(1) + assert transform(inputs).ndim == 3 + transform = OriginalGlassBlur(0) + assert transform(inputs).ndim == 3 + def test_defocus_blur(self): inputs = torch.rand(3, 32, 32) transform = DefocusBlur(1) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 22b5dab1..4be0f403 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -16,7 +16,9 @@ from importlib import util from io import BytesIO +import torch.nn.functional as F from einops import rearrange +from torch.distributions import Categorical if util.find_spec("cv2"): import cv2 @@ -252,12 +254,137 @@ def forward(self, img: Tensor) -> Tensor: return out +def generate_offset_distribution(max_delta, iterations): + """Symmetrized version of the glass blur swapping algorithm. + + The original implementation is sequential and extremely long on large images. This version should + be statistically equivalent. The sketch of proof will be provided in TorchUncertainty's paper. + """ + interval_length = 2 * max_delta + 1 + diagram_size = 12 * max_delta # sufficient for a proper density estimation + tab = torch.zeros((diagram_size, diagram_size), dtype=torch.float32) + tab[0, max_delta] = 1 + for pivot, t in enumerate(range(1, diagram_size)): + # the pivot gets 1/interval_length of all the accessible previous densities + for i in range(-max_delta, max_delta + 1): + if 0 <= pivot + i < diagram_size: + tab[t, pivot] += tab[t - 1, pivot + i] + + # the other values keep (interval_length-1/interval_length of their previous densities + # and 1/interval_length the value of the pivot + for i in range(-max_delta, max_delta + 1): + if i != 0 and 0 <= pivot + i < diagram_size: + tab[t, pivot + i] += (interval_length - 1) * tab[t - 1, pivot + i] + tab[ + t - 1, pivot + ] + tab[t, :] /= interval_length + density = torch.diag(tab, -max_delta - 1) + + # reducing distribution dimention + idx = torch.clamp(density, 1e-4).argmin() + density = density[:idx] + + padded_density = F.pad(density, (len(density) - 2 * max_delta - 1, 0)) + sym_density = 1 / 2 * padded_density + 1 / 2 * padded_density.flip(-1) + + # Convolve the density in lieu of iterating + sym_density = sym_density.unsqueeze(0).unsqueeze(0) + sym_density_iter = sym_density.clone() + for _ in range(iterations - 1): + sym_density_iter = F.conv1d( + sym_density_iter, torch.flip(sym_density, (-1,)), padding=sym_density.shape[-1] // 2 + ) + return Categorical(probs=sym_density_iter.squeeze(0, 1)) + + class GlassBlur(TUCorruption): name = "glass_blur" def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a glass blur corruption to unbatched tensor images. + Faster implementation using a symetrized offset distribution. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. + + Note: + The hyperparameters have been adapted to output images calibrated with the original + implementation despite the fixes that increase the power of the transformation. + """ + super().__init__(severity) + if not kornia_installed: + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" + ) + sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] + self.sigma = (sigma, sigma) + self.kernel_size = int(sigma * 6 // 2 * 2 + 1) + iterations = [1, 2, 3, 2, 3][severity - 1] + max_delta = [1, 1, 1, 2, 3][severity - 1] + self.max_delta = max_delta + + self.offset_dist = generate_offset_distribution(max_delta, iterations) + + if seed is None: + self.rng = None + else: + self.rng = torch.Generator(device="cpu").manual_seed(seed) + + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + + img = gaussian_blur2d( + img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma + ).squeeze(0) + + img = img.permute(1, 2, 0) # HWC + height, width, _ = img.shape + max_d = self.max_delta + + valid_h = height - max_d + valid_w = width - max_d + + # Generate random offsets + rand_offsets = ( + self.offset_dist.sample(sample_shape=(valid_h, valid_w, 2)) + - self.offset_dist.param_shape[0] // 2 + ) + + # Create base indices + hs = ( + torch.arange(max_d, height, device=img.device)[:valid_h].unsqueeze(1).repeat(1, valid_w) + ) + ws = torch.arange(max_d, width, device=img.device)[:valid_w].unsqueeze(0).repeat(valid_h, 1) + + dy = rand_offsets[..., 0] + dx = rand_offsets[..., 1] + hs_prime = (hs + dy).clamp(0, height - 1) + ws_prime = (ws + dx).clamp(0, width - 1) + + flat_idx = hs.flatten(), ws.flatten() + flat_idx_prime = hs_prime.flatten(), ws_prime.flatten() + + tmp = img[flat_idx].clone() + img[flat_idx] = img[flat_idx_prime] + img[flat_idx_prime] = tmp + + img = img.permute(2, 0, 1).unsqueeze(0) # Back to BCHW + img = gaussian_blur2d(img, kernel_size=self.kernel_size, sigma=self.sigma).squeeze(0) + return torch.clamp(img, 0, 1) + + +class OriginalGlassBlur(TUCorruption): + name = "glass_blur" + + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply a glass blur corruption to unbatched tensor images. + + Original, likely incorrect and very slow implementation. + Args: severity (int): Severity level of the corruption. seed (int | None): Optional seed for the rng. @@ -485,7 +612,7 @@ def forward(self, img: Tensor) -> Tensor: return torch.clamp(self.mix[0] * img + self.mix[1] * frost_img, 0, 1) -def plasma_fractal(height, width, rng, wibbledecay=3): +def plasma_fractal(height, width, rng, wibbledecay): """Generate a heightmap using diamond-square algorithm. Return square 2d array, side length 'mapsize', of floats in range 0-1. 'mapsize' must be a power of two. @@ -852,6 +979,7 @@ def forward(self, img: Tensor) -> Tensor: ImpulseNoise, DefocusBlur, GlassBlur, + OriginalGlassBlur, MotionBlur, ZoomBlur, Snow, From e58c3bb052844a75f2dc54377681f5234b88ff9f Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 17 Mar 2025 11:38:02 +0100 Subject: [PATCH 22/47] :hammer: Replace ``.permute()`` and ``.repeat()`` with ``einops.rearrange`` ``einops.repeat`` --- torch_uncertainty/transforms/corruption.py | 30 ++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 4be0f403..45d71635 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -17,7 +17,7 @@ from io import BytesIO import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat from torch.distributions import Categorical if util.find_spec("cv2"): @@ -341,7 +341,7 @@ def forward(self, img: Tensor) -> Tensor: img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma ).squeeze(0) - img = img.permute(1, 2, 0) # HWC + img = rearrange(img, "c h w -> h w c") height, width, _ = img.shape max_d = self.max_delta @@ -355,10 +355,8 @@ def forward(self, img: Tensor) -> Tensor: ) # Create base indices - hs = ( - torch.arange(max_d, height, device=img.device)[:valid_h].unsqueeze(1).repeat(1, valid_w) - ) - ws = torch.arange(max_d, width, device=img.device)[:valid_w].unsqueeze(0).repeat(valid_h, 1) + hs = repeat(torch.arange(max_d, height, device=img.device)[:valid_h], "h -> h w", w=valid_w) + ws = repeat(torch.arange(max_d, width, device=img.device)[:valid_w], "w -> h w", h=valid_h) dy = rand_offsets[..., 0] dx = rand_offsets[..., 1] @@ -372,7 +370,7 @@ def forward(self, img: Tensor) -> Tensor: img[flat_idx] = img[flat_idx_prime] img[flat_idx_prime] = tmp - img = img.permute(2, 0, 1).unsqueeze(0) # Back to BCHW + img = rearrange(img, "h w c -> 1 c h w") # Back to BCHW img = gaussian_blur2d(img, kernel_size=self.kernel_size, sigma=self.sigma).squeeze(0) return torch.clamp(img, 0, 1) @@ -410,11 +408,11 @@ def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img img_size = img.shape - img = ( - gaussian_blur2d(img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma) - .squeeze(0) - .permute(1, 2, 0) + img = rearrange( + gaussian_blur2d(img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma), + "1 c h w -> h w c", ) + rands = torch.randint( -self.max_delta, self.max_delta, @@ -431,7 +429,7 @@ def forward(self, img: Tensor) -> Tensor: return torch.clamp( gaussian_blur2d( - img.permute(2, 0, 1).unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma + rearrange(img, "h w c -> 1 c h w"), kernel_size=self.kernel_size, sigma=self.sigma ).squeeze(0), 0, 1, @@ -526,12 +524,12 @@ def __init__(self, severity: int) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - img = img.permute(1, 2, 0).numpy() + img = rearrange(img, "c h w -> h w c").numpy() out = np.zeros_like(img) for zoom_factor in self.zooms: out += clipped_zoom(img, zoom_factor) img = (img + out) / (len(self.zooms) + 1) - return torch.clamp(torch.as_tensor(img).permute(2, 0, 1), 0, 1) + return torch.clamp(rearrange(torch.as_tensor(img), "h w c -> c h w"), 0, 1) class Snow(TUCorruption): @@ -811,7 +809,7 @@ def __init__(self, severity: int, seed: int | None = None) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - image = np.array(img.permute(1, 2, 0), dtype=np.float32) + image = np.array(rearrange(img, "c h w -> h w c"), dtype=np.float32) shape = image.shape shape_size = shape[:2] @@ -889,7 +887,7 @@ def forward(self, img: Tensor) -> Tensor: 0, 1, ) - return torch.as_tensor(img).permute(2, 0, 1) + return rearrange(torch.as_tensor(img), "h w c -> c h w") # Additional corruption transforms From 0e750192a612c5434d597779505d46ab6af242fb Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 11:56:36 +0100 Subject: [PATCH 23/47] :fire: Remove name & rename batch var --- torch_uncertainty/datasets/corrupted.py | 2 +- torch_uncertainty/transforms/corruption.py | 59 +++++----------------- 2 files changed, 14 insertions(+), 47 deletions(-) diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py index f774eeb8..ae614516 100644 --- a/torch_uncertainty/datasets/corrupted.py +++ b/torch_uncertainty/datasets/corrupted.py @@ -54,7 +54,7 @@ def prepare_data(self): with logging_redirect_tqdm(): pbar = tqdm(corruption_transforms) for corruption in pbar: - corruption_name = corruption.name + corruption_name = corruption.__name__ pbar.set_description(f"Processing {corruption_name}") (self.root / corruption_name / f"{self.shift_severity}").mkdir( parents=True, exist_ok=True diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 4be0f403..f213f69b 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -17,7 +17,7 @@ from io import BytesIO import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat from torch.distributions import Categorical if util.find_spec("cv2"): @@ -89,7 +89,6 @@ class TUCorruption(nn.Module): - name: str batched: bool = False def __init__(self, severity: int) -> None: @@ -107,8 +106,7 @@ def __repr__(self) -> str: class GaussianNoise(TUCorruption): - name = "gaussian_noise" - batched = True + batchable = True def __init__(self, severity: int) -> None: """Apply a Gaussian noise corruption to tensor images. @@ -131,8 +129,7 @@ def forward(self, img: Tensor) -> Tensor: class ShotNoise(TUCorruption): - name = "shot_noise" - batched = True + batchable = True def __init__(self, severity: int) -> None: """Apply a shot (Poisson) noise corruption to tensor images. @@ -155,8 +152,7 @@ def forward(self, img: Tensor): class ImpulseNoise(TUCorruption): - name = "impulse_noise" - batched = True + batchable = True def __init__(self, severity: int, black_white: bool = False) -> None: """Apply an impulse (channel-independent Salt & Pepper) noise corruption to unbatched @@ -222,8 +218,7 @@ def disk(radius: int, alias_blur: float = 0.1, dtype=torch.float32): class DefocusBlur(TUCorruption): - name = "defocus_blur" - batched = True + batchable = True def __init__(self, severity: int) -> None: """Apply a defocus blur corruption to unbatched tensor images. @@ -298,8 +293,6 @@ def generate_offset_distribution(max_delta, iterations): class GlassBlur(TUCorruption): - name = "glass_blur" - def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a glass blur corruption to unbatched tensor images. @@ -310,8 +303,10 @@ def __init__(self, severity: int, seed: int | None = None) -> None: seed (int | None): Optional seed for the rng. Note: - The hyperparameters have been adapted to output images calibrated with the original - implementation despite the fixes that increase the power of the transformation. + The hyperparameters have been adapted to output images qualitatively calibrated with + the original implementation despite the changes in implementation that increase the + power of the transformation. This is notably due to discarding the correlation between + the offsets to simplify the derivation. """ super().__init__(severity) if not kornia_installed: @@ -355,10 +350,8 @@ def forward(self, img: Tensor) -> Tensor: ) # Create base indices - hs = ( - torch.arange(max_d, height, device=img.device)[:valid_h].unsqueeze(1).repeat(1, valid_w) - ) - ws = torch.arange(max_d, width, device=img.device)[:valid_w].unsqueeze(0).repeat(valid_h, 1) + hs = repeat(torch.arange(max_d, height, device=img.device)[:valid_h], "h -> h w", w=valid_w) + ws = repeat(torch.arange(max_d, height, device=img.device)[:valid_w], "w -> h w", h=valid_h) dy = rand_offsets[..., 0] dx = rand_offsets[..., 1] @@ -378,8 +371,6 @@ def forward(self, img: Tensor) -> Tensor: class OriginalGlassBlur(TUCorruption): - name = "glass_blur" - def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a glass blur corruption to unbatched tensor images. @@ -439,8 +430,7 @@ def forward(self, img: Tensor) -> Tensor: class MotionBlur(TUCorruption): - name = "motion_blur" - batched = True + batchable = True def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a motion blur corruption to unbatched tensor images. @@ -500,8 +490,6 @@ def clipped_zoom(img, zoom_factor): class ZoomBlur(TUCorruption): - name = "zoom_blur" - def __init__(self, severity: int) -> None: """Apply a zoom blur corruption to unbatched tensor images. @@ -535,8 +523,6 @@ def forward(self, img: Tensor) -> Tensor: class Snow(TUCorruption): - name = "snow" - def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a snow effect on unbatched tensor images. @@ -589,8 +575,6 @@ def forward(self, img: Tensor) -> Tensor: class Frost(TUCorruption): - name = "frost" - def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a frost corruption effect on unbatched tensor images. @@ -663,8 +647,6 @@ def filldiamonds(): class Fog(TUCorruption): - name = "fog" - def __init__(self, severity: int, seed: int | None = None) -> None: """Apply a fog corruption effect on unbatched tensor images. @@ -695,8 +677,6 @@ def forward(self, img: Tensor) -> Tensor: class Brightness(IBrightness, TUCorruption): - name = "brightness" - def __init__(self, severity: int) -> None: """Apply a brightness corruption to unbatched tensor images. @@ -717,8 +697,6 @@ def forward(self, img: Tensor) -> Tensor: class Contrast(IContrast, TUCorruption): - name = "contrast" - def __init__(self, severity: int) -> None: """Apply a contrast corruption to unbatched tensor images. @@ -735,8 +713,6 @@ def forward(self, img: Tensor) -> Tensor | Image.Image: class Pixelate(TUCorruption): - name = "pixelate" - def __init__(self, severity: int) -> None: """Apply a pixelation corruption to unbatched tensor images. @@ -761,8 +737,6 @@ def forward(self, img: Tensor) -> Tensor: class JPEGCompression(TUCorruption): - name = "jpeg_compression" - def __init__(self, severity: int) -> None: """Apply a JPEG compression corruption to unbatched tensor images. @@ -781,8 +755,6 @@ def forward(self, img: Tensor) -> Tensor: class Elastic(TUCorruption): - name = "elastic" - def __init__(self, severity: int, seed: int | None = None) -> None: """Apply an elastic corruption to unbatched tensor images. @@ -896,8 +868,7 @@ def forward(self, img: Tensor) -> Tensor: class SpeckleNoise(TUCorruption): - name = "speckle_noise" - batched = True + batchable = True def __init__(self, severity: int, seed: int | None = None) -> None: """Apply speckle noise to tensor images. @@ -926,8 +897,6 @@ def forward(self, img: Tensor) -> Tensor: class GaussianBlur(TUCorruption): - name = "gaussian_blur" - def __init__(self, severity: int) -> None: """Apply a Gaussian blur corruption to unbatched tensor images. @@ -955,8 +924,6 @@ def forward(self, img: Tensor) -> Tensor: class Saturation(ISaturation, TUCorruption): - name = "saturation" - def __init__(self, severity: int) -> None: """Apply a saturation corruption to unbatched tensor images. From 442b3b735732124ff978dfc1458d6a648e22e48d Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 12:17:35 +0100 Subject: [PATCH 24/47] :ok_hand: Comply with comments --- torch_uncertainty/transforms/corruption.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index b7718239..9e3374a7 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -818,14 +818,10 @@ def forward(self, img: Tensor) -> Tensor: dx = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])) - .unsqueeze(0) - .unsqueeze(0), + torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape[:2]))), kernel_size=ks, sigma=(sigma, sigma), - ) - .squeeze(0) - .squeeze(0) + ).squeeze(0, 1) * self.mix[0] * shape_size[0] ) @@ -835,14 +831,10 @@ def forward(self, img: Tensor) -> Tensor: dy = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=shape[:2])) - .unsqueeze(0) - .unsqueeze(0), + torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape[:2]))), kernel_size=ks, sigma=(sigma, sigma), - ) - .squeeze(0) - .squeeze(0) + ).squeeze(0, 1) * self.mix[0] * shape_size[0] ) From 55eb8aec127e51ad5791e3bba8c2053533aa8cf5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:24:53 +0100 Subject: [PATCH 25/47] :hammer: Rework the corrupted dataset --- torch_uncertainty/datasets/corrupted.py | 90 ++++++++++++++-------- torch_uncertainty/transforms/corruption.py | 23 +++++- 2 files changed, 77 insertions(+), 36 deletions(-) diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py index ae614516..4185514c 100644 --- a/torch_uncertainty/datasets/corrupted.py +++ b/torch_uncertainty/datasets/corrupted.py @@ -1,8 +1,10 @@ +import re from copy import deepcopy from pathlib import Path from torch import nn from torchvision.datasets import VisionDataset +from torchvision.datasets.folder import default_loader from torchvision.transforms import ToPILImage, ToTensor from tqdm import tqdm, trange from tqdm.contrib.logging import logging_redirect_tqdm @@ -15,61 +17,90 @@ def __init__( self, core_dataset: VisionDataset, shift_severity: int, + generate: bool = False, on_the_fly: bool = False, ) -> None: - """Generate the corrupted version of any VisionDataset.""" + """Generate the corrupted version of any VisionDataset. + + Args: + core_dataset (VisionDataset): dataset to be corrupted. + shift_severity (int): intensity of the corruption. Should be in [1, 5]. + generate (bool): Equivalent of the download attributes of the dataset. If ``True``, + generate a new dataset with all the corrupted images. Defaults to ``False``. + on_the_fly (bool): Generate the corrupted version of the dataset on the fly, without + saving the images on disk. This is discouraged since the experiment won't be fully + reproducible. + + Note: + The corrupted dataset will use `transforms` of :attr:`core_dataset`. + """ super().__init__() self.core_dataset = core_dataset - if shift_severity <= 0: - raise ValueError(f"Severity must be greater than 0. Got {shift_severity}.") + if shift_severity < 0: + raise ValueError(f"Severity must be strictly greater than 0. Got {shift_severity}.") + if not generate and on_the_fly: + raise ValueError("generate must be True if on_the_fly is True.") + self.shift_severity = shift_severity self.core_length = len(core_dataset) + self.generate = generate self.on_the_fly = on_the_fly + self.transforms = deepcopy(core_dataset.transforms) - self.target_transforms = deepcopy(core_dataset.target_transform) self.core_dataset.transform = None + self.core_dataset.transforms = None self.core_dataset.target_transform = None - dataset_name = str(type(core_dataset)).split(".")[-1][:-2].lower() + dataset_name = str(type(core_dataset)).split(".")[-1][:-2] self.root = Path(core_dataset.root) / (dataset_name + "-C") - self.root.mkdir(parents=True, exist_ok=True) - if not on_the_fly: + if hasattr(self.core_dataset, "targets"): + self.targets = self.core_dataset.targets + elif hasattr(self.core_dataset, "labels"): + self.targets = self.core_dataset.labels + elif hasattr(self.core_dataset, "_labels"): + self.targets = self.core_dataset._labels + else: + raise ValueError("The dataset should implement either targets, labels, or _labels.") + + self.targets = self.targets * len(corruption_transforms) + + if not generate: + paths = sorted(self.root.glob(f"**/{self.shift_severity}/*.jpg"), key=lambda x: x.stem) + self.samples = list(zip(paths, self.targets, strict=False)) + if len(paths) != 15 * self.core_length: + raise ValueError( + "The corrupted dataset is not complete. Download it from HuggingFace or set generate=True." + ) + + if generate and not on_the_fly: + self.root.mkdir(parents=True, exist_ok=True) self.to_tensor = ToTensor() self.to_pil = ToPILImage() self.samples = [] - if hasattr(self.core_dataset, "targets"): - self.targets = self.core_dataset.targets - elif hasattr(self.core_dataset, "labels"): - self.targets = self.core_dataset.labels - elif hasattr(self.core_dataset, "_labels"): - self.targets = self.core_dataset._labels - else: - raise ValueError("The dataset should implement either targets, labels, or _labels.") - - self.targets = self.targets * len(corruption_transforms) + self.prepare_data() def prepare_data(self): with logging_redirect_tqdm(): pbar = tqdm(corruption_transforms) for corruption in pbar: - corruption_name = corruption.__name__ + corruption_name = re.sub(r"([a-z])([A-Z])", r"\1_\2", corruption.__name__).lower() pbar.set_description(f"Processing {corruption_name}") (self.root / corruption_name / f"{self.shift_severity}").mkdir( parents=True, exist_ok=True ) - self.save_corruption( + self._save_corruption( self.root / corruption_name / f"{self.shift_severity}", corruption(self.shift_severity), ) - def save_corruption(self, root: Path, corruption: nn.Module) -> None: + def _save_corruption(self, root: Path, corruption: nn.Module) -> None: for i in trange(self.core_length, leave=False): img, tgt = self.core_dataset[i] img = corruption(self.to_tensor(img)) self.to_pil(img).save(root / f"{i}.jpg") - self.samples.append(root / f"{i}.jpg") + self.samples.append((root / f"{i}.jpg", tgt)) self.targets.append(tgt) def __len__(self): @@ -88,20 +119,15 @@ def __getitem__(self, idx: int): img, target = self.core_dataset[idx] img = corrupt(img) - if self.transform is not None: - img = self.transform(img) + if self.transforms is not None: + img, target = self.transforms(img, target) - if self.target_transform is not None: - target = self.target_transform(target) return img, target - img, target = self.core_dataset[idx] - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) - + path, target = self.samples[idx] + img = default_loader(path) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 9e3374a7..d9fa0de6 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -677,6 +677,8 @@ def forward(self, img: Tensor) -> Tensor: class Brightness(IBrightness, TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: """Apply a brightness corruption to unbatched tensor images. @@ -697,6 +699,8 @@ def forward(self, img: Tensor) -> Tensor: class Contrast(IContrast, TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: """Apply a contrast corruption to unbatched tensor images. @@ -889,6 +893,8 @@ def forward(self, img: Tensor) -> Tensor: class GaussianBlur(TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: """Apply a Gaussian blur corruption to unbatched tensor images. @@ -908,14 +914,23 @@ def __init__(self, severity: int) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - return torch.clamp( + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + out = torch.clamp( gaussian_blur2d(img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma), - min=0, - max=1, - ).squeeze(0) + 0, + 1, + ) + if no_batch: + out = out.squeeze(0) + return out class Saturation(ISaturation, TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: """Apply a saturation corruption to unbatched tensor images. From a7d9cef7b25fffab5846b78f59347be6032fb33d Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:26:36 +0100 Subject: [PATCH 26/47] :white_check_mark: Improve coverage --- tests/transforms/test_corruption.py | 2 +- torch_uncertainty/transforms/corruption.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 7cdac5cd..80fbdcbc 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -91,7 +91,7 @@ def test_glass_blur(self): assert transform(inputs).ndim == 3 inputs = torch.rand(3, 32, 32) - transform = OriginalGlassBlur(1) + transform = OriginalGlassBlur(1, seed=1) assert transform(inputs).ndim == 3 transform = OriginalGlassBlur(0) assert transform(inputs).ndim == 3 diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index d9fa0de6..9acf57d4 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -323,11 +323,6 @@ def __init__(self, severity: int, seed: int | None = None) -> None: self.offset_dist = generate_offset_distribution(max_delta, iterations) - if seed is None: - self.rng = None - else: - self.rng = torch.Generator(device="cpu").manual_seed(seed) - def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img From 22ca13c00a52481822b44a13dc5910d34dc03f13 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:30:08 +0100 Subject: [PATCH 27/47] :book: Add some doc on the hyperparameters --- torch_uncertainty/transforms/corruption.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 9acf57d4..1593546d 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -11,6 +11,10 @@ to avoid reproducibility issues. It may be preferable to first check if the corrupted dataset is available on TorchUncertainty's Hugging Face https://huggingface.co/torch-uncertainty. File an issue if you would like one specific and missing dataset to be published on this page. + +In most of the cases, we have chosen to follow the hyperparameters used for ImageNet-C, which +differ from those of TinyImageNet-C, CIFAR-C or even the Inception version of ImageNet-C. However, +this may not be entirely suitable in the case of datasets with much smaller or bigger images. """ from importlib import util From 70b5c002db959a4349d71e6b8f67fb8d67a41181 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:31:18 +0100 Subject: [PATCH 28/47] :bug: Fixinvoluntary double unsqueeze --- torch_uncertainty/transforms/corruption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 1593546d..2e8a5b17 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -918,7 +918,7 @@ def forward(self, img: Tensor) -> Tensor: no_batch = True img = img.unsqueeze(0) out = torch.clamp( - gaussian_blur2d(img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma), + gaussian_blur2d(img, kernel_size=self.kernel_size, sigma=self.sigma), 0, 1, ) From 2fad90291bdceaaa8df423c979162f0a39196b6c Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:36:44 +0100 Subject: [PATCH 29/47] :white_check_mark: Finish coverage --- tests/transforms/test_corruption.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 80fbdcbc..512f6801 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -83,6 +83,10 @@ def test_gaussian_blur(self): transform = GaussianBlur(0) assert transform(inputs).ndim == 3 + inputs = torch.rand(3, 3, 32, 32) + transform = MotionBlur(1) + assert transform(inputs).ndim == 4 + def test_glass_blur(self): inputs = torch.rand(3, 32, 32) transform = GlassBlur(1) From 1ea1478f1583381a5a80cbf49fcd4fc0c1cf904e Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:40:51 +0100 Subject: [PATCH 30/47] :white_check_mark: Finish coverage --- tests/transforms/test_corruption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 512f6801..bf1184a5 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -84,7 +84,7 @@ def test_gaussian_blur(self): assert transform(inputs).ndim == 3 inputs = torch.rand(3, 3, 32, 32) - transform = MotionBlur(1) + transform = GaussianBlur(1) assert transform(inputs).ndim == 4 def test_glass_blur(self): From d349855b6dfe61ce6596288516552ade1ff174df Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:46:43 +0100 Subject: [PATCH 31/47] :fire: Remove useless flaky test --- tests/datamodules/segmentation/test_muad.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/datamodules/segmentation/test_muad.py b/tests/datamodules/segmentation/test_muad.py index 97d4f6d0..862206f0 100644 --- a/tests/datamodules/segmentation/test_muad.py +++ b/tests/datamodules/segmentation/test_muad.py @@ -35,7 +35,3 @@ def test_camvid_main(self): dm.setup() dm.train_dataloader() dm.val_dataloader() - - def test_small_muad_accessibility(self): - dataset = MUAD(root="./data/", split="test", version="small", download=True) - assert len(dataset.samples) > 0, "Dataset is not found" From c7d022e11d66a6a85dc6fe13905cb31e41df9444 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:54:38 +0100 Subject: [PATCH 32/47] :fire: Remove OriginalGlassBlur from corruption transforms --- torch_uncertainty/transforms/corruption.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 2e8a5b17..3399da6c 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -952,7 +952,6 @@ def forward(self, img: Tensor) -> Tensor: ImpulseNoise, DefocusBlur, GlassBlur, - OriginalGlassBlur, MotionBlur, ZoomBlur, Snow, From 602475123986badfb8b1a6a11d726629060fe0b5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 13:57:33 +0100 Subject: [PATCH 33/47] :hammer: Add corrupted datasets to init --- torch_uncertainty/datasets/__init__.py | 1 + torch_uncertainty/datasets/muad.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index 5acc7735..75aa9852 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset +from .corrupted import CorruptedDataset from .fractals import Fractals from .frost import FrostImages from .kitti import KITTIDepth diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index cc93c3d9..162be185 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -244,7 +244,7 @@ def _make_dataset(self, path: Path) -> None: f"target_type must be one of ['semantic', 'depth']. Got {self.target_type}." ) - def _download(self, split: str) -> None: + def _download(self, split: str) -> None: # coverage: ignore """Download and extract the chosen split of the dataset.""" if self.version == "small": filename = f"{split}.zip" From 2e6061ad52b8623127e0882b75372059fb3f0347 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 14:00:03 +0100 Subject: [PATCH 34/47] :wrench: Use one flag for codecov --- .github/workflows/run-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 09095375..594a30cb 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -85,7 +85,7 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml - flags: cpu,pytest + flags: pytest name: CPU-coverage env_vars: PYTHON_VERSION @@ -96,7 +96,7 @@ jobs: continue-on-error: true with: token: ${{ secrets.CODECOV_TOKEN }} - flags: cpu,pytest + flags: pytest env_vars: PYTHON_VERSION - name: Test sphinx build without tutorials From 46bdab85afba91a753d649ec24d3e75e4d00668e Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 14:04:00 +0100 Subject: [PATCH 35/47] :fire: Remove corrupted dataset from init due to circular import --- torch_uncertainty/datasets/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index 75aa9852..5acc7735 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,6 +1,5 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset -from .corrupted import CorruptedDataset from .fractals import Fractals from .frost import FrostImages from .kitti import KITTIDepth From 77f19f7e147667fa430d569e02b01ded27aa77c3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 14:07:45 +0100 Subject: [PATCH 36/47] :shirt: Various improvements --- torch_uncertainty/datasets/corrupted.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py index 4185514c..8c8a4197 100644 --- a/torch_uncertainty/datasets/corrupted.py +++ b/torch_uncertainty/datasets/corrupted.py @@ -29,14 +29,14 @@ def __init__( generate a new dataset with all the corrupted images. Defaults to ``False``. on_the_fly (bool): Generate the corrupted version of the dataset on the fly, without saving the images on disk. This is discouraged since the experiment won't be fully - reproducible. + reproducible. Defaults to ``False``. Note: The corrupted dataset will use `transforms` of :attr:`core_dataset`. """ super().__init__() self.core_dataset = core_dataset - if shift_severity < 0: + if shift_severity <= 0: raise ValueError(f"Severity must be strictly greater than 0. Got {shift_severity}.") if not generate and on_the_fly: raise ValueError("generate must be True if on_the_fly is True.") @@ -79,9 +79,10 @@ def __init__( self.to_pil = ToPILImage() self.samples = [] - self.prepare_data() + self._generate_data() - def prepare_data(self): + def _generate_data(self): + """Generate the corrupted data.""" with logging_redirect_tqdm(): pbar = tqdm(corruption_transforms) for corruption in pbar: @@ -96,6 +97,12 @@ def prepare_data(self): ) def _save_corruption(self, root: Path, corruption: nn.Module) -> None: + """Save all images with the given corruption on the disk. + + Args: + root (Path): The path where to save the images. + corruption (nn.Module): The corruption module to apply on the images. + """ for i in trange(self.core_length, leave=False): img, tgt = self.core_dataset[i] img = corruption(self.to_tensor(img)) @@ -104,7 +111,7 @@ def _save_corruption(self, root: Path, corruption: nn.Module) -> None: self.targets.append(tgt) def __len__(self): - """The length of the corrupted dataset.""" + """Get the length of the corrupted dataset.""" return len(self.core_dataset) * len(corruption_transforms) def __getitem__(self, idx: int): @@ -129,10 +136,3 @@ def __getitem__(self, idx: int): if self.transforms is not None: img, target = self.transforms(img, target) return img, target - - -if __name__ == "__main__": - from torchvision.datasets import OxfordIIITPet - - dataset = OxfordIIITPet(root="data", split="test", download=True) - corrupted_dataset = CorruptedDataset(dataset, shift_severity=5) From 3a3941ab8a551a81697ecc993cfcaad3fc943758 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 17 Mar 2025 16:21:49 +0100 Subject: [PATCH 37/47] :hammer: Post-processing `fit()` method takes a `DataLoader` instead of a `Dataset` --- tests/post_processing/test_laplace.py | 6 ++--- tests/post_processing/test_mc_batch_norm.py | 16 ++++++------ tests/post_processing/test_scalers.py | 6 +++-- torch_uncertainty/post_processing/abnn.py | 12 +++------ torch_uncertainty/post_processing/abstract.py | 4 +-- .../post_processing/calibration/scaler.py | 21 ++++++---------- torch_uncertainty/post_processing/laplace.py | 11 +++----- .../post_processing/mc_batch_norm.py | 25 ++++++++++--------- torch_uncertainty/routines/classification.py | 8 +++--- 9 files changed, 47 insertions(+), 62 deletions(-) diff --git a/tests/post_processing/test_laplace.py b/tests/post_processing/test_laplace.py index 8b6249ea..f2fdda7b 100644 --- a/tests/post_processing/test_laplace.py +++ b/tests/post_processing/test_laplace.py @@ -1,6 +1,6 @@ import torch from torch import nn -from torch.utils.data import TensorDataset +from torch.utils.data import DataLoader, TensorDataset from tests._dummies.model import dummy_model from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing @@ -20,12 +20,12 @@ class TestLaplace: """Testing the LaplaceApprox class.""" def test_training(self): - ds = TensorDataset(torch.randn(16, 1), torch.randn(16, 10)) + dl = DataLoader(TensorDataset(torch.randn(16, 1), torch.randn(16, 10)), batch_size=5) la = LaplaceApprox( task="classification", model=dummy_model(1, 10), ) - la.fit(ds) + la.fit(dl) la(torch.randn(1, 1)) la = LaplaceApprox(task="classification") la.set_model(dummy_model(1, 10)) diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index bbe987ca..b2277d94 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -4,6 +4,7 @@ import torch import torchvision.transforms as T from torch import nn +from torch.utils.data import DataLoader from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d @@ -17,14 +18,13 @@ class TestMCBatchNorm: def test_main(self): """Test initialization.""" mc_model = lenet(1, 1, norm=partial(MCBatchNorm2d, num_estimators=2)) - stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False, mc_batch_size=1) + stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False) model = lenet(1, 1, norm=nn.BatchNorm2d) stoch_model = MCBatchNorm( nn.Sequential(model), num_estimators=2, convert=True, - mc_batch_size=1, ) dataset = DummyClassificationDataset( "./", @@ -34,13 +34,14 @@ def test_main(self): num_images=2, transform=T.ToTensor(), ) - stoch_model.fit(dataset=dataset) + dl = DataLoader(dataset, batch_size=1, shuffle=True) + stoch_model.fit(dataloader=dl) stoch_model.train() stoch_model(torch.randn(1, 1, 20, 20)) stoch_model.eval() stoch_model(torch.randn(1, 1, 20, 20)) - stoch_model = MCBatchNorm(num_estimators=2, convert=False, mc_batch_size=1) + stoch_model = MCBatchNorm(num_estimators=2, convert=False) stoch_model.set_model(mc_model) def test_errors(self): @@ -48,14 +49,12 @@ def test_errors(self): model = nn.Identity() with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=0, convert=True) - with pytest.raises(ValueError, match="mc_batch_size must be a positive integer"): - MCBatchNorm(model, num_estimators=1, convert=True, mc_batch_size=-1) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=False) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=True) model = lenet(1, 1, norm=nn.BatchNorm2d) - stoch_model = MCBatchNorm(model, num_estimators=4, convert=True, mc_batch_size=1) + stoch_model = MCBatchNorm(model, num_estimators=4, convert=True) dataset = DummyClassificationDataset( "./", num_channels=1, @@ -64,9 +63,10 @@ def test_errors(self): num_images=2, transform=T.ToTensor(), ) + dl = DataLoader(dataset, batch_size=2, shuffle=True) stoch_model.eval() with pytest.raises(RuntimeError): stoch_model(torch.randn(1, 1, 20, 20)) with pytest.raises(ValueError): - stoch_model.fit(dataset=dataset) + stoch_model.fit(dataloader=dl) diff --git a/tests/post_processing/test_scalers.py b/tests/post_processing/test_scalers.py index fabe77f3..b499efc5 100644 --- a/tests/post_processing/test_scalers.py +++ b/tests/post_processing/test_scalers.py @@ -1,6 +1,7 @@ import pytest import torch from torch import nn, softmax +from torch.utils.data import DataLoader from torch_uncertainty.post_processing import ( MatrixScaler, @@ -26,10 +27,11 @@ def test_fit_biased(self): labels = torch.as_tensor([0.5, 0.5]).repeat(10, 1) calibration_set = list(zip(inputs, labels, strict=True)) + dl = DataLoader(calibration_set, batch_size=10) scaler = TemperatureScaler(model=nn.Identity(), init_val=2, lr=1, max_iter=10) assert scaler.temperature[0] == 2.0 - scaler.fit(calibration_set) + scaler.fit(dl) assert scaler.temperature[0] > 10 # best is +inf assert ( torch.sum( @@ -39,7 +41,7 @@ def test_fit_biased(self): ** 2 < 0.001 ) - scaler.fit_predict(calibration_set, progress=False) + scaler.fit_predict(dl, progress=False) def test_errors(self): with pytest.raises(ValueError): diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 0ec24375..b79e7f36 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -2,7 +2,7 @@ import torch from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from torch_uncertainty.layers.bayesian.abnn import BatchNormAdapter2d from torch_uncertainty.models import deep_ensembles @@ -25,7 +25,6 @@ def __init__( device: torch.device | str, max_epochs: int = 5, use_original_model: bool = True, - batch_size: int = 128, precision: str = "32", model: nn.Module | None = None, ): @@ -45,8 +44,6 @@ def __init__( to 5. use_original_model (bool, optional): Use original model during evaluation. Defaults to True. - batch_size (int, optional): Batch size for the training of ABNN. - Defaults to 128. precision (str, optional): Machine precision for training & eval. Defaults to "32". model (nn.Module | None, optional): Model to use. Defaults to None. @@ -63,7 +60,6 @@ def __init__( num_models=num_models, num_samples=num_samples, base_lr=base_lr, - batch_size=batch_size, ) self.num_classes = num_classes self.alpha = alpha @@ -74,7 +70,6 @@ def __init__( self.use_original_model = use_original_model self.max_epochs = max_epochs - self.batch_size = batch_size self.precision = precision self.device = device @@ -88,10 +83,9 @@ def __init__( weight[torch.randperm(num_classes)[:num_rp_classes]] += random_prior - 1 self.weights.append(weight) - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: if self.model is None: raise ValueError("Model must be set before fitting.") - dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) source_model = copy.deepcopy(self.model) _replace_bn_layers(source_model, self.alpha) @@ -119,7 +113,7 @@ def fit(self, dataset: Dataset) -> None: logger=None, enable_model_summary=False, ) - trainer.fit(model=baseline, train_dataloaders=dl) + trainer.fit(model=baseline, train_dataloaders=dataloader) final_models = ( [copy.deepcopy(source_model) for _ in range(self.num_samples)] diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py index 5b4ae9a6..7afe050c 100644 --- a/torch_uncertainty/post_processing/abstract.py +++ b/torch_uncertainty/post_processing/abstract.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader class PostProcessing(ABC, nn.Module): @@ -14,7 +14,7 @@ def set_model(self, model: nn.Module) -> None: self.model = model @abstractmethod - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: pass @abstractmethod diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index ed5d8e36..54c5c703 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn, optim -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from tqdm import tqdm from torch_uncertainty.post_processing import PostProcessing @@ -47,20 +47,14 @@ def __init__( def fit( self, - calibration_set: Dataset, - batch_size: int = 32, - shuffle: bool = False, - drop_last: bool = False, + dataloader: DataLoader, save_logits: bool = False, progress: bool = True, ) -> None: """Fit the temperature parameters to the calibration data. Args: - calibration_set (Dataset): Calibration dataset. - batch_size (int, optional): Batch size for the calibration dataset. Defaults to 32. - shuffle (bool, optional): Whether to shuffle the calibration dataset. Defaults to False. - drop_last (bool, optional): Whether to drop the last batch if it's smaller than batch_size. Defaults to False. + dataloader (DataLoader): Dataloader with the calibration data. save_logits (bool, optional): Whether to save the logits and labels. Defaults to False. progress (bool, optional): Whether to show a progress bar. @@ -73,9 +67,7 @@ def fit( all_logits = [] all_labels = [] - calibration_dl = DataLoader( - calibration_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last - ) + calibration_dl = dataloader with torch.no_grad(): for inputs, labels in tqdm(calibration_dl, disable=not progress): logits = self.model(inputs.to(self.device)) @@ -119,10 +111,11 @@ def _scale(self, logits: Tensor) -> Tensor: def fit_predict( self, - calibration_set: Dataset, + # calibration_set: Dataset, + dataloader: DataLoader, progress: bool = True, ) -> Tensor: - self.fit(calibration_set, save_logits=True, progress=progress) + self.fit(dataloader, save_logits=True, progress=progress) return self(self.logits) @property diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 7a918bd5..eda08d87 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -2,7 +2,7 @@ from typing import Literal from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from .abstract import PostProcessing @@ -23,7 +23,6 @@ def __init__( hessian_struct="kron", pred_type: Literal["glm", "nn"] = "glm", link_approx: Literal["mc", "probit", "bridge", "bridge_norm"] = "probit", - batch_size: int = 256, optimize_prior_precision: bool = True, ) -> None: """Laplace approximation for uncertainty estimation. @@ -42,8 +41,6 @@ def __init__( link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): how to approximate the classification link function for the `'glm'`. See the Laplace library for more details. Defaults to "probit". - batch_size (int, optional): batch size for the Laplace approximation. - Defaults to 256. optimize_prior_precision (bool, optional): whether to optimize the prior precision. Defaults to True. @@ -63,7 +60,6 @@ def __init__( self.task = task self.weight_subset = weight_subset self.hessian_struct = hessian_struct - self.batch_size = batch_size self.optimize_prior_precision = optimize_prior_precision if model is not None: @@ -78,9 +74,8 @@ def set_model(self, model: nn.Module) -> None: hessian_structure=self.hessian_struct, ) - def fit(self, dataset: Dataset) -> None: - dl = DataLoader(dataset, batch_size=self.batch_size) - self.la.fit(train_loader=dl) + def fit(self, dataloader: DataLoader) -> None: + self.la.fit(train_loader=dataloader) if self.optimize_prior_precision: self.la.optimize_prior_precision(method="marglik") diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 33dbf35d..d5d467b6 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.post_processing import PostProcessing @@ -19,7 +19,6 @@ def __init__( model: nn.Module | None = None, num_estimators: int = 16, convert: bool = True, - mc_batch_size: int = 32, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Monte Carlo Batch Normalization wrapper. @@ -28,7 +27,6 @@ def __init__( model (nn.Module): model to be converted. num_estimators (int): number of estimators. convert (bool): whether to convert the model. - mc_batch_size (int, optional): Monte Carlo batch size. Defaults to 32. device (Literal["cpu", "cuda"] | torch.device | None, optional): device. Defaults to None. @@ -40,7 +38,6 @@ def __init__( batch normalized deep networks. In ICML 2018. """ super().__init__() - self.mc_batch_size = mc_batch_size self.convert = convert self.num_estimators = num_estimators self.device = device @@ -49,7 +46,7 @@ def __init__( self._setup_model(model) def _setup_model(self, model): - _mcbn_checks(model, self.num_estimators, self.mc_batch_size, self.convert) + _mcbn_checks(model, self.num_estimators, self.convert) self.model = deepcopy(model) # TODO: Is it necessary? self.model = self.model.eval() if self.convert: @@ -61,22 +58,28 @@ def set_model(self, model: nn.Module) -> None: self.model = model self._setup_model(model) - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: """Fit the model on the dataset. Args: - dataset (Dataset): dataset to be used for fitting. + dataloader (DataLoader): DataLoader with the training dataset. Note: This method is used to populate the MC BatchNorm layers. Use the training dataset. + + Warning: + The ``batch_size`` of the DataLoader should be carefully chosen as it + will have an impact on the statistics of the MC BatchNorm layers. + + Raises: + ValueError: If there are less batches than the number of estimators. """ - self.dl = DataLoader(dataset, batch_size=self.mc_batch_size, shuffle=True) self.counter = 0 self.reset_counters() self.set_accumulate(True) self.eval() - for x, _ in self.dl: + for x, _ in dataloader: self.model(x.to(self.device)) self.raise_counters() if self.counter == self.num_estimators: @@ -162,10 +165,8 @@ def has_mcbn(model: nn.Module) -> bool: return any(isinstance(module, MCBatchNorm2d) for module in model.modules()) -def _mcbn_checks(model, num_estimators, mc_batch_size, convert): +def _mcbn_checks(model, num_estimators, convert): if num_estimators < 1 or not isinstance(num_estimators, int): raise ValueError(f"num_estimators must be a positive integer, got {num_estimators}.") - if mc_batch_size < 1 or not isinstance(mc_batch_size, int): - raise ValueError(f"mc_batch_size must be a positive integer, got {mc_batch_size}.") if not convert and not has_mcbn(model): raise ValueError("model does not contain any MCBatchNorm2d nor is not to be converted.") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index e5469b5c..0d0faeb9 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -381,13 +381,13 @@ def on_test_start(self) -> None: the storage lists for logit plotting and update the batchnorms if needed. """ if self.post_processing is not None: - calibration_dataset = ( - self.trainer.datamodule.val_dataloader().dataset + calibration_dataloader = ( + self.trainer.datamodule.val_dataloader() if self.calibration_set == "val" - else self.trainer.datamodule.test_dataloader()[0].dataset + else self.trainer.datamodule.test_dataloader()[0] ) with torch.inference_mode(False): - self.post_processing.fit(calibration_dataset) + self.post_processing.fit(calibration_dataloader) if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = [] From 93af2c910b2637a7806eed9b61b43641dae93c39 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 17 Mar 2025 16:55:27 +0100 Subject: [PATCH 38/47] :books: Update tutorials --- auto_tutorials_source/tutorial_mc_batch_norm.py | 13 ++++++++----- auto_tutorials_source/tutorial_scaler.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index bb726902..4ec27105 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -26,6 +26,7 @@ from pathlib import Path from torch import nn +from torch.utils.data import DataLoader from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule @@ -84,15 +85,17 @@ # We can now wrap the model in a MCBatchNorm to add stochasticity to the # predictions. We specify that the BatchNorm layers are to be converted to # MCBatchNorm layers, and that we want to use 8 stochastic estimators. -# The amount of stochasticity is controlled by the ``mc_batch_size`` argument. -# The larger the ``mc_batch_size``, the more stochastic the predictions will be. -# The authors suggest 32 as a good value for ``mc_batch_size`` but we use 4 here +# The amount of stochasticity is controlled by the ``batch_size`` parameter. +# of the DataLoader used to train the model. +# The larger the ``batch_size``, the more stochastic the predictions will be. +# The authors suggest 32 as a good value for ``batch_size`` but we use 16 here # to highlight the effect of stochasticity on the predictions. routine.model = MCBatchNorm( - routine.model, num_estimators=8, convert=True, mc_batch_size=16 + routine.model, num_estimators=8, convert=True ) -routine.model.fit(datamodule.train) +mc_batch_norm_dl = DataLoader(datamodule.train, batch_size=16, shuffle=True) +routine.model.fit(dataloader=mc_batch_norm_dl) routine = routine.eval() # To avoid prints # %% diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index ceaaa036..a04da64b 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -84,6 +84,7 @@ dataset, [1000, 1000, len(dataset) - 2000] ) test_dataloader = DataLoader(test_dataset, batch_size=32) +calibration_dataloader = DataLoader(cal_dataset, batch_size=32) # Initialize the ECE ece = CalibrationError(task="multiclass", num_classes=100) @@ -114,7 +115,7 @@ # Fit the scaler on the calibration dataset scaled_model = TemperatureScaler(model=model) -scaled_model.fit(calibration_set=cal_dataset) +scaled_model.fit(dataloader=calibration_dataloader) # %% # 6. Iterating Again to Compute the Improved ECE From e93a837fd612bc6dc0d70cddcf5d0aa094c46cba Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 19:05:53 +0100 Subject: [PATCH 39/47] :zap: Bump version --- docs/source/conf.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2c9d46a0..fedd153e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.4.2" +release = "0.4.3" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 77dd1ee1..fd6ab4ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.4.2" +version = "0.4.3" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, From 0c7b0fb893504ddc3c5794bd9364ebdc1609c9fb Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 19:14:38 +0100 Subject: [PATCH 40/47] :bug: Fix rare Fog corruption bug --- torch_uncertainty/transforms/corruption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 3399da6c..1759f38a 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -662,7 +662,7 @@ def forward(self, img: Tensor) -> Tensor: return img _, height, width = img.shape max_val = img.max() - random_height_map_size = int(2 ** (m.ceil(m.log2(max(height, width) - 1)))) + random_height_map_size = int(2 ** (m.ceil(m.log2(max(height, width))))) fog = ( self.mix[0] * plasma_fractal( From 3b5ab2271e1725cf22f23594f3db46175d64f1a5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 23:20:53 +0100 Subject: [PATCH 41/47] :bug: Don't overwrite CUB's root --- torch_uncertainty/datasets/classification/cub.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py index 38bacf70..7f97c932 100644 --- a/torch_uncertainty/datasets/classification/cub.py +++ b/torch_uncertainty/datasets/classification/cub.py @@ -55,6 +55,7 @@ def __init__( ) super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) + self.root = Path(root) training_idx = self._load_train_idx() self.attributes, self.uncertainties = self._load_attributes() From 8728459a32f0a1ccd0cbde5e117c4d6ba695783b Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 17 Mar 2025 23:47:43 +0100 Subject: [PATCH 42/47] :bug: Improve Elastic Transform --- torch_uncertainty/transforms/corruption.py | 55 ++++++++++++---------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 1759f38a..55ee3dfe 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -787,12 +787,13 @@ def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img image = np.array(rearrange(img, "c h w -> h w c"), dtype=np.float32) - shape = image.shape - shape_size = shape[:2] + height, width, channels = image.shape + shape_size = height, width + min_shape_size = min(shape_size) # random affine center_square = np.float32(shape_size) // 2 - square_size = min(shape_size) // 3 + square_size = min_shape_size // 3 pts1 = np.float32( [ center_square + square_size, @@ -804,8 +805,8 @@ def forward(self, img: Tensor) -> Tensor: ] ) pts2 = pts1 + self.rng.uniform( - -self.mix[2] * shape_size[0], - self.mix[2] * shape_size[0], + -self.mix[2] * min_shape_size, + self.mix[2] * min_shape_size, size=pts1.shape, ).astype(np.float32) affine_transform = cv2.getAffineTransform(pts1, pts2) @@ -816,17 +817,17 @@ def forward(self, img: Tensor) -> Tensor: borderMode=cv2.BORDER_REFLECT_101, ) - sigma = self.mix[1] * shape_size[0] - ks = min(int((sigma * 3 // 2) * 2 + 1), min(shape_size[:2]) // 2 * 2 - 1) + sigma = self.mix[1] * min_shape_size + ks = min(int((sigma * 3 // 2) * 2 + 1), min_shape_size // 2 * 2 - 1) dx = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape[:2]))), + torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape_size))), kernel_size=ks, sigma=(sigma, sigma), ).squeeze(0, 1) * self.mix[0] - * shape_size[0] + * shape_size[1] ) .numpy() .astype(np.float32)[..., np.newaxis] @@ -834,7 +835,7 @@ def forward(self, img: Tensor) -> Tensor: dy = ( ( gaussian_blur2d( - torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape[:2]))), + torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape_size))), kernel_size=ks, sigma=(sigma, sigma), ).squeeze(0, 1) @@ -845,14 +846,16 @@ def forward(self, img: Tensor) -> Tensor: .astype(np.float32)[..., np.newaxis] ) - x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) + x, y, z = np.meshgrid(np.arange(width), np.arange(height), np.arange(channels)) indices = ( np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1)), ) img = np.clip( - map_coordinates(image, indices, order=1, mode="reflect").reshape(shape), + map_coordinates(image, indices, order=1, mode="reflect").reshape( + (height, width, channels) + ), 0, 1, ) @@ -947,19 +950,19 @@ def forward(self, img: Tensor) -> Tensor: corruption_transforms = ( - GaussianNoise, - ShotNoise, - ImpulseNoise, - DefocusBlur, - GlassBlur, - MotionBlur, - ZoomBlur, - Snow, - Frost, - Fog, - Brightness, - Contrast, + # GaussianNoise, + # ShotNoise, + # ImpulseNoise, + # DefocusBlur, + # GlassBlur, + # MotionBlur, + # ZoomBlur, + # Snow, + # Frost, + # Fog, + # Brightness, + # Contrast, Elastic, - Pixelate, - JPEGCompression, + # Pixelate, + # JPEGCompression, ) From 96dd89533fc31e28961ecf63dc687c22364731d1 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Mar 2025 11:11:21 +0100 Subject: [PATCH 43/47] :bug: Remove comments --- torch_uncertainty/transforms/corruption.py | 28 +++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 55ee3dfe..ce1273b5 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -950,19 +950,19 @@ def forward(self, img: Tensor) -> Tensor: corruption_transforms = ( - # GaussianNoise, - # ShotNoise, - # ImpulseNoise, - # DefocusBlur, - # GlassBlur, - # MotionBlur, - # ZoomBlur, - # Snow, - # Frost, - # Fog, - # Brightness, - # Contrast, + GaussianNoise, + ShotNoise, + ImpulseNoise, + DefocusBlur, + GlassBlur, + MotionBlur, + ZoomBlur, + Snow, + Frost, + Fog, + Brightness, + Contrast, Elastic, - # Pixelate, - # JPEGCompression, + Pixelate, + JPEGCompression, ) From 5bb56c00f19947d1e3006c96348d2d217eaf7983 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Mar 2025 11:39:35 +0100 Subject: [PATCH 44/47] :hammer: Rework post-processing set --- auto_tutorials_source/tutorial_scaler.py | 4 +-- .../classification/test_cifar10.py | 2 +- tests/datamodules/test_abstract_datamodule.py | 21 ++++++++++-- tests/routines/test_classification.py | 4 +-- tests/routines/test_segmentation.py | 4 +-- .../classification/deep_ensembles.py | 2 -- .../baselines/classification/resnet.py | 10 ++---- .../baselines/classification/vgg.py | 4 --- .../baselines/classification/wideresnet.py | 4 --- .../baselines/segmentation/deeplab.py | 4 +-- torch_uncertainty/datamodules/abstract.py | 22 ++++++++++++- .../datamodules/classification/cifar10.py | 4 +++ .../datamodules/classification/cifar100.py | 4 +++ .../datamodules/classification/imagenet.py | 4 +++ .../datamodules/classification/mnist.py | 4 +++ .../classification/tiny_imagenet.py | 2 ++ .../post_processing/calibration/scaler.py | 3 +- torch_uncertainty/routines/classification.py | 33 +++++++------------ torch_uncertainty/routines/segmentation.py | 22 ++++++------- 19 files changed, 94 insertions(+), 63 deletions(-) diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index a04da64b..91850a8d 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -6,8 +6,8 @@ of the top-label predictions and the reliability of the underlying neural network. This tutorial provides extensive details on how to use the TemperatureScaler -class, however, this is done automatically in the classification routine when setting -the `calibration_set` to val or test. +class, however, this is done automatically in the datamodule when setting +the `postprocess_set` to val or test. Through this tutorial, we also see how to use the datamodules outside any Lightning trainers, and how to use TorchUncertainty's models. diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index 51e6a0d6..d583219a 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -10,7 +10,7 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" def test_cifar10_main(self): - dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16) + dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16, postprocess_set="test") assert dm.dataset == CIFAR10 assert isinstance(dm.train_transform.transforms[2], Cutout) diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 84a87293..1983a028 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -47,12 +47,29 @@ def test_cv_main(self): def test_errors(self): TUDataModule.__abstractmethods__ = set() - dm = TUDataModule("root", 128, 0.0, 4, True, True) + dm = TUDataModule( + root="root", + batch_size=128, + val_split=0.0, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 0.0, 4, True, True) + cv_dm = CrossValDataModule( + root="root", + train_idx=[0], + val_idx=[1], + datamodule=dm, + batch_size=128, + val_split=0.0, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) with pytest.raises(NotImplementedError): cv_dm.setup() cv_dm._get_train_data() diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 5164206e..cd724c68 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -360,12 +360,12 @@ def test_classification_failures(self): mixup_params=mixup_params, ) - with pytest.raises(ValueError, match="num_calibration_bins must be at least 2, got"): + with pytest.raises(ValueError, match="num_bins_cal_err must be at least 2, got"): ClassificationRoutine( model=nn.Identity(), num_classes=2, loss=nn.CrossEntropyLoss(), - num_calibration_bins=0, + num_bins_cal_err=0, ) with pytest.raises(ValueError): diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index fab7a27c..2a498a7a 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -99,10 +99,10 @@ def test_segmentation_errors(self): metric_subsampling_rate=-1, ) - with pytest.raises(ValueError, match="num_calibration_bins must be at least 2, got"): + with pytest.raises(ValueError, match="num_bins_cal_err must be at least 2, got"): SegmentationRoutine( model=nn.Identity(), num_classes=2, loss=nn.CrossEntropyLoss(), - num_calibration_bins=0, + num_bins_cal_err=0, ) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index bb2cb96f..bba7683c 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -26,7 +26,6 @@ def __init__( eval_grouping_loss: bool = False, ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, - calibration_set: Literal["val", "test"] = "val", ) -> None: log_path = Path(log_path) @@ -54,6 +53,5 @@ def __init__( eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, log_plots=log_plots, - calibration_set=calibration_set, ) self.save_hyperparameters() # coverage: ignore diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 8fb208ab..376f5815 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -70,11 +70,10 @@ def __init__( ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, pretrained: bool = False, ) -> None: r"""ResNet backbone baseline for classification providing support for @@ -154,15 +153,13 @@ def __init__( Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in a csv file or not. Defaults to ``False``. - calibration_set (Callable, optional): Calibration set. Defaults to - ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. - num_calibration_bins (int, optional): Number of calibration bins. + num_bins_cal_err (int, optional): Number of calibration bins. Defaults to ``15``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. @@ -244,7 +241,6 @@ def __init__( ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, - calibration_set=calibration_set, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 4ba0bc6b..520c6425 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -41,7 +41,6 @@ def __init__( ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, @@ -100,8 +99,6 @@ def __init__( Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in a csv file or not. Defaults to ``False``. - calibration_set (Callable, optional): Calibration set. Defaults to - ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to @@ -178,7 +175,6 @@ def __init__( ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, - calibration_set=calibration_set, eval_grouping_loss=eval_grouping_loss, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index f3d57fee..477b83fd 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -50,7 +50,6 @@ def __init__( ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, @@ -112,8 +111,6 @@ def __init__( Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in a csv file or not. Defaults to ``False``. - calibration_set (Callable, optional): Calibration set. Defaults to - ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to @@ -195,6 +192,5 @@ def __init__( ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, - calibration_set=calibration_set, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/deeplab.py b/torch_uncertainty/baselines/segmentation/deeplab.py index f3e3982c..65bc4630 100644 --- a/torch_uncertainty/baselines/segmentation/deeplab.py +++ b/torch_uncertainty/baselines/segmentation/deeplab.py @@ -30,7 +30,7 @@ def __init__( separable: bool, metric_subsampling_rate: float = 1e-2, log_plots: bool = False, - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, pretrained_backbone: bool = True, ) -> None: params = { @@ -54,6 +54,6 @@ def __init__( format_batch_fn=format_batch_fn, metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index fdaddd14..ed56b8e5 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -13,6 +13,8 @@ else: # coverage: ignore sklearn_installed = False +import logging + from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler @@ -33,8 +35,9 @@ def __init__( num_workers: int, pin_memory: bool, persistent_workers: bool, + postprocess_set: Literal["val", "test"] = "val", ) -> None: - """Abstract DataModule class. + """Abstract DataModule class for TorchUncertainty. This class implements the basic functionality of a DataModule. It includes setters and getters for the datasets, dataloaders, as well as the crossval @@ -47,6 +50,8 @@ def __init__( num_workers (int): Number of workers to use for data loading. pin_memory (bool): Whether to pin memory. persistent_workers (bool): Whether to use persistent workers. + postprocess_set (str): Which split to use as post-processing set to fit the + post-processing method. """ super().__init__() @@ -58,6 +63,10 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers + if postprocess_set == "test": + logging.warning("Fitting the calibration method on the test set!") + self.postprocess_set = postprocess_set + @abstractmethod def setup(self, stage: Literal["fit", "test"] | None = None) -> None: pass @@ -99,6 +108,14 @@ def test_dataloader(self) -> list[DataLoader]: """ return [self._data_loader(self.test)] + def postprocess_dataloader(self) -> DataLoader: + r"""Get the calibration dataloader. + + Return: + DataLoader: calibration dataloader. + """ + return self.val_dataloader() if self.postprocess_set == "val" else self.test_dataloader()[0] + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: """Create a dataloader for a given dataset. @@ -155,6 +172,7 @@ def make_cross_val_splits(self, n_splits: int = 10, train_over: int = 4) -> list datamodule=self, batch_size=self.batch_size, val_split=self.val_split, + postprocess_set=self.postprocess_set, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -176,6 +194,7 @@ def __init__( num_workers: int, pin_memory: bool, persistent_workers: bool, + postprocess_set: Literal["val", "test"] = "val", ) -> None: super().__init__( root=root, @@ -184,6 +203,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, + postprocess_set=postprocess_set, ) self.train_idx = train_idx diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 4aa12bfd..d52211e2 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -33,6 +33,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", num_workers: int = 1, basic_augment: bool = True, cutout: int | None = None, @@ -53,6 +54,8 @@ def __init__( batch_size (int): Number of samples per batch. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. basic_augment (bool): Whether to apply base augmentations. Defaults to @@ -73,6 +76,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 70224847..6334b10c 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -33,6 +33,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", basic_augment: bool = True, cutout: int | None = None, randaugment: bool = False, @@ -53,6 +54,8 @@ def __init__( batch_size (int): Number of samples per batch. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. @@ -72,6 +75,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 157bd8a2..23012f66 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -50,6 +50,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | Path | None = None, + postprocess_set: Literal["val", "test"] = "val", ood_ds: str = "openimage-o", test_alt: str | None = None, procedure: str | None = None, @@ -74,6 +75,8 @@ def __init__( val_split (float or Path): Share of samples to use for validation or path to a yaml file containing a list of validation images ids. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. ood_ds (str): Which out-of-distribution dataset to use. Defaults to ``"openimage-o"``. test_alt (str): Which test set to use. Defaults to ``None``. @@ -94,6 +97,7 @@ def __init__( root=Path(root), batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index ed31116a..a1d72de4 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -30,6 +30,7 @@ def __init__( eval_shift: bool = False, ood_ds: Literal["fashion", "notMNIST"] = "fashion", val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", num_workers: int = 1, basic_augment: bool = True, cutout: int | None = None, @@ -50,6 +51,8 @@ def __init__( notMNIST. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. basic_augment (bool): Whether to apply base augmentations. Defaults to @@ -63,6 +66,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index b8aa34e5..f84b972c 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -37,6 +37,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", ood_ds: str = "svhn", interpolation: str = "bilinear", basic_augment: bool = True, @@ -49,6 +50,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 54c5c703..3b1cedec 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -92,7 +92,7 @@ def calib_eval() -> float: @torch.no_grad() def forward(self, inputs: Tensor) -> Tensor: - if not self.trained: + if self.model is None or not self.trained: logging.error( "TemperatureScaler has not been trained yet. Returning manually tempered inputs." ) @@ -111,7 +111,6 @@ def _scale(self, logits: Tensor) -> Tensor: def fit_predict( self, - # calibration_set: Dataset, dataloader: DataLoader, progress: bool = True, ) -> Tensor: diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 0d0faeb9..52b8dc46 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -74,8 +74,7 @@ def __init__( eval_grouping_loss: bool = False, ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", post_processing: PostProcessing | None = None, - calibration_set: Literal["val", "test"] = "val", - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, log_plots: bool = False, save_in_csv: bool = False, ) -> None: @@ -111,10 +110,8 @@ def __init__( post_processing (PostProcessing, optional): Post-processing method to train on the calibration set. No post-processing if None. Defaults to ``None``. - calibration_set (str, optional): The post-hoc calibration dataset to - use for the post-processing method. Defaults to ``val``. - num_calibration_bins (int, optional): Number of bins to compute calibration - metrics. Defaults to ``15``. + num_bins_cal_err (int, optional): Number of bins to compute calibration + error metrics. Defaults to ``15``. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. save_in_csv(bool, optional): Save the results in csv. Defaults to @@ -150,7 +147,7 @@ def __init__( is_ensemble=is_ensemble, ood_criterion=ood_criterion, eval_grouping_loss=eval_grouping_loss, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, mixup_params=mixup_params, post_processing=post_processing, format_batch_fn=format_batch_fn, @@ -166,11 +163,10 @@ def __init__( self.ood_criterion = ood_criterion self.log_plots = log_plots self.save_in_csv = save_in_csv - self.calibration_set = calibration_set self.binary_cls = num_classes == 1 self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) - self.num_calibration_bins = num_calibration_bins + self.num_bins_cal_err = num_bins_cal_err self.model = model self.loss = loss self.format_batch_fn = format_batch_fn @@ -202,13 +198,13 @@ def _init_metrics(self) -> None: "cls/NLL": CategoricalNLL(), "cal/ECE": CalibrationError( task=task, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), "cal/aECE": CalibrationError( task=task, adaptive=True, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), "sc/AURC": AURC(), @@ -381,13 +377,8 @@ def on_test_start(self) -> None: the storage lists for logit plotting and update the batchnorms if needed. """ if self.post_processing is not None: - calibration_dataloader = ( - self.trainer.datamodule.val_dataloader() - if self.calibration_set == "val" - else self.trainer.datamodule.test_dataloader()[0] - ) with torch.inference_mode(False): - self.post_processing.fit(calibration_dataloader) + self.post_processing.fit(self.trainer.datamodule.postprocess_dataloader()) if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = [] @@ -699,7 +690,7 @@ def _classification_routine_checks( is_ensemble: bool, ood_criterion: str, eval_grouping_loss: bool, - num_calibration_bins: int, + num_bins_cal_err: int, mixup_params: dict | None, post_processing: PostProcessing | None, format_batch_fn: nn.Module | None, @@ -712,7 +703,7 @@ def _classification_routine_checks( is_ensemble (bool): whether the model is an ensemble or a single model. ood_criterion (str): the criterion for the binary OOD detection task. eval_grouping_loss (bool): whether to evaluate the grouping loss. - num_calibration_bins (int): the number of bins for the evaluation of the calibration. + num_bins_cal_err (int): the number of bins for the evaluation of the calibration. mixup_params (dict | None): the dictionary to setup the mixup augmentation. post_processing (PostProcessing | None): the post-processing module. format_batch_fn (nn.Module | None): the function for formatting the batch for ensembles. @@ -758,8 +749,8 @@ def _classification_routine_checks( "attribute to compute the grouping loss." ) - if num_calibration_bins < 2: - raise ValueError(f"num_calibration_bins must be at least 2, got {num_calibration_bins}.") + if num_bins_cal_err < 2: + raise ValueError(f"num_bins_cal_err must be at least 2, got {num_bins_cal_err}.") if mixup_params is not None and isinstance(format_batch_fn, RepeatTarget): raise ValueError( diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index fac59ecb..643d983b 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -37,7 +37,7 @@ def __init__( metric_subsampling_rate: float = 1e-2, log_plots: bool = False, num_samples_to_plot: int = 3, - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, ) -> None: r"""Routine for training & testing on **segmentation** tasks. @@ -57,8 +57,8 @@ def __init__( metrics. Defaults to ``False``. num_samples_to_plot (int, optional): Number of samples to plot in the segmentation results. Defaults to ``3``. - num_calibration_bins (int, optional): Number of bins to compute calibration - metrics. Defaults to ``15``. + num_bins_cal_err (int, optional): Number of bins to compute calibration + error metrics. Defaults to ``15``. Warning: You must define :attr:`optim_recipe` if you do not use the CLI. @@ -72,7 +72,7 @@ def __init__( _segmentation_routine_checks( num_classes=num_classes, metric_subsampling_rate=metric_subsampling_rate, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, ) if eval_shift: raise NotImplementedError( @@ -81,7 +81,7 @@ def __init__( self.model = model self.num_classes = num_classes - self.num_calibration_bins = num_calibration_bins + self.num_bins_cal_err = num_bins_cal_err self.loss = loss self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) @@ -118,13 +118,13 @@ def _init_metrics(self) -> None: "cal/ECE": CalibrationError( task="multiclass", num_classes=self.num_classes, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, ), "cal/aECE": CalibrationError( task="multiclass", adaptive=True, num_classes=self.num_classes, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, ), "sc/AURC": AURC(), "sc/AUGRC": AUGRC(), @@ -327,14 +327,14 @@ def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: def _segmentation_routine_checks( num_classes: int, metric_subsampling_rate: float, - num_calibration_bins: int, + num_bins_cal_err: int, ) -> None: """Check the domains of the routine's parameters. Args: num_classes (int): the number of classes in the dataset. metric_subsampling_rate (float): the rate of subsampling to compute the metrics. - num_calibration_bins (int): the number of bins for the evaluation of the calibration. + num_bins_cal_err (int): the number of bins for the evaluation of the calibration. """ if num_classes < 2: raise ValueError(f"num_classes must be at least 2, got {num_classes}.") @@ -344,5 +344,5 @@ def _segmentation_routine_checks( f"metric_subsampling_rate must be in the range (0, 1], got {metric_subsampling_rate}." ) - if num_calibration_bins < 2: - raise ValueError(f"num_calibration_bins must be at least 2, got {num_calibration_bins}.") + if num_bins_cal_err < 2: + raise ValueError(f"num_bins_cal_err must be at least 2, got {num_bins_cal_err}.") From 8df18d00d88450f519b7a4db4a5ddf1be5d9f37f Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 18 Mar 2025 18:08:19 +0100 Subject: [PATCH 45/47] :hammer: Make mc_batch_size explicit again --- .../tutorial_mc_batch_norm.py | 13 ++--- auto_tutorials_source/tutorial_scaler.py | 4 +- tests/post_processing/test_mc_batch_norm.py | 13 ++--- .../post_processing/mc_batch_norm.py | 48 +++++++++++++++---- 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 4ec27105..18df5cff 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -26,7 +26,6 @@ from pathlib import Path from torch import nn -from torch.utils.data import DataLoader from torch_uncertainty import TUTrainer from torch_uncertainty.datamodules import MNISTDataModule @@ -85,17 +84,15 @@ # We can now wrap the model in a MCBatchNorm to add stochasticity to the # predictions. We specify that the BatchNorm layers are to be converted to # MCBatchNorm layers, and that we want to use 8 stochastic estimators. -# The amount of stochasticity is controlled by the ``batch_size`` parameter. -# of the DataLoader used to train the model. -# The larger the ``batch_size``, the more stochastic the predictions will be. -# The authors suggest 32 as a good value for ``batch_size`` but we use 16 here +# The amount of stochasticity is controlled by the ``mc_batch_size`` argument. +# The larger the ``mc_batch_size``, the more stochastic the predictions will be. +# The authors suggest 32 as a good value for ``mc_batch_size`` but we use 16 here # to highlight the effect of stochasticity on the predictions. routine.model = MCBatchNorm( - routine.model, num_estimators=8, convert=True + routine.model, num_estimators=8, convert=True, mc_batch_size=16 ) -mc_batch_norm_dl = DataLoader(datamodule.train, batch_size=16, shuffle=True) -routine.model.fit(dataloader=mc_batch_norm_dl) +routine.model.fit(dataloader=datamodule.postprocess_dataloader()) routine = routine.eval() # To avoid prints # %% diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index 91850a8d..e8072333 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -61,8 +61,8 @@ dm.prepare_data() dm.setup("test") -# Get the full test dataloader (unused in this tutorial) -dataloader = dm.test_dataloader()[0] +# Get the full post-processing dataloader (unused in this tutorial) +dataloader = dm.postprocess_dataloader() # %% # 4. Iterating on the Dataloader and Computing the ECE diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index b2277d94..1d8d552b 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -18,13 +18,11 @@ class TestMCBatchNorm: def test_main(self): """Test initialization.""" mc_model = lenet(1, 1, norm=partial(MCBatchNorm2d, num_estimators=2)) - stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False) + stoch_model = MCBatchNorm(mc_model, num_estimators=2, convert=False, mc_batch_size=1) model = lenet(1, 1, norm=nn.BatchNorm2d) stoch_model = MCBatchNorm( - nn.Sequential(model), - num_estimators=2, - convert=True, + nn.Sequential(model), num_estimators=2, convert=True, mc_batch_size=1 ) dataset = DummyClassificationDataset( "./", @@ -34,14 +32,13 @@ def test_main(self): num_images=2, transform=T.ToTensor(), ) - dl = DataLoader(dataset, batch_size=1, shuffle=True) - stoch_model.fit(dataloader=dl) + stoch_model.fit(dataloader=DataLoader(dataset, batch_size=6, shuffle=True)) stoch_model.train() stoch_model(torch.randn(1, 1, 20, 20)) stoch_model.eval() stoch_model(torch.randn(1, 1, 20, 20)) - stoch_model = MCBatchNorm(num_estimators=2, convert=False) + stoch_model = MCBatchNorm(num_estimators=2, convert=False, mc_batch_size=1) stoch_model.set_model(mc_model) def test_errors(self): @@ -54,7 +51,7 @@ def test_errors(self): with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=True) model = lenet(1, 1, norm=nn.BatchNorm2d) - stoch_model = MCBatchNorm(model, num_estimators=4, convert=True) + stoch_model = MCBatchNorm(model, num_estimators=4, convert=True, mc_batch_size=1) dataset = DummyClassificationDataset( "./", num_channels=1, diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index d5d467b6..9fb8dae0 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -19,6 +19,7 @@ def __init__( model: nn.Module | None = None, num_estimators: int = 16, convert: bool = True, + mc_batch_size: int = 32, device: Literal["cpu", "cuda"] | torch.device | None = None, ) -> None: """Monte Carlo Batch Normalization wrapper. @@ -26,9 +27,11 @@ def __init__( Args: model (nn.Module): model to be converted. num_estimators (int): number of estimators. - convert (bool): whether to convert the model. + convert (bool): whether to convert the model. Defaults to ``True``. + mc_batch_size (int, optional): Monte Carlo batch size. The smaller the more variability + in the predictions. Defaults to 32. device (Literal["cpu", "cuda"] | torch.device | None, optional): device. - Defaults to None. + Defaults to ``None``. Note: This wrapper will be stochastic in eval mode only. @@ -38,15 +41,16 @@ def __init__( batch normalized deep networks. In ICML 2018. """ super().__init__() - self.convert = convert self.num_estimators = num_estimators + self.convert = convert + self.mc_batch_size = mc_batch_size self.device = device if model is not None: self._setup_model(model) def _setup_model(self, model): - _mcbn_checks(model, self.num_estimators, self.convert) + _mcbn_checks(model, self.num_estimators, self.mc_batch_size, self.convert) self.model = deepcopy(model) # TODO: Is it necessary? self.model = self.model.eval() if self.convert: @@ -62,11 +66,11 @@ def fit(self, dataloader: DataLoader) -> None: """Fit the model on the dataset. Args: - dataloader (DataLoader): DataLoader with the training dataset. + dataloader (DataLoader): DataLoader with the post-processing dataset. Note: This method is used to populate the MC BatchNorm layers. - Use the training dataset. + Use the post-processing dataset. Warning: The ``batch_size`` of the DataLoader should be carefully chosen as it @@ -75,6 +79,7 @@ def fit(self, dataloader: DataLoader) -> None: Raises: ValueError: If there are less batches than the number of estimators. """ + dataloader = init_dataloader(dataloader, batch_size=self.mc_batch_size) self.counter = 0 self.reset_counters() self.set_accumulate(True) @@ -101,7 +106,7 @@ def forward( if self.training: return self.model(inputs) if not self.trained: - raise RuntimeError("MCBatchNorm has not been trained. Call .fit() first.") + raise RuntimeError("MCBatchNorm has not been fit. Call .fit() first.") self.reset_counters() return torch.cat([self._est_forward(inputs) for _ in range(self.num_estimators)], dim=0) @@ -165,8 +170,35 @@ def has_mcbn(model: nn.Module) -> bool: return any(isinstance(module, MCBatchNorm2d) for module in model.modules()) -def _mcbn_checks(model, num_estimators, convert): +def init_dataloader(dataloader: DataLoader, batch_size: int): + """Reinitialize dataloader with the chosen batch size. + + It is impossible to change the ``batch_size`` of an already-instantiated dataloader. + + Args: + dataloader (DataLoader): the dataloader to be reinitialized with + batch_size (int): the given batch_size. + """ + return DataLoader( + dataloader.dataset, + batch_size=batch_size, + sampler=dataloader.sampler, + num_workers=dataloader.num_workers, + pin_memory=dataloader.pin_memory, + drop_last=dataloader.drop_last, + timeout=dataloader.timeout, + worker_init_fn=dataloader.worker_init_fn, + multiprocessing_context=dataloader.multiprocessing_context, + generator=dataloader.generator, + prefetch_factor=dataloader.prefetch_factor, + persistent_workers=dataloader.persistent_workers, + ) + + +def _mcbn_checks(model, num_estimators, mc_batch_size, convert): if num_estimators < 1 or not isinstance(num_estimators, int): raise ValueError(f"num_estimators must be a positive integer, got {num_estimators}.") + if mc_batch_size < 1 or not isinstance(mc_batch_size, int): + raise ValueError(f"mc_batch_size must be a positive integer, got {mc_batch_size}.") if not convert and not has_mcbn(model): raise ValueError("model does not contain any MCBatchNorm2d nor is not to be converted.") From 81d8c4a818c457b299c64839c5e58bf81c528967 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 19 Mar 2025 14:14:15 +0100 Subject: [PATCH 46/47] :bug: Fix Temperature Scaling tutorial --- auto_tutorials_source/tutorial_scaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index e8072333..a5c3d9d4 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -57,7 +57,7 @@ # element if eval_ood is True: the dataloader of in-distribution data and the dataloader # of out-of-distribution data. Otherwise, it is a list of 1 element. -dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32) +dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32, postprocess_set="test") dm.prepare_data() dm.setup("test") From de6d239bece82be261009448cd61748127bc9bd2 Mon Sep 17 00:00:00 2001 From: alafage Date: Wed, 19 Mar 2025 14:16:07 +0100 Subject: [PATCH 47/47] :hammer: No seg plot when `trainer.datamodule is None` in `SegmentationRoutine` --- torch_uncertainty/routines/segmentation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 643d983b..662d5581 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -283,7 +283,10 @@ def on_test_epoch_end(self) -> None: "Selective Classification/Generalized Risk-Coverage curve", self.test_sbsmpl_seg_metrics["sc/AUGRC"].plot()[0], ) - self.log_segmentation_plots() + if self.trainer.datamodule is not None: + self.log_segmentation_plots() + else: + print("No datamodule found, skipping segmentation plots.") def log_segmentation_plots(self) -> None: """Build and log examples of segmentation plots from the test set."""