diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 8435ac79..594a30cb 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -36,23 +36,23 @@ jobs: echo "PYTHON_VERSION=$(python -c "import platform; print(platform.python_version())")" echo "PYTHON_VERSION=$(python -c "import platform; print(platform.python_version())")" >> $GITHUB_ENV - - name: Get changed files - id: changed-files-specific - uses: tj-actions/changed-files@v42 - with: - files: | - auto_tutorials_source/** - data/** - experiments/** - docs/** - *.md - *.yaml - *.yml - LICENSE - .gitignore + # - name: Get changed files + # id: changed-files-specific + # uses: tj-actions/changed-files@v42 + # with: + # files: | + # auto_tutorials_source/** + # data/** + # experiments/** + # docs/** + # *.md + # *.yaml + # *.yml + # LICENSE + # .gitignore - name: Cache folder for TorchUncertainty - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' uses: actions/cache@v4 id: cache-folder with: @@ -61,43 +61,45 @@ jobs: key: torch-uncertainty-${{ runner.os }} - name: Install dependencies - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[all] - name: Check style & format - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m ruff check torch_uncertainty --no-fix --statistics python3 -m ruff format torch_uncertainty --check - name: Test with pytest and compute coverage - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | python3 -m pytest --cov --cov-report xml --durations 10 --junitxml=junit.xml - name: Upload coverage to Codecov - if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + # if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + if: github.event_name != 'pull_request' || github.base_ref == 'dev' uses: codecov/codecov-action@v4 continue-on-error: true with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml - flags: cpu,pytest + flags: pytest name: CPU-coverage env_vars: PYTHON_VERSION - name: Upload test results to Codecov - if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + # if: steps.changed-files-specific.outputs.only_changed != 'true' && (github.event_name != 'pull_request' || github.base_ref == 'dev') + if: github.event_name != 'pull_request' || github.base_ref == 'dev' uses: codecov/test-results-action@v1 continue-on-error: true with: token: ${{ secrets.CODECOV_TOKEN }} - flags: cpu,pytest + flags: pytest env_vars: PYTHON_VERSION - name: Test sphinx build without tutorials - if: steps.changed-files-specific.outputs.only_changed != 'true' + # if: steps.changed-files-specific.outputs.only_changed != 'true' run: | cd docs && make clean && make html-noplot diff --git a/README.md b/README.md index e0f735a4..1e476b03 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,9 @@ This package provides a multi-level API, including: - easy-to-use :zap: lightning **uncertainty-aware** training & evaluation routines for **4 tasks**: classification, probabilistic and pointwise regression, and segmentation. - ready-to-train baselines on research datasets, such as ImageNet and CIFAR -- [pretrained weights](https://huggingface.co/torch-uncertainty) for these baselines on ImageNet and CIFAR ( :construction: work in progress :construction: ). - **layers**, **models**, **metrics**, & **losses** available for use in your networks - scikit-learn style post-processing methods such as Temperature Scaling. +- transformations, including corruptions resulting in additional "corrupted datasets" available on [HuggingFace](https://huggingface.co/torch-uncertainty) Have a look at the [Reference page](https://torch-uncertainty.github.io/references.html) or the [API reference](https://torch-uncertainty.github.io/api.html) for a more exhaustive list of the implemented methods, datasets, metrics, etc. diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index bb726902..18df5cff 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -86,13 +86,13 @@ # MCBatchNorm layers, and that we want to use 8 stochastic estimators. # The amount of stochasticity is controlled by the ``mc_batch_size`` argument. # The larger the ``mc_batch_size``, the more stochastic the predictions will be. -# The authors suggest 32 as a good value for ``mc_batch_size`` but we use 4 here +# The authors suggest 32 as a good value for ``mc_batch_size`` but we use 16 here # to highlight the effect of stochasticity on the predictions. routine.model = MCBatchNorm( routine.model, num_estimators=8, convert=True, mc_batch_size=16 ) -routine.model.fit(datamodule.train) +routine.model.fit(dataloader=datamodule.postprocess_dataloader()) routine = routine.eval() # To avoid prints # %% diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index ceaaa036..a5c3d9d4 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -6,8 +6,8 @@ of the top-label predictions and the reliability of the underlying neural network. This tutorial provides extensive details on how to use the TemperatureScaler -class, however, this is done automatically in the classification routine when setting -the `calibration_set` to val or test. +class, however, this is done automatically in the datamodule when setting +the `postprocess_set` to val or test. Through this tutorial, we also see how to use the datamodules outside any Lightning trainers, and how to use TorchUncertainty's models. @@ -57,12 +57,12 @@ # element if eval_ood is True: the dataloader of in-distribution data and the dataloader # of out-of-distribution data. Otherwise, it is a list of 1 element. -dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32) +dm = CIFAR100DataModule(root="./data", eval_ood=False, batch_size=32, postprocess_set="test") dm.prepare_data() dm.setup("test") -# Get the full test dataloader (unused in this tutorial) -dataloader = dm.test_dataloader()[0] +# Get the full post-processing dataloader (unused in this tutorial) +dataloader = dm.postprocess_dataloader() # %% # 4. Iterating on the Dataloader and Computing the ECE @@ -84,6 +84,7 @@ dataset, [1000, 1000, len(dataset) - 2000] ) test_dataloader = DataLoader(test_dataset, batch_size=32) +calibration_dataloader = DataLoader(cal_dataset, batch_size=32) # Initialize the ECE ece = CalibrationError(task="multiclass", num_classes=100) @@ -114,7 +115,7 @@ # Fit the scaler on the calibration dataset scaled_model = TemperatureScaler(model=model) -scaled_model.fit(calibration_set=cal_dataset) +scaled_model.fit(dataloader=calibration_dataloader) # %% # 6. Iterating Again to Compute the Improved ECE diff --git a/docs/source/conf.py b/docs/source/conf.py index 2c9d46a0..fedd153e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.4.2" +release = "0.4.3" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index e404e640..fd6ab4ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.4.2" +version = "0.4.3" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, @@ -46,7 +46,6 @@ experiments = [ "safetensors", ] image = [ - "scikit-image", "kornia", "h5py", "opencv-python", diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 66555c6c..8da6b391 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -2,10 +2,10 @@ import numpy as np import torch -import torchvision.transforms.v2 as T from numpy.typing import ArrayLike from torch.utils.data import DataLoader from torchvision import tv_tensors +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule @@ -52,8 +52,8 @@ def __init__( self.ood_dataset = DummyClassificationDataset self.shift_dataset = DummyClassificationDataset - self.train_transform = T.ToTensor() - self.test_transform = T.ToTensor() + self.train_transform = v2.ToTensor() + self.test_transform = v2.ToTensor() def prepare_data(self) -> None: pass @@ -207,7 +207,7 @@ def __init__( self.dataset = DummySegmentationDataset - self.train_transform = T.ToDtype( + self.train_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -215,7 +215,7 @@ def __init__( }, scale=True, ) - self.test_transform = T.ToDtype( + self.test_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, @@ -296,7 +296,7 @@ def __init__( self.dataset = DummPixelRegressionDataset - self.train_transform = T.ToDtype( + self.train_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.float32, @@ -304,7 +304,7 @@ def __init__( }, scale=True, ) - self.test_transform = T.ToDtype( + self.test_transform = v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.float32, diff --git a/tests/datamodules/classification/test_cifar10.py b/tests/datamodules/classification/test_cifar10.py index 51e6a0d6..d583219a 100644 --- a/tests/datamodules/classification/test_cifar10.py +++ b/tests/datamodules/classification/test_cifar10.py @@ -10,7 +10,7 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" def test_cifar10_main(self): - dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16) + dm = CIFAR10DataModule(root="./data/", batch_size=128, cutout=16, postprocess_set="test") assert dm.dataset == CIFAR10 assert isinstance(dm.train_transform.transforms[2], Cutout) diff --git a/tests/datamodules/segmentation/test_muad.py b/tests/datamodules/segmentation/test_muad.py index 97d4f6d0..862206f0 100644 --- a/tests/datamodules/segmentation/test_muad.py +++ b/tests/datamodules/segmentation/test_muad.py @@ -35,7 +35,3 @@ def test_camvid_main(self): dm.setup() dm.train_dataloader() dm.val_dataloader() - - def test_small_muad_accessibility(self): - dataset = MUAD(root="./data/", split="test", version="small", download=True) - assert len(dataset.samples) > 0, "Dataset is not found" diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 84a87293..1983a028 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -47,12 +47,29 @@ def test_cv_main(self): def test_errors(self): TUDataModule.__abstractmethods__ = set() - dm = TUDataModule("root", 128, 0.0, 4, True, True) + dm = TUDataModule( + root="root", + batch_size=128, + val_split=0.0, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds - cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 0.0, 4, True, True) + cv_dm = CrossValDataModule( + root="root", + train_idx=[0], + val_idx=[1], + datamodule=dm, + batch_size=128, + val_split=0.0, + num_workers=4, + pin_memory=True, + persistent_workers=True, + ) with pytest.raises(NotImplementedError): cv_dm.setup() cv_dm._get_train_data() diff --git a/tests/post_processing/test_laplace.py b/tests/post_processing/test_laplace.py index 8b6249ea..f2fdda7b 100644 --- a/tests/post_processing/test_laplace.py +++ b/tests/post_processing/test_laplace.py @@ -1,6 +1,6 @@ import torch from torch import nn -from torch.utils.data import TensorDataset +from torch.utils.data import DataLoader, TensorDataset from tests._dummies.model import dummy_model from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing @@ -20,12 +20,12 @@ class TestLaplace: """Testing the LaplaceApprox class.""" def test_training(self): - ds = TensorDataset(torch.randn(16, 1), torch.randn(16, 10)) + dl = DataLoader(TensorDataset(torch.randn(16, 1), torch.randn(16, 10)), batch_size=5) la = LaplaceApprox( task="classification", model=dummy_model(1, 10), ) - la.fit(ds) + la.fit(dl) la(torch.randn(1, 1)) la = LaplaceApprox(task="classification") la.set_model(dummy_model(1, 10)) diff --git a/tests/post_processing/test_mc_batch_norm.py b/tests/post_processing/test_mc_batch_norm.py index bbe987ca..1d8d552b 100644 --- a/tests/post_processing/test_mc_batch_norm.py +++ b/tests/post_processing/test_mc_batch_norm.py @@ -4,6 +4,7 @@ import torch import torchvision.transforms as T from torch import nn +from torch.utils.data import DataLoader from tests._dummies.dataset import DummyClassificationDataset from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d @@ -21,10 +22,7 @@ def test_main(self): model = lenet(1, 1, norm=nn.BatchNorm2d) stoch_model = MCBatchNorm( - nn.Sequential(model), - num_estimators=2, - convert=True, - mc_batch_size=1, + nn.Sequential(model), num_estimators=2, convert=True, mc_batch_size=1 ) dataset = DummyClassificationDataset( "./", @@ -34,7 +32,7 @@ def test_main(self): num_images=2, transform=T.ToTensor(), ) - stoch_model.fit(dataset=dataset) + stoch_model.fit(dataloader=DataLoader(dataset, batch_size=6, shuffle=True)) stoch_model.train() stoch_model(torch.randn(1, 1, 20, 20)) stoch_model.eval() @@ -48,8 +46,6 @@ def test_errors(self): model = nn.Identity() with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=0, convert=True) - with pytest.raises(ValueError, match="mc_batch_size must be a positive integer"): - MCBatchNorm(model, num_estimators=1, convert=True, mc_batch_size=-1) with pytest.raises(ValueError): MCBatchNorm(model, num_estimators=1, convert=False) with pytest.raises(ValueError): @@ -64,9 +60,10 @@ def test_errors(self): num_images=2, transform=T.ToTensor(), ) + dl = DataLoader(dataset, batch_size=2, shuffle=True) stoch_model.eval() with pytest.raises(RuntimeError): stoch_model(torch.randn(1, 1, 20, 20)) with pytest.raises(ValueError): - stoch_model.fit(dataset=dataset) + stoch_model.fit(dataloader=dl) diff --git a/tests/post_processing/test_scalers.py b/tests/post_processing/test_scalers.py index fabe77f3..b499efc5 100644 --- a/tests/post_processing/test_scalers.py +++ b/tests/post_processing/test_scalers.py @@ -1,6 +1,7 @@ import pytest import torch from torch import nn, softmax +from torch.utils.data import DataLoader from torch_uncertainty.post_processing import ( MatrixScaler, @@ -26,10 +27,11 @@ def test_fit_biased(self): labels = torch.as_tensor([0.5, 0.5]).repeat(10, 1) calibration_set = list(zip(inputs, labels, strict=True)) + dl = DataLoader(calibration_set, batch_size=10) scaler = TemperatureScaler(model=nn.Identity(), init_val=2, lr=1, max_iter=10) assert scaler.temperature[0] == 2.0 - scaler.fit(calibration_set) + scaler.fit(dl) assert scaler.temperature[0] > 10 # best is +inf assert ( torch.sum( @@ -39,7 +41,7 @@ def test_fit_biased(self): ** 2 < 0.001 ) - scaler.fit_predict(calibration_set, progress=False) + scaler.fit_predict(dl, progress=False) def test_errors(self): with pytest.raises(ValueError): diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 5164206e..cd724c68 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -360,12 +360,12 @@ def test_classification_failures(self): mixup_params=mixup_params, ) - with pytest.raises(ValueError, match="num_calibration_bins must be at least 2, got"): + with pytest.raises(ValueError, match="num_bins_cal_err must be at least 2, got"): ClassificationRoutine( model=nn.Identity(), num_classes=2, loss=nn.CrossEntropyLoss(), - num_calibration_bins=0, + num_bins_cal_err=0, ) with pytest.raises(ValueError): diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index fab7a27c..2a498a7a 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -99,10 +99,10 @@ def test_segmentation_errors(self): metric_subsampling_rate=-1, ) - with pytest.raises(ValueError, match="num_calibration_bins must be at least 2, got"): + with pytest.raises(ValueError, match="num_bins_cal_err must be at least 2, got"): SegmentationRoutine( model=nn.Identity(), num_classes=2, loss=nn.CrossEntropyLoss(), - num_calibration_bins=0, + num_bins_cal_err=0, ) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index ce6a76cc..bf1184a5 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -15,6 +15,7 @@ ImpulseNoise, JPEGCompression, MotionBlur, + OriginalGlassBlur, Pixelate, Saturation, ShotNoise, @@ -34,84 +35,113 @@ def test_gaussian_noise(self): _ = GaussianNoise(0.1) inputs = torch.rand(3, 32, 32) transform = GaussianNoise(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = GaussianNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + assert transform(inputs).ndim == 4 + print(transform) def test_shot_noise(self): inputs = torch.rand(3, 32, 32) transform = ShotNoise(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = ShotNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + assert transform(inputs).ndim == 4 def test_impulse_noise(self): inputs = torch.rand(3, 32, 32) - transform = ImpulseNoise(1) - transform(inputs) + transform = ImpulseNoise(1, black_white=True) + assert transform(inputs).ndim == 3 transform = ImpulseNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + transform = ImpulseNoise(1, black_white=False) + inputs = torch.rand(3, 3, 32, 32) + assert transform(inputs).ndim == 4 def test_speckle_noise(self): inputs = torch.rand(3, 32, 32) transform = SpeckleNoise(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = SpeckleNoise(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + transform = MotionBlur(1) + assert transform(inputs).ndim == 4 def test_gaussian_blur(self): inputs = torch.rand(3, 32, 32) transform = GaussianBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = GaussianBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + transform = GaussianBlur(1) + assert transform(inputs).ndim == 4 def test_glass_blur(self): inputs = torch.rand(3, 32, 32) transform = GlassBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = GlassBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 32, 32) + transform = OriginalGlassBlur(1, seed=1) + assert transform(inputs).ndim == 3 + transform = OriginalGlassBlur(0) + assert transform(inputs).ndim == 3 def test_defocus_blur(self): inputs = torch.rand(3, 32, 32) transform = DefocusBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = DefocusBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 + + inputs = torch.rand(3, 3, 32, 32) + transform = DefocusBlur(1) + assert transform(inputs).ndim == 4 def test_motion_blur(self): inputs = torch.rand(3, 32, 32) transform = MotionBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = MotionBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 - inputs = torch.rand(1, 3, 32, 32) + inputs = torch.rand(3, 3, 32, 32) transform = MotionBlur(1) - transform(inputs) + assert transform(inputs).ndim == 4 def test_zoom_blur(self): inputs = torch.rand(3, 32, 32) transform = ZoomBlur(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = ZoomBlur(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_jpeg_compression(self): inputs = torch.rand(3, 32, 32) transform = JPEGCompression(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = JPEGCompression(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_pixelate(self): inputs = torch.rand(3, 32, 32) transform = Pixelate(1) - transform(inputs) + assert transform(inputs).ndim == 3 transform = Pixelate(0) - transform(inputs) + assert transform(inputs).ndim == 3 def test_frost(self): try: @@ -135,18 +165,11 @@ def test_snow(self): def test_fog(self): inputs = torch.rand(3, 32, 32) - transform = Fog(1, size=32) + transform = Fog(1) transform(inputs) - - with pytest.raises(ValueError, match="Image must be square. Got "): - transform(torch.rand(3, 32, 12)) - - transform = Fog(0, size=32) + transform = Fog(0) transform(inputs) - with pytest.raises(ValueError, match="Size must be a power of 2. Got "): - _ = Fog(1, size=15) - def test_brightness(self): inputs = torch.rand(3, 32, 32) transform = Brightness(1) diff --git a/torch_uncertainty/baselines/classification/deep_ensembles.py b/torch_uncertainty/baselines/classification/deep_ensembles.py index bb2cb96f..bba7683c 100644 --- a/torch_uncertainty/baselines/classification/deep_ensembles.py +++ b/torch_uncertainty/baselines/classification/deep_ensembles.py @@ -26,7 +26,6 @@ def __init__( eval_grouping_loss: bool = False, ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, - calibration_set: Literal["val", "test"] = "val", ) -> None: log_path = Path(log_path) @@ -54,6 +53,5 @@ def __init__( eval_grouping_loss=eval_grouping_loss, ood_criterion=ood_criterion, log_plots=log_plots, - calibration_set=calibration_set, ) self.save_hyperparameters() # coverage: ignore diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 8fb208ab..376f5815 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -70,11 +70,10 @@ def __init__( ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, pretrained: bool = False, ) -> None: r"""ResNet backbone baseline for classification providing support for @@ -154,15 +153,13 @@ def __init__( Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in a csv file or not. Defaults to ``False``. - calibration_set (Callable, optional): Calibration set. Defaults to - ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. - num_calibration_bins (int, optional): Number of calibration bins. + num_bins_cal_err (int, optional): Number of calibration bins. Defaults to ``15``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. @@ -244,7 +241,6 @@ def __init__( ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, - calibration_set=calibration_set, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 4ba0bc6b..520c6425 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -41,7 +41,6 @@ def __init__( ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, @@ -100,8 +99,6 @@ def __init__( Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in a csv file or not. Defaults to ``False``. - calibration_set (Callable, optional): Calibration set. Defaults to - ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to @@ -178,7 +175,6 @@ def __init__( ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, - calibration_set=calibration_set, eval_grouping_loss=eval_grouping_loss, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index f3d57fee..477b83fd 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -50,7 +50,6 @@ def __init__( ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", log_plots: bool = False, save_in_csv: bool = False, - calibration_set: Literal["val", "test"] = "val", eval_ood: bool = False, eval_shift: bool = False, eval_grouping_loss: bool = False, @@ -112,8 +111,6 @@ def __init__( Defaults to ``False``. save_in_csv (bool, optional): Indicates whether to save the results in a csv file or not. Defaults to ``False``. - calibration_set (Callable, optional): Calibration set. Defaults to - ``None``. eval_ood (bool, optional): Indicates whether to evaluate the OOD detection or not. Defaults to ``False``. eval_shift (bool): Whether to evaluate on shifted data. Defaults to @@ -195,6 +192,5 @@ def __init__( ood_criterion=ood_criterion, log_plots=log_plots, save_in_csv=save_in_csv, - calibration_set=calibration_set, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/deeplab.py b/torch_uncertainty/baselines/segmentation/deeplab.py index f3e3982c..65bc4630 100644 --- a/torch_uncertainty/baselines/segmentation/deeplab.py +++ b/torch_uncertainty/baselines/segmentation/deeplab.py @@ -30,7 +30,7 @@ def __init__( separable: bool, metric_subsampling_rate: float = 1e-2, log_plots: bool = False, - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, pretrained_backbone: bool = True, ) -> None: params = { @@ -54,6 +54,6 @@ def __init__( format_batch_fn=format_batch_fn, metric_subsampling_rate=metric_subsampling_rate, log_plots=log_plots, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index fdaddd14..ed56b8e5 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -13,6 +13,8 @@ else: # coverage: ignore sklearn_installed = False +import logging + from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler @@ -33,8 +35,9 @@ def __init__( num_workers: int, pin_memory: bool, persistent_workers: bool, + postprocess_set: Literal["val", "test"] = "val", ) -> None: - """Abstract DataModule class. + """Abstract DataModule class for TorchUncertainty. This class implements the basic functionality of a DataModule. It includes setters and getters for the datasets, dataloaders, as well as the crossval @@ -47,6 +50,8 @@ def __init__( num_workers (int): Number of workers to use for data loading. pin_memory (bool): Whether to pin memory. persistent_workers (bool): Whether to use persistent workers. + postprocess_set (str): Which split to use as post-processing set to fit the + post-processing method. """ super().__init__() @@ -58,6 +63,10 @@ def __init__( self.pin_memory = pin_memory self.persistent_workers = persistent_workers + if postprocess_set == "test": + logging.warning("Fitting the calibration method on the test set!") + self.postprocess_set = postprocess_set + @abstractmethod def setup(self, stage: Literal["fit", "test"] | None = None) -> None: pass @@ -99,6 +108,14 @@ def test_dataloader(self) -> list[DataLoader]: """ return [self._data_loader(self.test)] + def postprocess_dataloader(self) -> DataLoader: + r"""Get the calibration dataloader. + + Return: + DataLoader: calibration dataloader. + """ + return self.val_dataloader() if self.postprocess_set == "val" else self.test_dataloader()[0] + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: """Create a dataloader for a given dataset. @@ -155,6 +172,7 @@ def make_cross_val_splits(self, n_splits: int = 10, train_over: int = 4) -> list datamodule=self, batch_size=self.batch_size, val_split=self.val_split, + postprocess_set=self.postprocess_set, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -176,6 +194,7 @@ def __init__( num_workers: int, pin_memory: bool, persistent_workers: bool, + postprocess_set: Literal["val", "test"] = "val", ) -> None: super().__init__( root=root, @@ -184,6 +203,7 @@ def __init__( num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, + postprocess_set=postprocess_set, ) self.train_idx = train_idx diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 1e1441c2..d52211e2 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -2,12 +2,13 @@ from typing import Literal import numpy as np -import torchvision.transforms as T +import torch from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10, SVHN +from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import TUDataModule from torch_uncertainty.datasets import AggregatedDataset @@ -32,6 +33,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", num_workers: int = 1, basic_augment: bool = True, cutout: int | None = None, @@ -52,6 +54,8 @@ def __init__( batch_size (int): Number of samples per batch. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. basic_augment (bool): Whether to apply base augmentations. Defaults to @@ -72,6 +76,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -100,10 +105,10 @@ def __init__( ) if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), + v2.RandomCrop(32, padding=4), + v2.RandomHorizontalFlip(), ] ) else: @@ -116,25 +121,21 @@ def __init__( else: main_transform = nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), basic_transform, main_transform, - T.Normalize( - self.mean, - self.std, - ), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Normalize( - self.mean, - self.std, - ), + v2.ToImage(), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 11d0a7fa..6334b10c 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -3,12 +3,12 @@ import numpy as np import torch -import torchvision.transforms as T from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, SVHN +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets import AggregatedDataset @@ -33,6 +33,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", basic_augment: bool = True, cutout: int | None = None, randaugment: bool = False, @@ -53,6 +54,8 @@ def __init__( batch_size (int): Number of samples per batch. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. basic_augment (bool): Whether to apply base augmentations. Defaults to ``True``. cutout (int): Size of cutout to apply to images. Defaults to ``None``. @@ -72,6 +75,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -94,10 +98,10 @@ def __init__( ) if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomCrop(32, padding=4), - T.RandomHorizontalFlip(), + v2.RandomCrop(32, padding=4), + v2.RandomHorizontalFlip(), ] ) else: @@ -106,25 +110,26 @@ def __init__( if cutout: main_transform = Cutout(cutout) elif randaugment: - main_transform = T.RandAugment(num_ops=2, magnitude=20) + main_transform = v2.RandAugment(num_ops=2, magnitude=20) elif auto_augment: main_transform = rand_augment_transform(auto_augment, {}) else: main_transform = nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), basic_transform, main_transform, - T.ConvertImageDtype(torch.float32), - T.Normalize(mean=self.mean, std=self.std), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 24fd5c6d..23012f66 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -2,13 +2,14 @@ from pathlib import Path from typing import Literal -import torchvision.transforms as T +import torch import yaml from timm.data.auto_augment import rand_augment_transform from timm.data.mixup import Mixup from torch import nn from torch.utils.data import DataLoader, Subset from torchvision.datasets import DTD, SVHN, ImageNet, INaturalist +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import ( @@ -49,6 +50,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | Path | None = None, + postprocess_set: Literal["val", "test"] = "val", ood_ds: str = "openimage-o", test_alt: str | None = None, procedure: str | None = None, @@ -73,6 +75,8 @@ def __init__( val_split (float or Path): Share of samples to use for validation or path to a yaml file containing a list of validation images ids. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. ood_ds (str): Which out-of-distribution dataset to use. Defaults to ``"openimage-o"``. test_alt (str): Which test set to use. Defaults to ``None``. @@ -93,6 +97,7 @@ def __init__( root=Path(root), batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -137,10 +142,10 @@ def __init__( self.procedure = procedure if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomResizedCrop(train_size, interpolation=self.interpolation), - T.RandomHorizontalFlip(), + v2.RandomResizedCrop(train_size, interpolation=self.interpolation), + v2.RandomHorizontalFlip(), ] ) else: @@ -153,7 +158,7 @@ def __init__( main_transform = nn.Identity() elif self.procedure == "ViT": train_size = 224 - main_transform = T.Compose( + main_transform = v2.Compose( [ Mixup(mixup_alpha=0.2, cutmix_alpha=1.0), rand_augment_transform("rand-m9-n2-mstd0.5", {}), @@ -165,21 +170,23 @@ def __init__( else: raise ValueError("The procedure is unknown") - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), basic_transform, main_transform, - T.Normalize(mean=self.mean, std=self.std), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Resize(256, interpolation=self.interpolation), - T.CenterCrop(224), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.Resize(256, interpolation=self.interpolation), + v2.CenterCrop(224), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index f6879c1a..a1d72de4 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -1,10 +1,11 @@ from pathlib import Path from typing import Literal -import torchvision.transforms as T +import torch from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST, FashionMNIST +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import MNISTC, NotMNIST @@ -29,6 +30,7 @@ def __init__( eval_shift: bool = False, ood_ds: Literal["fashion", "notMNIST"] = "fashion", val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", num_workers: int = 1, basic_augment: bool = True, cutout: int | None = None, @@ -49,6 +51,8 @@ def __init__( notMNIST. val_split (float): Share of samples to use for validation. Defaults to ``0.0``. + postprocess_set (str, optional): The post-hoc calibration dataset to + use for the post-processing method. Defaults to ``val``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. basic_augment (bool): Whether to apply base augmentations. Defaults to @@ -62,6 +66,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -82,32 +87,35 @@ def __init__( self.shift_dataset = MNISTC self.shift_severity = 1 - basic_transform = T.RandomCrop(28, padding=4) if basic_augment else nn.Identity() + basic_transform = v2.RandomCrop(28, padding=4) if basic_augment else nn.Identity() main_transform = Cutout(cutout) if cutout else nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), basic_transform, main_transform, - T.Normalize(mean=self.mean, std=self.std), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.CenterCrop(28), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.CenterCrop(28), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) if self.eval_ood: # NotMNIST has 3 channels - self.ood_transform = T.Compose( + self.ood_transform = v2.Compose( [ - T.ToTensor(), - T.Grayscale(num_output_channels=1), - T.CenterCrop(28), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.Grayscale(num_output_channels=1), + v2.CenterCrop(28), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index bf95159e..f84b972c 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -2,12 +2,13 @@ from typing import Literal import numpy as np -import torchvision.transforms as T +import torch from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN +from torchvision.transforms import v2 from torch_uncertainty.datamodules import TUDataModule from torch_uncertainty.datasets.classification import ( @@ -36,6 +37,7 @@ def __init__( eval_shift: bool = False, shift_severity: int = 1, val_split: float | None = None, + postprocess_set: Literal["val", "test"] = "val", ood_ds: str = "svhn", interpolation: str = "bilinear", basic_augment: bool = True, @@ -48,6 +50,7 @@ def __init__( root=root, batch_size=batch_size, val_split=val_split, + postprocess_set=postprocess_set, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, @@ -71,10 +74,10 @@ def __init__( raise ValueError(f"OOD dataset {ood_ds} not supported for TinyImageNet.") self.shift_dataset = TinyImageNetC if basic_augment: - basic_transform = T.Compose( + basic_transform = v2.Compose( [ - T.RandomCrop(64, padding=4), - T.RandomHorizontalFlip(), + v2.RandomCrop(64, padding=4), + v2.RandomHorizontalFlip(), ] ) else: @@ -85,20 +88,22 @@ def __init__( else: main_transform = nn.Identity() - self.train_transform = T.Compose( + self.train_transform = v2.Compose( [ - T.ToTensor(), + v2.ToImage(), basic_transform, main_transform, - T.Normalize(mean=self.mean, std=self.std), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) - self.test_transform = T.Compose( + self.test_transform = v2.Compose( [ - T.ToTensor(), - T.Resize(64, interpolation=self.interpolation), - T.Normalize(mean=self.mean, std=self.std), + v2.ToImage(), + v2.Resize(64, interpolation=self.interpolation), + v2.ToDtype(dtype=torch.float32, scale=True), + v2.Normalize(mean=self.mean, std=self.std), ] ) diff --git a/torch_uncertainty/datamodules/segmentation/camvid.py b/torch_uncertainty/datamodules/segmentation/camvid.py index 08fd432a..66d774b8 100644 --- a/torch_uncertainty/datamodules/segmentation/camvid.py +++ b/torch_uncertainty/datamodules/segmentation/camvid.py @@ -125,7 +125,6 @@ def __init__( self.train_transform = v2.Compose( [ - v2.ToImage(), basic_transform, v2.ToDtype( dtype={ @@ -140,7 +139,6 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.ToImage(), v2.Resize(size=self.eval_size, antialias=True), v2.ToDtype( dtype={ diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py index 38bacf70..7f97c932 100644 --- a/torch_uncertainty/datasets/classification/cub.py +++ b/torch_uncertainty/datasets/classification/cub.py @@ -55,6 +55,7 @@ def __init__( ) super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) + self.root = Path(root) training_idx = self._load_train_idx() self.attributes, self.uncertainties = self._load_attributes() diff --git a/torch_uncertainty/datasets/corrupted.py b/torch_uncertainty/datasets/corrupted.py index dbdff04a..8c8a4197 100644 --- a/torch_uncertainty/datasets/corrupted.py +++ b/torch_uncertainty/datasets/corrupted.py @@ -1,11 +1,12 @@ +import re from copy import deepcopy from pathlib import Path -from PIL import Image from torch import nn from torchvision.datasets import VisionDataset +from torchvision.datasets.folder import default_loader from torchvision.transforms import ToPILImage, ToTensor -from tqdm.auto import tqdm +from tqdm import tqdm, trange from tqdm.contrib.logging import logging_redirect_tqdm from torch_uncertainty.transforms.corruption import corruption_transforms @@ -16,52 +17,101 @@ def __init__( self, core_dataset: VisionDataset, shift_severity: int, + generate: bool = False, on_the_fly: bool = False, ) -> None: + """Generate the corrupted version of any VisionDataset. + + Args: + core_dataset (VisionDataset): dataset to be corrupted. + shift_severity (int): intensity of the corruption. Should be in [1, 5]. + generate (bool): Equivalent of the download attributes of the dataset. If ``True``, + generate a new dataset with all the corrupted images. Defaults to ``False``. + on_the_fly (bool): Generate the corrupted version of the dataset on the fly, without + saving the images on disk. This is discouraged since the experiment won't be fully + reproducible. Defaults to ``False``. + + Note: + The corrupted dataset will use `transforms` of :attr:`core_dataset`. + """ super().__init__() self.core_dataset = core_dataset if shift_severity <= 0: - raise ValueError(f"Severity must be greater than 0. Got {shift_severity}.") + raise ValueError(f"Severity must be strictly greater than 0. Got {shift_severity}.") + if not generate and on_the_fly: + raise ValueError("generate must be True if on_the_fly is True.") + self.shift_severity = shift_severity self.core_length = len(core_dataset) + self.generate = generate self.on_the_fly = on_the_fly + self.transforms = deepcopy(core_dataset.transforms) - self.target_transforms = deepcopy(core_dataset.target_transform) self.core_dataset.transform = None + self.core_dataset.transforms = None self.core_dataset.target_transform = None - self.root = Path(core_dataset.root) - dataset_name = str(type(core_dataset)).split(".")[-1][:-2].lower() - self.root /= dataset_name + "_corrupted" - self.root /= f"severity_{self.shift_severity}" - self.root.mkdir(parents=True) - - if not on_the_fly: + dataset_name = str(type(core_dataset)).split(".")[-1][:-2] + self.root = Path(core_dataset.root) / (dataset_name + "-C") + + if hasattr(self.core_dataset, "targets"): + self.targets = self.core_dataset.targets + elif hasattr(self.core_dataset, "labels"): + self.targets = self.core_dataset.labels + elif hasattr(self.core_dataset, "_labels"): + self.targets = self.core_dataset._labels + else: + raise ValueError("The dataset should implement either targets, labels, or _labels.") + + self.targets = self.targets * len(corruption_transforms) + + if not generate: + paths = sorted(self.root.glob(f"**/{self.shift_severity}/*.jpg"), key=lambda x: x.stem) + self.samples = list(zip(paths, self.targets, strict=False)) + if len(paths) != 15 * self.core_length: + raise ValueError( + "The corrupted dataset is not complete. Download it from HuggingFace or set generate=True." + ) + + if generate and not on_the_fly: + self.root.mkdir(parents=True, exist_ok=True) self.to_tensor = ToTensor() self.to_pil = ToPILImage() self.samples = [] - self.targets = self.core_dataset.targets * 10 - self.prepare_data() - def prepare_data(self): + self._generate_data() + + def _generate_data(self): + """Generate the corrupted data.""" with logging_redirect_tqdm(): - for corruption in tqdm(corruption_transforms): - corruption_name = corruption.__name__.lower() - (self.root / corruption_name).mkdir(parents=True) - self.save_corruption(self.root / corruption_name, corruption(self.shift_severity)) + pbar = tqdm(corruption_transforms) + for corruption in pbar: + corruption_name = re.sub(r"([a-z])([A-Z])", r"\1_\2", corruption.__name__).lower() + pbar.set_description(f"Processing {corruption_name}") + (self.root / corruption_name / f"{self.shift_severity}").mkdir( + parents=True, exist_ok=True + ) + self._save_corruption( + self.root / corruption_name / f"{self.shift_severity}", + corruption(self.shift_severity), + ) + + def _save_corruption(self, root: Path, corruption: nn.Module) -> None: + """Save all images with the given corruption on the disk. - def save_corruption(self, root: Path, corruption: nn.Module) -> None: - for i in range(self.core_length): + Args: + root (Path): The path where to save the images. + corruption (nn.Module): The corruption module to apply on the images. + """ + for i in trange(self.core_length, leave=False): img, tgt = self.core_dataset[i] - if isinstance(img, str | Path): - img = Image.open(img).convert("RGB") img = corruption(self.to_tensor(img)) - self.to_pil(img).save(root / f"{i}.png") - self.samples.append(root / f"{i}.png") + self.to_pil(img).save(root / f"{i}.jpg") + self.samples.append((root / f"{i}.jpg", tgt)) self.targets.append(tgt) def __len__(self): - """The length of the corrupted dataset.""" + """Get the length of the corrupted dataset.""" return len(self.core_dataset) * len(corruption_transforms) def __getitem__(self, idx: int): @@ -76,26 +126,13 @@ def __getitem__(self, idx: int): img, target = self.core_dataset[idx] img = corrupt(img) - if self.transform is not None: - img = self.transform(img) + if self.transforms is not None: + img, target = self.transforms(img, target) - if self.target_transform is not None: - target = self.target_transform(target) return img, target - img, target = self.core_dataset[idx] - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) - + path, target = self.samples[idx] + img = default_loader(path) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target - - -if __name__ == "__main__": - from torchvision.datasets import CIFAR10 - - dataset = CIFAR10(root="data", download=True) - corrupted_dataset = CorruptedDataset(dataset, shift_severity=1) - print(len(corrupted_dataset)) diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index cc93c3d9..162be185 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -244,7 +244,7 @@ def _make_dataset(self, path: Path) -> None: f"target_type must be one of ['semantic', 'depth']. Got {self.target_type}." ) - def _download(self, split: str) -> None: + def _download(self, split: str) -> None: # coverage: ignore """Download and extract the chosen split of the dataset.""" if self.version == "small": filename = f"{split}.zip" diff --git a/torch_uncertainty/post_processing/abnn.py b/torch_uncertainty/post_processing/abnn.py index 0ec24375..b79e7f36 100644 --- a/torch_uncertainty/post_processing/abnn.py +++ b/torch_uncertainty/post_processing/abnn.py @@ -2,7 +2,7 @@ import torch from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from torch_uncertainty.layers.bayesian.abnn import BatchNormAdapter2d from torch_uncertainty.models import deep_ensembles @@ -25,7 +25,6 @@ def __init__( device: torch.device | str, max_epochs: int = 5, use_original_model: bool = True, - batch_size: int = 128, precision: str = "32", model: nn.Module | None = None, ): @@ -45,8 +44,6 @@ def __init__( to 5. use_original_model (bool, optional): Use original model during evaluation. Defaults to True. - batch_size (int, optional): Batch size for the training of ABNN. - Defaults to 128. precision (str, optional): Machine precision for training & eval. Defaults to "32". model (nn.Module | None, optional): Model to use. Defaults to None. @@ -63,7 +60,6 @@ def __init__( num_models=num_models, num_samples=num_samples, base_lr=base_lr, - batch_size=batch_size, ) self.num_classes = num_classes self.alpha = alpha @@ -74,7 +70,6 @@ def __init__( self.use_original_model = use_original_model self.max_epochs = max_epochs - self.batch_size = batch_size self.precision = precision self.device = device @@ -88,10 +83,9 @@ def __init__( weight[torch.randperm(num_classes)[:num_rp_classes]] += random_prior - 1 self.weights.append(weight) - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: if self.model is None: raise ValueError("Model must be set before fitting.") - dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) source_model = copy.deepcopy(self.model) _replace_bn_layers(source_model, self.alpha) @@ -119,7 +113,7 @@ def fit(self, dataset: Dataset) -> None: logger=None, enable_model_summary=False, ) - trainer.fit(model=baseline, train_dataloaders=dl) + trainer.fit(model=baseline, train_dataloaders=dataloader) final_models = ( [copy.deepcopy(source_model) for _ in range(self.num_samples)] diff --git a/torch_uncertainty/post_processing/abstract.py b/torch_uncertainty/post_processing/abstract.py index 9c7908cc..7afe050c 100644 --- a/torch_uncertainty/post_processing/abstract.py +++ b/torch_uncertainty/post_processing/abstract.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader class PostProcessing(ABC, nn.Module): @@ -14,12 +14,12 @@ def set_model(self, model: nn.Module) -> None: self.model = model @abstractmethod - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: pass @abstractmethod def forward( self, - x: Tensor, + inputs: Tensor, ) -> Tensor: pass diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index a5398df2..3b1cedec 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn, optim -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from tqdm import tqdm from torch_uncertainty.post_processing import PostProcessing @@ -47,14 +47,14 @@ def __init__( def fit( self, - calibration_set: Dataset, + dataloader: DataLoader, save_logits: bool = False, progress: bool = True, ) -> None: """Fit the temperature parameters to the calibration data. Args: - calibration_set (Dataset): Calibration dataset. + dataloader (DataLoader): Dataloader with the calibration data. save_logits (bool, optional): Whether to save the logits and labels. Defaults to False. progress (bool, optional): Whether to show a progress bar. @@ -65,16 +65,16 @@ def fit( "Cannot fit a Scaler method without model. Call .set_model(model) first." ) - logits_list = [] - labels_list = [] - calibration_dl = DataLoader(calibration_set, batch_size=32, shuffle=False, drop_last=False) + all_logits = [] + all_labels = [] + calibration_dl = dataloader with torch.no_grad(): for inputs, labels in tqdm(calibration_dl, disable=not progress): logits = self.model(inputs.to(self.device)) - logits_list.append(logits) - labels_list.append(labels) - all_logits = torch.cat(logits_list).detach().to(self.device) - all_labels = torch.cat(labels_list).detach().to(self.device) + all_logits.append(logits) + all_labels.append(labels) + all_logits = torch.cat(all_logits).to(self.device) + all_labels = torch.cat(all_labels).to(self.device) optimizer = optim.LBFGS(self.temperature, lr=self.lr, max_iter=self.max_iter) @@ -87,12 +87,12 @@ def calib_eval() -> float: optimizer.step(calib_eval) self.trained = True if save_logits: - self.logits = logits - self.labels = labels + self.logits = all_logits + self.labels = all_labels @torch.no_grad() def forward(self, inputs: Tensor) -> Tensor: - if not self.trained: + if self.model is None or not self.trained: logging.error( "TemperatureScaler has not been trained yet. Returning manually tempered inputs." ) @@ -111,10 +111,10 @@ def _scale(self, logits: Tensor) -> Tensor: def fit_predict( self, - calibration_set: Dataset, + dataloader: DataLoader, progress: bool = True, ) -> Tensor: - self.fit(calibration_set, save_logits=True, progress=progress) + self.fit(dataloader, save_logits=True, progress=progress) return self(self.logits) @property diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index 582a4bd8..eda08d87 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -2,7 +2,7 @@ from typing import Literal from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from .abstract import PostProcessing @@ -23,7 +23,6 @@ def __init__( hessian_struct="kron", pred_type: Literal["glm", "nn"] = "glm", link_approx: Literal["mc", "probit", "bridge", "bridge_norm"] = "probit", - batch_size: int = 256, optimize_prior_precision: bool = True, ) -> None: """Laplace approximation for uncertainty estimation. @@ -42,8 +41,6 @@ def __init__( link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): how to approximate the classification link function for the `'glm'`. See the Laplace library for more details. Defaults to "probit". - batch_size (int, optional): batch size for the Laplace approximation. - Defaults to 256. optimize_prior_precision (bool, optional): whether to optimize the prior precision. Defaults to True. @@ -51,7 +48,7 @@ def __init__( Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. """ super().__init__() - if not laplace_installed: # coverage: ignore + if not laplace_installed: raise ImportError( "The laplace-torch library is not installed. Please install" "torch_uncertainty with the all option:" @@ -63,7 +60,6 @@ def __init__( self.task = task self.weight_subset = weight_subset self.hessian_struct = hessian_struct - self.batch_size = batch_size self.optimize_prior_precision = optimize_prior_precision if model is not None: @@ -78,14 +74,13 @@ def set_model(self, model: nn.Module) -> None: hessian_structure=self.hessian_struct, ) - def fit(self, dataset: Dataset) -> None: - dl = DataLoader(dataset, batch_size=self.batch_size) - self.la.fit(train_loader=dl) + def fit(self, dataloader: DataLoader) -> None: + self.la.fit(train_loader=dataloader) if self.optimize_prior_precision: self.la.optimize_prior_precision(method="marglik") def forward( self, - x: Tensor, + inputs: Tensor, ) -> Tensor: - return self.la(x, pred_type=self.pred_type, link_approx=self.link_approx) + return self.la(inputs, pred_type=self.pred_type, link_approx=self.link_approx) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 0a1d250d..9fb8dae0 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.post_processing import PostProcessing @@ -27,10 +27,11 @@ def __init__( Args: model (nn.Module): model to be converted. num_estimators (int): number of estimators. - convert (bool): whether to convert the model. - mc_batch_size (int, optional): Monte Carlo batch size. Defaults to 32. + convert (bool): whether to convert the model. Defaults to ``True``. + mc_batch_size (int, optional): Monte Carlo batch size. The smaller the more variability + in the predictions. Defaults to 32. device (Literal["cpu", "cuda"] | torch.device | None, optional): device. - Defaults to None. + Defaults to ``None``. Note: This wrapper will be stochastic in eval mode only. @@ -40,9 +41,9 @@ def __init__( batch normalized deep networks. In ICML 2018. """ super().__init__() - self.mc_batch_size = mc_batch_size - self.convert = convert self.num_estimators = num_estimators + self.convert = convert + self.mc_batch_size = mc_batch_size self.device = device if model is not None: @@ -50,7 +51,7 @@ def __init__( def _setup_model(self, model): _mcbn_checks(model, self.num_estimators, self.mc_batch_size, self.convert) - self.model = deepcopy(model) # Is it necessary? + self.model = deepcopy(model) # TODO: Is it necessary? self.model = self.model.eval() if self.convert: self._convert() @@ -61,22 +62,29 @@ def set_model(self, model: nn.Module) -> None: self.model = model self._setup_model(model) - def fit(self, dataset: Dataset) -> None: + def fit(self, dataloader: DataLoader) -> None: """Fit the model on the dataset. Args: - dataset (Dataset): dataset to be used for fitting. + dataloader (DataLoader): DataLoader with the post-processing dataset. Note: This method is used to populate the MC BatchNorm layers. - Use the training dataset. + Use the post-processing dataset. + + Warning: + The ``batch_size`` of the DataLoader should be carefully chosen as it + will have an impact on the statistics of the MC BatchNorm layers. + + Raises: + ValueError: If there are less batches than the number of estimators. """ - self.dl = DataLoader(dataset, batch_size=self.mc_batch_size, shuffle=True) + dataloader = init_dataloader(dataloader, batch_size=self.mc_batch_size) self.counter = 0 self.reset_counters() self.set_accumulate(True) self.eval() - for x, _ in self.dl: + for x, _ in dataloader: self.model(x.to(self.device)) self.raise_counters() if self.counter == self.num_estimators: @@ -93,14 +101,14 @@ def _est_forward(self, x: Tensor) -> Tensor: def forward( self, - x: Tensor, + inputs: Tensor, ) -> Tensor: if self.training: - return self.model(x) + return self.model(inputs) if not self.trained: - raise RuntimeError("MCBatchNorm has not been trained. Call .fit() first.") + raise RuntimeError("MCBatchNorm has not been fit. Call .fit() first.") self.reset_counters() - return torch.cat([self._est_forward(x) for _ in range(self.num_estimators)], dim=0) + return torch.cat([self._est_forward(inputs) for _ in range(self.num_estimators)], dim=0) def _convert(self) -> None: """Convert all BatchNorm2d layers to MCBatchNorm2d layers.""" @@ -162,6 +170,31 @@ def has_mcbn(model: nn.Module) -> bool: return any(isinstance(module, MCBatchNorm2d) for module in model.modules()) +def init_dataloader(dataloader: DataLoader, batch_size: int): + """Reinitialize dataloader with the chosen batch size. + + It is impossible to change the ``batch_size`` of an already-instantiated dataloader. + + Args: + dataloader (DataLoader): the dataloader to be reinitialized with + batch_size (int): the given batch_size. + """ + return DataLoader( + dataloader.dataset, + batch_size=batch_size, + sampler=dataloader.sampler, + num_workers=dataloader.num_workers, + pin_memory=dataloader.pin_memory, + drop_last=dataloader.drop_last, + timeout=dataloader.timeout, + worker_init_fn=dataloader.worker_init_fn, + multiprocessing_context=dataloader.multiprocessing_context, + generator=dataloader.generator, + prefetch_factor=dataloader.prefetch_factor, + persistent_workers=dataloader.persistent_workers, + ) + + def _mcbn_checks(model, num_estimators, mc_batch_size, convert): if num_estimators < 1 or not isinstance(num_estimators, int): raise ValueError(f"num_estimators must be a positive integer, got {num_estimators}.") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index e5469b5c..52b8dc46 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -74,8 +74,7 @@ def __init__( eval_grouping_loss: bool = False, ood_criterion: Literal["msp", "logit", "energy", "entropy", "mi", "vr"] = "msp", post_processing: PostProcessing | None = None, - calibration_set: Literal["val", "test"] = "val", - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, log_plots: bool = False, save_in_csv: bool = False, ) -> None: @@ -111,10 +110,8 @@ def __init__( post_processing (PostProcessing, optional): Post-processing method to train on the calibration set. No post-processing if None. Defaults to ``None``. - calibration_set (str, optional): The post-hoc calibration dataset to - use for the post-processing method. Defaults to ``val``. - num_calibration_bins (int, optional): Number of bins to compute calibration - metrics. Defaults to ``15``. + num_bins_cal_err (int, optional): Number of bins to compute calibration + error metrics. Defaults to ``15``. log_plots (bool, optional): Indicates whether to log plots from metrics. Defaults to ``False``. save_in_csv(bool, optional): Save the results in csv. Defaults to @@ -150,7 +147,7 @@ def __init__( is_ensemble=is_ensemble, ood_criterion=ood_criterion, eval_grouping_loss=eval_grouping_loss, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, mixup_params=mixup_params, post_processing=post_processing, format_batch_fn=format_batch_fn, @@ -166,11 +163,10 @@ def __init__( self.ood_criterion = ood_criterion self.log_plots = log_plots self.save_in_csv = save_in_csv - self.calibration_set = calibration_set self.binary_cls = num_classes == 1 self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) - self.num_calibration_bins = num_calibration_bins + self.num_bins_cal_err = num_bins_cal_err self.model = model self.loss = loss self.format_batch_fn = format_batch_fn @@ -202,13 +198,13 @@ def _init_metrics(self) -> None: "cls/NLL": CategoricalNLL(), "cal/ECE": CalibrationError( task=task, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), "cal/aECE": CalibrationError( task=task, adaptive=True, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, num_classes=self.num_classes, ), "sc/AURC": AURC(), @@ -381,13 +377,8 @@ def on_test_start(self) -> None: the storage lists for logit plotting and update the batchnorms if needed. """ if self.post_processing is not None: - calibration_dataset = ( - self.trainer.datamodule.val_dataloader().dataset - if self.calibration_set == "val" - else self.trainer.datamodule.test_dataloader()[0].dataset - ) with torch.inference_mode(False): - self.post_processing.fit(calibration_dataset) + self.post_processing.fit(self.trainer.datamodule.postprocess_dataloader()) if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): self.id_logit_storage = [] @@ -699,7 +690,7 @@ def _classification_routine_checks( is_ensemble: bool, ood_criterion: str, eval_grouping_loss: bool, - num_calibration_bins: int, + num_bins_cal_err: int, mixup_params: dict | None, post_processing: PostProcessing | None, format_batch_fn: nn.Module | None, @@ -712,7 +703,7 @@ def _classification_routine_checks( is_ensemble (bool): whether the model is an ensemble or a single model. ood_criterion (str): the criterion for the binary OOD detection task. eval_grouping_loss (bool): whether to evaluate the grouping loss. - num_calibration_bins (int): the number of bins for the evaluation of the calibration. + num_bins_cal_err (int): the number of bins for the evaluation of the calibration. mixup_params (dict | None): the dictionary to setup the mixup augmentation. post_processing (PostProcessing | None): the post-processing module. format_batch_fn (nn.Module | None): the function for formatting the batch for ensembles. @@ -758,8 +749,8 @@ def _classification_routine_checks( "attribute to compute the grouping loss." ) - if num_calibration_bins < 2: - raise ValueError(f"num_calibration_bins must be at least 2, got {num_calibration_bins}.") + if num_bins_cal_err < 2: + raise ValueError(f"num_bins_cal_err must be at least 2, got {num_bins_cal_err}.") if mixup_params is not None and isinstance(format_batch_fn, RepeatTarget): raise ValueError( diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index fac59ecb..662d5581 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -37,7 +37,7 @@ def __init__( metric_subsampling_rate: float = 1e-2, log_plots: bool = False, num_samples_to_plot: int = 3, - num_calibration_bins: int = 15, + num_bins_cal_err: int = 15, ) -> None: r"""Routine for training & testing on **segmentation** tasks. @@ -57,8 +57,8 @@ def __init__( metrics. Defaults to ``False``. num_samples_to_plot (int, optional): Number of samples to plot in the segmentation results. Defaults to ``3``. - num_calibration_bins (int, optional): Number of bins to compute calibration - metrics. Defaults to ``15``. + num_bins_cal_err (int, optional): Number of bins to compute calibration + error metrics. Defaults to ``15``. Warning: You must define :attr:`optim_recipe` if you do not use the CLI. @@ -72,7 +72,7 @@ def __init__( _segmentation_routine_checks( num_classes=num_classes, metric_subsampling_rate=metric_subsampling_rate, - num_calibration_bins=num_calibration_bins, + num_bins_cal_err=num_bins_cal_err, ) if eval_shift: raise NotImplementedError( @@ -81,7 +81,7 @@ def __init__( self.model = model self.num_classes = num_classes - self.num_calibration_bins = num_calibration_bins + self.num_bins_cal_err = num_bins_cal_err self.loss = loss self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) @@ -118,13 +118,13 @@ def _init_metrics(self) -> None: "cal/ECE": CalibrationError( task="multiclass", num_classes=self.num_classes, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, ), "cal/aECE": CalibrationError( task="multiclass", adaptive=True, num_classes=self.num_classes, - num_bins=self.num_calibration_bins, + num_bins=self.num_bins_cal_err, ), "sc/AURC": AURC(), "sc/AUGRC": AUGRC(), @@ -283,7 +283,10 @@ def on_test_epoch_end(self) -> None: "Selective Classification/Generalized Risk-Coverage curve", self.test_sbsmpl_seg_metrics["sc/AUGRC"].plot()[0], ) - self.log_segmentation_plots() + if self.trainer.datamodule is not None: + self.log_segmentation_plots() + else: + print("No datamodule found, skipping segmentation plots.") def log_segmentation_plots(self) -> None: """Build and log examples of segmentation plots from the test set.""" @@ -327,14 +330,14 @@ def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: def _segmentation_routine_checks( num_classes: int, metric_subsampling_rate: float, - num_calibration_bins: int, + num_bins_cal_err: int, ) -> None: """Check the domains of the routine's parameters. Args: num_classes (int): the number of classes in the dataset. metric_subsampling_rate (float): the rate of subsampling to compute the metrics. - num_calibration_bins (int): the number of bins for the evaluation of the calibration. + num_bins_cal_err (int): the number of bins for the evaluation of the calibration. """ if num_classes < 2: raise ValueError(f"num_classes must be at least 2, got {num_classes}.") @@ -344,5 +347,5 @@ def _segmentation_routine_checks( f"metric_subsampling_rate must be in the range (0, 1], got {metric_subsampling_rate}." ) - if num_calibration_bins < 2: - raise ValueError(f"num_calibration_bins must be at least 2, got {num_calibration_bins}.") + if num_bins_cal_err < 2: + raise ValueError(f"num_bins_cal_err must be at least 2, got {num_bins_cal_err}.") diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 0b28c73a..ce1273b5 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -1,8 +1,29 @@ -"""Adapted from https://github.com/hendrycks/robustness.""" +"""These corruptive transformations are mostly PyTorch portings of the originals provided by +Dan Hendrycks and Thomas Dietterich in "Benchmarking neural network robustness to common +corruptions and perturbations" published at ICLR 2019 through their GitHub repository +https://github.com/hendrycks/robustness. + +However, please note that these transforms have been rewritten with more modern tools to improve +their efficiency as well as reduce the number of dependencies. As a result, some parameters had +to be modified to remain as close as possible to the original transforms. + +The authors of the library advise avoiding using the stochastic transforms to generate your dataset +to avoid reproducibility issues. It may be preferable to first check if the corrupted dataset is +available on TorchUncertainty's Hugging Face https://huggingface.co/torch-uncertainty. File an +issue if you would like one specific and missing dataset to be published on this page. + +In most of the cases, we have chosen to follow the hyperparameters used for ImageNet-C, which +differ from those of TinyImageNet-C, CIFAR-C or even the Inception version of ImageNet-C. However, +this may not be entirely suitable in the case of datasets with much smaller or bigger images. +""" from importlib import util from io import BytesIO +import torch.nn.functional as F +from einops import rearrange, repeat +from torch.distributions import Categorical + if util.find_spec("cv2"): import cv2 @@ -10,18 +31,13 @@ else: # coverage: ignore cv2_installed = False +import math as m + import numpy as np import torch +from kornia.augmentation import RandomSaltAndPepperNoise from PIL import Image -if util.find_spec("skimage"): - from skimage.filters import gaussian - from skimage.util import random_noise - - skimage_installed = True -else: # coverage: ignore - skimage_installed = False - if util.find_spec("scipy"): from scipy.ndimage import map_coordinates from scipy.ndimage import zoom as scizoom @@ -40,7 +56,8 @@ ) if util.find_spec("kornia"): - from kornia.filters import motion_blur + from kornia.color import rgb_to_grayscale + from kornia.filters import filter2d, gaussian_blur2d, motion_blur kornia_installed = True else: # coverage: ignore @@ -76,8 +93,10 @@ class TUCorruption(nn.Module): + batched: bool = False + def __init__(self, severity: int) -> None: - """Base class for corruptions.""" + """Base class for corruption transforms.""" super().__init__() if not (0 <= severity <= 5): raise ValueError("Severity must be between 0 and 5.") @@ -91,146 +110,341 @@ def __repr__(self) -> str: class GaussianNoise(TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: - """Add Gaussian noise to an image. + """Apply a Gaussian noise corruption to tensor images. Args: severity (int): Severity level of the corruption. """ super().__init__(severity) - self.scale = [0, 0.04, 0.06, 0.08, 0.09, 0.10][severity] + self.scale = [0.08, 0.12, 0.18, 0.26, 0.38][severity - 1] def forward(self, img: Tensor) -> Tensor: + """Apply Gaussian noise on an input image. + + Args: + img (Tensor): A potentially batched image of shape (C, H, W) or (B, C, H, W) + """ if self.severity == 0: return img return torch.clamp(torch.normal(img, self.scale), 0, 1) class ShotNoise(TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: - """Add shot noise to an image. + """Apply a shot (Poisson) noise corruption to tensor images. Args: severity (int): Severity level of the corruption. """ super().__init__(severity) - self.scale = [500, 250, 100, 75, 50][severity - 1] + self.scale = [60, 25, 12, 5, 3][severity - 1] def forward(self, img: Tensor): + """Apply Poisson noise on an input image. + + Args: + img (Tensor): A potentially batched image of shape (C, H, W) or (B, C, H, W) + """ if self.severity == 0: return img return torch.clamp(torch.poisson(img * self.scale) / self.scale, 0, 1) class ImpulseNoise(TUCorruption): - def __init__(self, severity: int) -> None: - """Add impulse noise to an image. + batchable = True + + def __init__(self, severity: int, black_white: bool = False) -> None: + """Apply an impulse (channel-independent Salt & Pepper) noise corruption to unbatched + tensor images. Args: severity (int): Severity level of the corruption. + black_white (bool): If black and white, set all pixel channel values to 0 or 1. + Defaults to ``False`` (as in the original paper). """ super().__init__(severity) - if not skimage_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.scale = [0, 0.01, 0.02, 0.03, 0.05, 0.07][severity] + self.aug = RandomSaltAndPepperNoise( + amount=[0.03, 0.06, 0.09, 0.17, 0.27][severity - 1], + salt_vs_pepper=0.5, + p=1, + same_on_batch=False, + ) + self.black_white = black_white def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - return torch.clamp( - torch.as_tensor(random_noise(img, mode="s&p", amount=self.scale)), - torch.zeros(1), - torch.ones(1), + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + channels = img.shape[1] + if not self.black_white: + img = rearrange(img, "b c ... -> (b c) 1 ...") + img = torch.clamp( + input=self.aug(img), + min=torch.zeros(1), + max=torch.ones(1), ) + if not self.black_white: + img = rearrange(img, "(b c) 1 ... -> b c ... ", c=channels) + + if no_batch: + img = img.squeeze(0) + return img.squeeze(0) if self.black_white else img.squeeze(1) + + +def disk(radius: int, alias_blur: float = 0.1, dtype=torch.float32): + """Generate a Gaussian disk of shape (1, radius, radius) for filtering.""" + if radius <= 8: + size = torch.arange(-8, 8 + 1) + ksize = (3, 3) + else: # coverage: ignore + size = torch.arange(-radius, radius + 1) + ksize = (5, 5) + xs, ys = torch.meshgrid(size, size, indexing="xy") + + aliased_disk = ((xs**2 + ys**2) <= radius**2).to(dtype=dtype) + aliased_disk /= aliased_disk.sum() + return gaussian_blur2d( + aliased_disk.unsqueeze(0).unsqueeze(0), kernel_size=ksize, sigma=(alias_blur, alias_blur) + ).squeeze(0) class DefocusBlur(TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: - """Add defocus blur to an image. + """Apply a defocus blur corruption to unbatched tensor images. Args: severity (int): Severity level of the corruption. """ super().__init__(severity) - if not cv2_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.radius = [3, 4, 6, 8, 10][severity - 1] - self.alias_blur = [0.1, 0.5, 0.5, 0.5, 0.5][severity - 1] + radius = [3, 4, 6, 8, 10][severity - 1] + alias_blur = [0.1, 0.5, 0.5, 0.5, 0.5][severity - 1] + self.disk = disk(radius, alias_blur=alias_blur) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - img = img.numpy() - channels = [ - torch.as_tensor( - cv2.filter2D( - img[ch, :, :], - -1, - disk(self.radius, alias_blur=self.alias_blur), - ) + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + out = torch.clamp(filter2d(img, kernel=self.disk), 0, 1) + if no_batch: + out = out.squeeze(0) + return out + + +def generate_offset_distribution(max_delta, iterations): + """Symmetrized version of the glass blur swapping algorithm. + + The original implementation is sequential and extremely long on large images. This version should + be statistically equivalent. The sketch of proof will be provided in TorchUncertainty's paper. + """ + interval_length = 2 * max_delta + 1 + diagram_size = 12 * max_delta # sufficient for a proper density estimation + tab = torch.zeros((diagram_size, diagram_size), dtype=torch.float32) + tab[0, max_delta] = 1 + for pivot, t in enumerate(range(1, diagram_size)): + # the pivot gets 1/interval_length of all the accessible previous densities + for i in range(-max_delta, max_delta + 1): + if 0 <= pivot + i < diagram_size: + tab[t, pivot] += tab[t - 1, pivot + i] + + # the other values keep (interval_length-1/interval_length of their previous densities + # and 1/interval_length the value of the pivot + for i in range(-max_delta, max_delta + 1): + if i != 0 and 0 <= pivot + i < diagram_size: + tab[t, pivot + i] += (interval_length - 1) * tab[t - 1, pivot + i] + tab[ + t - 1, pivot + ] + tab[t, :] /= interval_length + density = torch.diag(tab, -max_delta - 1) + + # reducing distribution dimention + idx = torch.clamp(density, 1e-4).argmin() + density = density[:idx] + + padded_density = F.pad(density, (len(density) - 2 * max_delta - 1, 0)) + sym_density = 1 / 2 * padded_density + 1 / 2 * padded_density.flip(-1) + + # Convolve the density in lieu of iterating + sym_density = sym_density.unsqueeze(0).unsqueeze(0) + sym_density_iter = sym_density.clone() + for _ in range(iterations - 1): + sym_density_iter = F.conv1d( + sym_density_iter, torch.flip(sym_density, (-1,)), padding=sym_density.shape[-1] // 2 + ) + return Categorical(probs=sym_density_iter.squeeze(0, 1)) + + +class GlassBlur(TUCorruption): + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply a glass blur corruption to unbatched tensor images. + + Faster implementation using a symetrized offset distribution. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. + + Note: + The hyperparameters have been adapted to output images qualitatively calibrated with + the original implementation despite the changes in implementation that increase the + power of the transformation. This is notably due to discarding the correlation between + the offsets to simplify the derivation. + """ + super().__init__(severity) + if not kornia_installed: + raise ImportError( + "Please install torch_uncertainty with the image option:" + """pip install -U "torch_uncertainty[image]".""" ) - for ch in range(3) - ] - return torch.clamp(torch.stack(channels), 0, 1) + sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] + self.sigma = (sigma, sigma) + self.kernel_size = int(sigma * 6 // 2 * 2 + 1) + iterations = [1, 2, 3, 2, 3][severity - 1] + max_delta = [1, 1, 1, 2, 3][severity - 1] + self.max_delta = max_delta + self.offset_dist = generate_offset_distribution(max_delta, iterations) -class GlassBlur(TUCorruption): # TODO: batch - def __init__(self, severity: int) -> None: + def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img + + img = gaussian_blur2d( + img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma + ).squeeze(0) + + img = rearrange(img, "c h w -> h w c") + height, width, _ = img.shape + max_d = self.max_delta + + valid_h = height - max_d + valid_w = width - max_d + + # Generate random offsets + rand_offsets = ( + self.offset_dist.sample(sample_shape=(valid_h, valid_w, 2)) + - self.offset_dist.param_shape[0] // 2 + ) + + # Create base indices + hs = repeat(torch.arange(max_d, height, device=img.device)[:valid_h], "h -> h w", w=valid_w) + ws = repeat(torch.arange(max_d, width, device=img.device)[:valid_w], "w -> h w", h=valid_h) + + dy = rand_offsets[..., 0] + dx = rand_offsets[..., 1] + hs_prime = (hs + dy).clamp(0, height - 1) + ws_prime = (ws + dx).clamp(0, width - 1) + + flat_idx = hs.flatten(), ws.flatten() + flat_idx_prime = hs_prime.flatten(), ws_prime.flatten() + + tmp = img[flat_idx].clone() + img[flat_idx] = img[flat_idx_prime] + img[flat_idx_prime] = tmp + + img = rearrange(img, "h w c -> 1 c h w") # Back to BCHW + img = gaussian_blur2d(img, kernel_size=self.kernel_size, sigma=self.sigma).squeeze(0) + return torch.clamp(img, 0, 1) + + +class OriginalGlassBlur(TUCorruption): + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply a glass blur corruption to unbatched tensor images. + + Original, likely incorrect and very slow implementation. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. + """ super().__init__(severity) - if not skimage_installed or not cv2_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.sigma = [0.05, 0.25, 0.4, 0.25, 0.4][severity - 1] - self.max_delta = 1 - self.iterations = [1, 1, 1, 2, 2][severity - 1] + sigma = [0.7, 0.9, 1, 1.1, 1.5][severity - 1] + self.sigma = (sigma, sigma) + self.kernel_size = int(sigma * 4 // 2 * 2 + 1) + self.iterations = [2, 1, 3, 2, 2][severity - 1] + self.max_delta = [1, 2, 2, 3, 4][severity - 1] + + if seed is None: + self.rng = None + else: + self.rng = torch.Generator(device="cpu").manual_seed(seed) def forward(self, img: Tensor) -> Tensor: + if self.severity == 0: + return img img_size = img.shape - img = torch.as_tensor(gaussian(img, sigma=self.sigma)) - for _ in range(self.iterations): - for h in range(img_size[0] - self.max_delta, self.max_delta, -1): - for w in range(img_size[1] - self.max_delta, self.max_delta, -1): - dx, dy = torch.randint(-self.max_delta, self.max_delta, size=(2,)) - h_prime, w_prime = h + dy, w + dx - img[h, w], img[h_prime, w_prime] = ( - img[h_prime, w_prime], - img[h, w], - ) - return torch.clamp(torch.as_tensor(gaussian(img, sigma=self.sigma)), 0, 1) + img = rearrange( + gaussian_blur2d(img.unsqueeze(0), kernel_size=self.kernel_size, sigma=self.sigma), + "1 c h w -> h w c", + ) + rands = torch.randint( + -self.max_delta, + self.max_delta, + size=(self.iterations, img_size[1] - self.max_delta, img_size[2] - self.max_delta, 2), + generator=self.rng, + ) -def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): - if radius <= 8: - size = np.arange(-8, 8 + 1) - ksize = (3, 3) - else: # coverage: ignore - size = np.arange(-radius, radius + 1) - ksize = (5, 5) - xs, ys = np.meshgrid(size, size) - aliased_disk = np.array((xs**2 + ys**2) <= radius**2, dtype=dtype) - aliased_disk /= np.sum(aliased_disk) - return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) + for iteration in range(self.iterations): + for i, h in enumerate(range(img_size[1] - self.max_delta, self.max_delta, -1)): + for j, w in enumerate(range(img_size[2] - self.max_delta, self.max_delta, -1)): + dx, dy = rands[iteration, i, j, :] + h_prime, w_prime = h + dy, w + dx + img[h, w, :], img[h_prime, w_prime, :] = img[h_prime, w_prime, :], img[h, w, :] + + return torch.clamp( + gaussian_blur2d( + rearrange(img, "h w c -> 1 c h w"), kernel_size=self.kernel_size, sigma=self.sigma + ).squeeze(0), + 0, + 1, + ) class MotionBlur(TUCorruption): - def __init__(self, severity: int) -> None: - """Apply a motion blur corruption on the image. + batchable = True + + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply a motion blur corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. Note: Originally, Hendrycks et al. used Gaussian motion blur. To remove the dependency with with `Wand` we changed the transform to a simpler motion blur and kept the values of - sigma as the new half kernel sizes. + sigma as the new kernel radius sizes. """ super().__init__(severity) - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) self.radius = [3, 5, 8, 12, 15][severity - 1] if not kornia_installed: @@ -255,24 +469,32 @@ def forward(self, img: Tensor) -> Tensor: def clipped_zoom(img, zoom_factor): - h = img.shape[0] + h, w = img.shape[:2] # ceil crop height(= crop width) - ch = int(np.ceil(h / zoom_factor)) + ceil_crop_height = int(np.ceil(h / zoom_factor)) + left_crop_width = int(np.ceil(w / zoom_factor)) - top = (h - ch) // 2 + top = (h - ceil_crop_height) // 2 + left = (w - left_crop_width) // 2 img = scizoom( - img[top : top + ch, top : top + ch], + img[top : top + ceil_crop_height, left : left + left_crop_width], (zoom_factor, zoom_factor, 1), order=1, ) # trim off any extra pixels trim_top = (img.shape[0] - h) // 2 + trim_left = (img.shape[1] - w) // 2 - return img[trim_top : trim_top + h, trim_top : trim_top + h] + return img[trim_top : trim_top + h, trim_left : trim_left + w] class ZoomBlur(TUCorruption): def __init__(self, severity: int) -> None: + """Apply a zoom blur corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) self.zooms = [ np.arange(1, 1.11, 0.01), @@ -291,32 +513,36 @@ def __init__(self, severity: int) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - img = img.permute(1, 2, 0).numpy() + img = rearrange(img, "c h w -> h w c").numpy() out = np.zeros_like(img) for zoom_factor in self.zooms: out += clipped_zoom(img, zoom_factor) img = (img + out) / (len(self.zooms) + 1) - return torch.clamp(torch.as_tensor(img).permute(2, 0, 1), 0, 1) + return torch.clamp(rearrange(torch.as_tensor(img), "h w c -> c h w"), 0, 1) class Snow(TUCorruption): - def __init__(self, severity: int) -> None: - """Apply a snow effect on the image. + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply a snow effect on unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. Note: The transformation has been slightly modified, see MotionBlur for details. """ super().__init__(severity) self.mix = [ - (0.1, 0.3, 3, 0.5, 4, 0.8), - (0.2, 0.3, 2, 0.5, 4, 0.7), - (0.55, 0.3, 4, 0.9, 8, 0.7), - (0.55, 0.3, 4.5, 0.85, 8, 0.65), - (0.55, 0.3, 2.5, 0.85, 12, 0.55), + (0.1, 3, 0.5, 4, 0.8), + (0.2, 2, 0.5, 4, 0.7), + (0.55, 4, 0.9, 8, 0.7), + (0.55, 4.5, 0.85, 8, 0.65), + (0.55, 2.5, 0.85, 12, 0.55), ][severity - 1] - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) - if not kornia_installed: + if not kornia_installed or not scipy_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -325,52 +551,51 @@ def __init__(self, severity: int) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - _, height, width = img.shape - x = img.numpy() - snow_layer = self.rng.normal(size=x.shape[1:], loc=self.mix[0], scale=self.mix[1])[ + snow_layer = self.rng.normal(size=img.shape[1:], loc=self.mix[0], scale=0.3)[ ..., np.newaxis ] - snow_layer = clipped_zoom(snow_layer, self.mix[2]) - snow_layer[snow_layer < self.mix[3]] = 0 + snow_layer = clipped_zoom(snow_layer, self.mix[1]) + snow_layer[snow_layer < self.mix[2]] = 0 snow_layer = np.clip(snow_layer.squeeze(), 0, 1) - snow_layer = ( - motion_blur( - torch.as_tensor(snow_layer).unsqueeze(0).unsqueeze(0), - kernel_size=self.mix[4] * 2 + 1, - angle=self.rng.uniform(-135, -45), - direction=0, - ) - .squeeze(0) - .numpy() - ) + snow_layer = motion_blur( + torch.as_tensor(snow_layer).unsqueeze(0).unsqueeze(0), + kernel_size=self.mix[3] * 2 + 1, + angle=self.rng.uniform(-135, -45), + direction=0, + ).squeeze(0) - x = self.mix[5] * x + (1 - self.mix[5]) * np.maximum( - x, - cv2.cvtColor(x.transpose([1, 2, 0]), cv2.COLOR_RGB2GRAY).reshape(1, height, width) * 1.5 - + 0.5, + x = self.mix[4] * img + (1 - self.mix[4]) * torch.maximum( + img, + rgb_to_grayscale(img) * 1.5 + 0.5, ) - return torch.clamp(torch.as_tensor(x + snow_layer + np.rot90(snow_layer, k=2)), 0, 1) + + return torch.clamp(x + snow_layer + snow_layer.flip(dims=(1, 2)), 0, 1) class Frost(TUCorruption): - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply a frost corruption effect on unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. + """ super().__init__(severity) - self.rng = np.random.default_rng() - self.mix = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][severity - 1] + self.rng = np.random.default_rng(seed) + self.mix = [(1, 0.4), (0.8, 0.6), (0.7, 0.7), (0.65, 0.7), (0.6, 0.75)][severity - 1] self.frost_ds = FrostImages("./data", download=True, transform=ToTensor()) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - _, height, width = img.shape - frost_img = RandomResizedCrop((height, width))( - self.frost_ds[self.rng.integers(low=0, high=4)] + frost_img = RandomResizedCrop(img.shape[1:])( + self.frost_ds[self.rng.integers(low=0, high=5)] ) return torch.clamp(self.mix[0] * img + self.mix[1] * frost_img, 0, 1) -def plasma_fractal(height, width, wibbledecay=3): +def plasma_fractal(height, width, rng, wibbledecay): """Generate a heightmap using diamond-square algorithm. Return square 2d array, side length 'mapsize', of floats in range 0-1. 'mapsize' must be a power of two. @@ -379,7 +604,6 @@ def plasma_fractal(height, width, wibbledecay=3): maparray[0, 0] = 0 stepsize = height wibble = 100 - rng = np.random.default_rng() def wibbledmean(array): return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape) @@ -422,35 +646,50 @@ def filldiamonds(): class Fog(TUCorruption): - def __init__(self, severity: int, size: int = 256) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply a fog corruption effect on unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. + """ super().__init__(severity) - if (size & (size - 1) == 0) and size != 0: - self.size = size - self.resize = Resize((size, size), InterpolationMode.BICUBIC) - else: - raise ValueError(f"Size must be a power of 2. Got {size}.") self.mix = [(1.5, 2), (2, 2), (2.5, 1.7), (2.5, 1.5), (3, 1.4)][severity - 1] + self.rng = np.random.default_rng(seed) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img _, height, width = img.shape - if height != width: - raise ValueError(f"Image must be square. Got {height}x{width}.") - img = self.resize(img) max_val = img.max() + random_height_map_size = int(2 ** (m.ceil(m.log2(max(height, width))))) fog = ( self.mix[0] - * plasma_fractal(height=height, width=width, wibbledecay=self.mix[1])[:height, :width] + * plasma_fractal( + height=random_height_map_size, + width=random_height_map_size, + wibbledecay=self.mix[1], + rng=self.rng, + )[:height, :width] ) - final = torch.clamp((img + fog) * max_val / (max_val + self.mix[0]), 0, 1) - return Resize((height, width), InterpolationMode.BICUBIC)(final) + return torch.clamp((img + fog) * max_val / (max_val + self.mix[0]), 0, 1) class Brightness(IBrightness, TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: + """Apply a brightness corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + + Note: + The values have been changed to better reflect the magnitude of the original + transformation replaced with the more principled torchvision adjust_brightness. + """ TUCorruption.__init__(self, severity) - self.level = [1.1, 1.2, 1.3, 1.4, 1.5][severity - 1] + self.level = [1.3, 1.6, 1.9, 2.2, 2.5][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -459,7 +698,14 @@ def forward(self, img: Tensor) -> Tensor: class Contrast(IContrast, TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: + """Apply a contrast corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ TUCorruption.__init__(self, severity) self.level = [0.4, 0.3, 0.2, 0.1, 0.05][severity - 1] @@ -471,25 +717,37 @@ def forward(self, img: Tensor) -> Tensor | Image.Image: class Pixelate(TUCorruption): def __init__(self, severity: int) -> None: + """Apply a pixelation corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) - self.quality = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1] + self.quality = [0.6, 0.5, 0.4, 0.3, 0.25][severity - 1] + self.to_pil = ToPILImage() + self.to_tensor = ToTensor() def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img _, height, width = img.shape - img = ToPILImage()(img) + img = self.to_pil(img) img = Resize( (int(height * self.quality), int(width * self.quality)), InterpolationMode.BOX, )(img) - return ToTensor()(Resize((height, width), InterpolationMode.BOX)(img)) + return self.to_tensor(Resize((height, width), InterpolationMode.BOX)(img)) class JPEGCompression(TUCorruption): def __init__(self, severity: int) -> None: + """Apply a JPEG compression corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) - self.quality = [80, 65, 58, 50, 40][severity - 1] + self.quality = [25, 18, 15, 10, 7][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: @@ -500,34 +758,42 @@ def forward(self, img: Tensor) -> Tensor: class Elastic(TUCorruption): - def __init__(self, severity: int) -> None: + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply an elastic corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. + + Note: + mix[0][1] has been changed to 0.5 to avoid errors when dealing with small images. + """ super().__init__(severity) if not cv2_installed or not scipy_installed: raise ImportError( "Please install torch_uncertainty with the all option:" """pip install -U "torch_uncertainty[all]".""" ) - # The following pertubation values are based on the original repo but - # are quite strange, notably for the severities 3 and 4 self.mix = [ - (2, 0.7, 0.1), + (2, 0.5, 0.1), (2, 0.08, 0.2), (0.05, 0.01, 0.02), (0.07, 0.01, 0.02), (0.12, 0.01, 0.02), ][severity - 1] - self.rng = np.random.default_rng() + self.rng = np.random.default_rng(seed) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - image = np.array(img.permute(1, 2, 0), dtype=np.float32) - shape = image.shape - shape_size = shape[:2] + image = np.array(rearrange(img, "c h w -> h w c"), dtype=np.float32) + height, width, channels = image.shape + shape_size = height, width + min_shape_size = min(shape_size) # random affine center_square = np.float32(shape_size) // 2 - square_size = min(shape_size) // 3 + square_size = min_shape_size // 3 pts1 = np.float32( [ center_square + square_size, @@ -539,8 +805,8 @@ def forward(self, img: Tensor) -> Tensor: ] ) pts2 = pts1 + self.rng.uniform( - -self.mix[2] * shape_size[0], - self.mix[2] * shape_size[0], + -self.mix[2] * min_shape_size, + self.mix[2] * min_shape_size, size=pts1.shape, ).astype(np.float32) affine_transform = cv2.getAffineTransform(pts1, pts2) @@ -551,83 +817,131 @@ def forward(self, img: Tensor) -> Tensor: borderMode=cv2.BORDER_REFLECT_101, ) + sigma = self.mix[1] * min_shape_size + ks = min(int((sigma * 3 // 2) * 2 + 1), min_shape_size // 2 * 2 - 1) dx = ( - gaussian( - self.rng.uniform(-1, 1, size=shape[:2]), - self.mix[1] * shape_size[0], - mode="reflect", - truncate=3, + ( + gaussian_blur2d( + torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape_size))), + kernel_size=ks, + sigma=(sigma, sigma), + ).squeeze(0, 1) + * self.mix[0] + * shape_size[1] ) - * self.mix[0] - * shape_size[0] - ).astype(np.float32) + .numpy() + .astype(np.float32)[..., np.newaxis] + ) dy = ( - gaussian( - self.rng.uniform(-1, 1, size=shape[:2]), - self.mix[1] * shape_size[0], - mode="reflect", - truncate=3, + ( + gaussian_blur2d( + torch.as_tensor(self.rng.uniform(-1, 1, size=(1, 1, *shape_size))), + kernel_size=ks, + sigma=(sigma, sigma), + ).squeeze(0, 1) + * self.mix[0] + * shape_size[0] ) - * self.mix[0] - * shape_size[0] - ).astype(np.float32) - dx, dy = dx[..., np.newaxis], dy[..., np.newaxis] + .numpy() + .astype(np.float32)[..., np.newaxis] + ) - x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])) + x, y, z = np.meshgrid(np.arange(width), np.arange(height), np.arange(channels)) indices = ( np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1)), ) img = np.clip( - map_coordinates(image, indices, order=1, mode="reflect").reshape(shape), + map_coordinates(image, indices, order=1, mode="reflect").reshape( + (height, width, channels) + ), 0, 1, ) - return torch.as_tensor(img).permute(2, 0, 1) + return rearrange(torch.as_tensor(img), "h w c -> c h w") + + +# Additional corruption transforms class SpeckleNoise(TUCorruption): - def __init__(self, severity: int) -> None: + batchable = True + + def __init__(self, severity: int, seed: int | None = None) -> None: + """Apply speckle noise to tensor images. + + Args: + severity (int): Severity level of the corruption. + seed (int | None): Optional seed for the rng. + """ super().__init__(severity) - self.scale = [0.06, 0.1, 0.12, 0.16, 0.2][severity - 1] - self.rng = np.random.default_rng() + self.scale = [0.15, 0.2, 0.35, 0.45, 0.6][severity - 1] + self.rng = np.random.default_rng(seed) def forward(self, img: Tensor) -> Tensor: + """Apply speckle noise on images. + + Args: + img (Tensor): A potentially batched image of shape (C, H, W) or (B, C, H, W). + """ if self.severity == 0: return img return torch.clamp( - img + img * self.rng.normal(img, self.scale), + img * self.rng.normal(1, self.scale, size=img.shape), 0, 1, ) class GaussianBlur(TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: + """Apply a Gaussian blur corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ super().__init__(severity) - if not skimage_installed: + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" ) - self.sigma = [0.4, 0.6, 0.7, 0.8, 1.0][severity - 1] + sigma = [1, 2, 3, 4, 6][severity - 1] + self.sigma = (sigma, sigma) + self.kernel_size = int(sigma // 2 * 2 * 4 + 1) def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - return torch.clamp( - torch.as_tensor(gaussian(img, sigma=self.sigma)), - min=0, - max=1, + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + out = torch.clamp( + gaussian_blur2d(img, kernel_size=self.kernel_size, sigma=self.sigma), + 0, + 1, ) + if no_batch: + out = out.squeeze(0) + return out class Saturation(ISaturation, TUCorruption): + batchable = True + def __init__(self, severity: int) -> None: + """Apply a saturation corruption to unbatched tensor images. + + Args: + severity (int): Severity level of the corruption. + """ TUCorruption.__init__(self, severity) self.severity = severity - self.level = [0.1, 0.2, 0.3, 0.4, 0.5][severity - 1] + self.level = [0.8, 0.6, 0.4, 0.2, 0.1][severity - 1] def forward(self, img: Tensor) -> Tensor: if self.severity == 0: