From 2522a31c04c73158118f15a884c5d9a2d12f30b2 Mon Sep 17 00:00:00 2001 From: Anton Date: Tue, 4 Mar 2025 22:48:22 +0000 Subject: [PATCH 1/7] :bug: Fix MNIST test dataloader for shifted data --- torch_uncertainty/datamodules/classification/mnist.py | 4 +++- torch_uncertainty/datasets/classification/mnist_c.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index a49fc168..d50421bf 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -171,10 +171,12 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: Dataloaders of the MNIST test set (in - distribution data) and FashionMNIST test split + distribution data), MNISTC (shifted data), and FashionMNIST test split (out-of-distribution data). """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index 6d8086cb..aeffaa29 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -23,6 +23,7 @@ class MNISTC(VisionDataset): takes in the target and transforms it. Defaults to None. subset (str): The subset to use, one of ``all`` or the keys in ``mnistc_subsets``. + shift_severity (int): The shift_severity of the corruption, between 1 and 5. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. @@ -70,6 +71,7 @@ def __init__( target_transform: Callable | None = None, split: Literal["train", "test"] = "test", subset: str = "all", + shift_severity: int = 1, download: bool = False, ) -> None: self.root = Path(root) @@ -90,6 +92,12 @@ def __init__( raise ValueError(f"The subset '{subset}' does not exist in MNIST-C.") self.subset = subset + self.shift_severity = shift_severity + if shift_severity not in list(range(1, 6)): + raise ValueError( + "Corruptions shift_severity should be chosen between 1 and 5 " "included." + ) + if split not in ["train", "test"]: raise ValueError(f"The split '{split}' should be either 'train' or 'test'.") self.split = split From 3b6c7c62b1116efbf8c0f463f70537b7e0a75bcc Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 11:55:15 +0100 Subject: [PATCH 2/7] :bug: Fast track config fixes Co-authored-by: Anton --- .gitignore | 1 + experiments/classification/mnist/configs/bayesian_lenet.yaml | 1 - experiments/classification/mnist/configs/lenet.yaml | 1 - .../classification/mnist/configs/lenet_checkpoint_ensemble.yaml | 1 - experiments/classification/mnist/configs/lenet_ema.yaml | 1 - experiments/classification/mnist/configs/lenet_swa.yaml | 1 - experiments/classification/mnist/configs/lenet_swag.yaml | 1 - 7 files changed, 1 insertion(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 9ad40cf4..ca0ee960 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Custom .vscode/ +.itea/ data/ logs/ lightning_logs/ diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 55f6b3c6..70f5cf8e 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_samples: 16 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 0c7989ab..3f8b63c2 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -38,7 +38,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_classes: 10 loss: CrossEntropyLoss diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index c5398a87..c387ef47 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} save_schedule: - 20 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 363461c6..2ea72001 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} momentum: 0.99 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 2274bdb5..09d7d506 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 19 cycle_length: 5 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index ddff0067..e33d954f 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 10 cycle_length: 5 From 11801da484c0dc3651f7dfb740e5df2c52a88afc Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 12:06:27 +0100 Subject: [PATCH 3/7] :fire: Remove shift-severity in MNISTC --- torch_uncertainty/datasets/classification/mnist_c.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index aeffaa29..e61071d8 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -71,7 +71,6 @@ def __init__( target_transform: Callable | None = None, split: Literal["train", "test"] = "test", subset: str = "all", - shift_severity: int = 1, download: bool = False, ) -> None: self.root = Path(root) @@ -92,12 +91,6 @@ def __init__( raise ValueError(f"The subset '{subset}' does not exist in MNIST-C.") self.subset = subset - self.shift_severity = shift_severity - if shift_severity not in list(range(1, 6)): - raise ValueError( - "Corruptions shift_severity should be chosen between 1 and 5 " "included." - ) - if split not in ["train", "test"]: raise ValueError(f"The split '{split}' should be either 'train' or 'test'.") self.split = split From 16cfb44d8286f95921869f7f3b63bf3a6ac185c0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 12:06:55 +0100 Subject: [PATCH 4/7] :book: Update test dataloaders docstring --- torch_uncertainty/datamodules/classification/cifar10.py | 3 +-- torch_uncertainty/datamodules/classification/cifar100.py | 4 ++-- torch_uncertainty/datamodules/classification/imagenet.py | 4 ++-- torch_uncertainty/datamodules/classification/mnist.py | 4 ++-- torch_uncertainty/datamodules/classification/tiny_imagenet.py | 4 ++-- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 3ee9c8a3..1e1441c2 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -228,8 +228,7 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 2717cc12..11d0a7fa 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -210,8 +210,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or + CIFAR-100C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d436fd16..24fd5c6d 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -289,8 +289,8 @@ def test_dataloader(self) -> list[DataLoader]: """Get the test dataloaders for ImageNet. Return: - list[DataLoader]: ImageNet test set (in distribution data) and - Textures test split (out-of-distribution data). + list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split + (out-of-distribution data), and/or ImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index d50421bf..f6879c1a 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -171,8 +171,8 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: Dataloaders of the MNIST test set (in - distribution data), MNISTC (shifted data), and FashionMNIST test split - (out-of-distribution data). + distribution data), FashionMNIST or NotMNIST test split + (out-of-distribution data), and/or MNISTC (shifted data). """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 3c3c7ec4..bf95159e 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -237,8 +237,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, OOD data, and/or + TinyImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: From 9d761f1c699486221ab88da369bf430ef2fd8831 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 12:08:09 +0100 Subject: [PATCH 5/7] :hammer: use datamodule's shift-severity --- torch_uncertainty/datamodules/abstract.py | 2 ++ torch_uncertainty/routines/classification.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 0dc88033..40168429 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -15,6 +15,8 @@ class TUDataModule(ABC, LightningDataModule): val: Dataset test: Dataset + shift_severity = 1 + def __init__( self, root: str | Path, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index bb616c58..8487f93f 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -606,9 +606,7 @@ def on_test_epoch_end(self) -> None: if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() - shift_severity = self.trainer.test_dataloaders[ - 2 if self.eval_ood else 1 - ].dataset.shift_severity + shift_severity = self.trainer.datamodule.shift_severity tmp_metrics["shift/shift_severity"] = shift_severity self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) From 3a186b835b554bf431cfd46462b4dd3ae970800a Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 13:05:52 +0100 Subject: [PATCH 6/7] :white_check_mark: Fix coverage --- tests/datamodules/classification/test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index 7a967edc..73304283 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -47,6 +47,7 @@ def test_mnist_cutout(self): dm.setup("other") dm.eval_ood = True + dm.eval_shift = True dm.ood_transform = dm.test_transform dm.val_split = 0.1 dm.prepare_data() From d4d09f5806f59ac713da4e2387a65c4a13505688 Mon Sep 17 00:00:00 2001 From: Olivier Laurent <62881275+o-laurent@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:16:14 +0100 Subject: [PATCH 7/7] :fire: Remove documentation for removed argument --- torch_uncertainty/datasets/classification/mnist_c.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index e61071d8..6d8086cb 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -23,7 +23,6 @@ class MNISTC(VisionDataset): takes in the target and transforms it. Defaults to None. subset (str): The subset to use, one of ``all`` or the keys in ``mnistc_subsets``. - shift_severity (int): The shift_severity of the corruption, between 1 and 5. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False.