diff --git a/.gitignore b/.gitignore index 9ad40cf4..ca0ee960 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Custom .vscode/ +.itea/ data/ logs/ lightning_logs/ diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 55f6b3c6..70f5cf8e 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_samples: 16 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 0c7989ab..3f8b63c2 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -38,7 +38,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_classes: 10 loss: CrossEntropyLoss diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index c5398a87..c387ef47 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} save_schedule: - 20 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 363461c6..2ea72001 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} momentum: 0.99 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 2274bdb5..09d7d506 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 19 cycle_length: 5 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index ddff0067..e33d954f 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 10 cycle_length: 5 diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index 7a967edc..73304283 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -47,6 +47,7 @@ def test_mnist_cutout(self): dm.setup("other") dm.eval_ood = True + dm.eval_shift = True dm.ood_transform = dm.test_transform dm.val_split = 0.1 dm.prepare_data() diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 0dc88033..40168429 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -15,6 +15,8 @@ class TUDataModule(ABC, LightningDataModule): val: Dataset test: Dataset + shift_severity = 1 + def __init__( self, root: str | Path, diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 3ee9c8a3..1e1441c2 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -228,8 +228,7 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 2717cc12..11d0a7fa 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -210,8 +210,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or + CIFAR-100C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d436fd16..24fd5c6d 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -289,8 +289,8 @@ def test_dataloader(self) -> list[DataLoader]: """Get the test dataloaders for ImageNet. Return: - list[DataLoader]: ImageNet test set (in distribution data) and - Textures test split (out-of-distribution data). + list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split + (out-of-distribution data), and/or ImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index a49fc168..f6879c1a 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -171,10 +171,12 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: Dataloaders of the MNIST test set (in - distribution data) and FashionMNIST test split - (out-of-distribution data). + distribution data), FashionMNIST or NotMNIST test split + (out-of-distribution data), and/or MNISTC (shifted data). """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 3c3c7ec4..bf95159e 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -237,8 +237,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, OOD data, and/or + TinyImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index bb616c58..8487f93f 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -606,9 +606,7 @@ def on_test_epoch_end(self) -> None: if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() - shift_severity = self.trainer.test_dataloaders[ - 2 if self.eval_ood else 1 - ].dataset.shift_severity + shift_severity = self.trainer.datamodule.shift_severity tmp_metrics["shift/shift_severity"] = shift_severity self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics)