Skip to content

🐛 Fix MNIST test dataloader for shifted data #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Custom
.vscode/
.itea/
data/
logs/
lightning_logs/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ model:
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
last_layer_dropout: false
layer_args: {}
num_samples: 16
num_classes: 10
Expand Down
1 change: 0 additions & 1 deletion experiments/classification/mnist/configs/lenet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ model:
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
last_layer_dropout: false
layer_args: {}
num_classes: 10
loss: CrossEntropyLoss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ model:
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
last_layer_dropout: false
layer_args: {}
save_schedule:
- 20
Expand Down
1 change: 0 additions & 1 deletion experiments/classification/mnist/configs/lenet_ema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ model:
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
last_layer_dropout: false
layer_args: {}
momentum: 0.99
num_classes: 10
Expand Down
1 change: 0 additions & 1 deletion experiments/classification/mnist/configs/lenet_swa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ model:
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
last_layer_dropout: false
layer_args: {}
cycle_start: 19
cycle_length: 5
Expand Down
1 change: 0 additions & 1 deletion experiments/classification/mnist/configs/lenet_swag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ model:
norm: torch.nn.Identity
groups: 1
dropout_rate: 0
last_layer_dropout: false
layer_args: {}
cycle_start: 10
cycle_length: 5
Expand Down
1 change: 1 addition & 0 deletions tests/datamodules/classification/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_mnist_cutout(self):
dm.setup("other")

dm.eval_ood = True
dm.eval_shift = True
dm.ood_transform = dm.test_transform
dm.val_split = 0.1
dm.prepare_data()
Expand Down
2 changes: 2 additions & 0 deletions torch_uncertainty/datamodules/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class TUDataModule(ABC, LightningDataModule):
val: Dataset
test: Dataset

shift_severity = 1

def __init__(
self,
root: str | Path,
Expand Down
3 changes: 1 addition & 2 deletions torch_uncertainty/datamodules/classification/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ def test_dataloader(self) -> list[DataLoader]:
r"""Get test dataloaders.

Return:
list[DataLoader]: test set for in distribution data
and out-of-distribution data.
list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data.
"""
dataloader = [self._data_loader(self.test)]
if self.eval_ood:
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/datamodules/classification/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ def test_dataloader(self) -> list[DataLoader]:
r"""Get test dataloaders.

Return:
list[DataLoader]: test set for in distribution data
and out-of-distribution data.
list[DataLoader]: test set for in distribution data, SVHN data, and/or
CIFAR-100C data.
"""
dataloader = [self._data_loader(self.test)]
if self.eval_ood:
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/datamodules/classification/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ def test_dataloader(self) -> list[DataLoader]:
"""Get the test dataloaders for ImageNet.

Return:
list[DataLoader]: ImageNet test set (in distribution data) and
Textures test split (out-of-distribution data).
list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split
(out-of-distribution data), and/or ImageNetC data.
"""
dataloader = [self._data_loader(self.test)]
if self.eval_ood:
Expand Down
6 changes: 4 additions & 2 deletions torch_uncertainty/datamodules/classification/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ def test_dataloader(self) -> list[DataLoader]:

Return:
list[DataLoader]: Dataloaders of the MNIST test set (in
distribution data) and FashionMNIST test split
(out-of-distribution data).
distribution data), FashionMNIST or NotMNIST test split
(out-of-distribution data), and/or MNISTC (shifted data).
"""
dataloader = [self._data_loader(self.test)]
if self.eval_ood:
dataloader.append(self._data_loader(self.ood))
if self.eval_shift:
dataloader.append(self._data_loader(self.shift))
return dataloader
4 changes: 2 additions & 2 deletions torch_uncertainty/datamodules/classification/tiny_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def test_dataloader(self) -> list[DataLoader]:
r"""Get test dataloaders for TinyImageNet.

Return:
list[DataLoader]: test set for in distribution data
and out-of-distribution data.
list[DataLoader]: test set for in distribution data, OOD data, and/or
TinyImageNetC data.
"""
dataloader = [self._data_loader(self.test)]
if self.eval_ood:
Expand Down
4 changes: 1 addition & 3 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,7 @@ def on_test_epoch_end(self) -> None:

if self.eval_shift:
tmp_metrics = self.test_shift_metrics.compute()
shift_severity = self.trainer.test_dataloaders[
2 if self.eval_ood else 1
].dataset.shift_severity
shift_severity = self.trainer.datamodule.shift_severity
tmp_metrics["shift/shift_severity"] = shift_severity
self.log_dict(tmp_metrics, sync_dist=True)
result_dict.update(tmp_metrics)
Expand Down