diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index f9fc48f3..e1b5845e 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -4,6 +4,7 @@ on: branches: - main pull_request: + types: [opened, reopened, ready_for_review, synchronize] branches: - main schedule: @@ -40,7 +41,7 @@ jobs: - name: Install dependencies run: | - python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[all] - name: Sphinx build diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 1c74dd5d..8435ac79 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -7,7 +7,6 @@ on: - dev pull_request: branches: - - main - dev schedule: - cron: "42 7 * * 0" @@ -64,7 +63,7 @@ jobs: - name: Install dependencies if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[all] - name: Check style & format diff --git a/.gitignore b/.gitignore index 9ad40cf4..65bc7436 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Custom .vscode/ +.idea/ data/ logs/ lightning_logs/ @@ -41,6 +42,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.setup_done # PyInstaller # Usually these files are written by a python script from a template diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..1fd3f3d8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,43 @@ +FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-runtime + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONUNBUFFERED=1 \ + LC_ALL=C.UTF-8 \ + LANG=C.UTF-8 + +# Install Git, OpenSSH Server, and OpenGL (PyTorch's base image already includes Conda and Pip) +RUN apt-get update && apt-get install -y \ + git \ + openssh-server \ + libgl1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +# Create an empty README.md file and an empty torch_uncertainty module to satisfy flit +RUN touch README.md && mkdir -p torch_uncertainty && touch torch_uncertainty/__init__.py + +# Copy dependency file +COPY pyproject.toml /workspace/ + +# Install dependencies all dependencies +RUN pip install --no-cache-dir -e ".[all]" + +# Always activate Conda when opening a new terminal +RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc + +# Configure SSH server +RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ + echo "PubkeyAuthentication yes" >> /etc/ssh/sshd_config && \ + echo "AuthorizedKeysFile .ssh/authorized_keys" >> /etc/ssh/sshd_config + +# Expose port 8888 for TensorBoard and Jupyter Notebook and port 22 for SSH +EXPOSE 8888 22 + +# Entrypoint script (runs every time the container starts) +COPY docker/entrypoint.sh /usr/local/bin/ +RUN chmod +x /usr/local/bin/entrypoint.sh +ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] + +# Note that if /workspace/ is a mounted volume, any files copied to /workspace/ during the build will be overwritten by the mounted volume +# This is why we copy the entrypoint script to /usr/local/bin/ instead of /workspace/ diff --git a/README.md b/README.md index 52b0ed8e..e0f735a4 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,11 @@ pip install torch-uncertainty The installation procedure for contributors is different: have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). +### :whale: Docker image for contributors + +For contributors who want to run experiments on cloud GPU instances, we provide a pre-built Docker image that includes all necessary dependencies and configurations and the Dockerfile for building your custom Docker images. +This allows you to quickly launch an experiment-ready container with minimal setup. Please refer to [DOCKER.md](docker/DOCKER.md) for further details. + ## :racehorse: Quickstart We make a quickstart available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). @@ -94,4 +99,4 @@ The following projects use TorchUncertainty: - *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) -**If you are using TorchUncertainty in your project, please let us know, we will add your project to this list!** +**If you are using TorchUncertainty in your project, please let us know, and we will add your project to this list!** diff --git a/docker/DOCKER.md b/docker/DOCKER.md new file mode 100644 index 00000000..35adbb81 --- /dev/null +++ b/docker/DOCKER.md @@ -0,0 +1,71 @@ +# :whale: Docker image for contributors + +### Pre-built Docker image +1. To pull the pre-built image from Docker Hub, simply run: + ```bash + docker pull docker.io/tonyzamyatin/torch-uncertainty:latest + ``` + + This image includes: + - PyTorch with CUDA support + - OpenGL (for visualization tasks) + - Git, OpenSSH, and all Python dependencies + + Checkout the [registry on Docker Hub](https://hub.docker.com/repository/docker/tonyzamyatin/torch-uncertainty/general) for all available images. + +2. To start a container using this image, set up the necessary environment variables and run: + ```bash + docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \ + -e VM_SSH_PUBLIC_KEY="your-public-key" \ + -e GITHUB_SSH_PRIVATE_KEY="your-github-key" \ + -e GITHUB_USER="your-github-username" \ + -e GIT_USER_EMAIL="your-git-email" \ + -e GIT_USER_NAME="your-git-name" \ + docker.io/tonyzamyatin/torch-uncertainty + ``` + + Optionally, you can also set `-e USER_COMPACT_SHELL_PROMPT="true"` + to make the VM's shell prompts compact and colorized. + + **Note:** Some cloud providers offer templates, in which you can preconfigure + in advance which Docker image to pull and which environment variables to set. + In this case, the provider will pull the image, set all environment variables, + and start the container for you. + +3. Once your cloud provider has deployed the VM, it will display the host address and SSH port. + You can connect to the container via SSH using: + ```bash + ssh -i /path/to/private_key root@ -p + ``` + + Replace `` and `` with the values provided by your cloud provider, + and `/path/to/private_key` with the private key that corresponds to `VM_SSH_PUBLIC_KEY`. + +4. The container exposes port `8888` in case you want to run Jupyter Notebooks or TensorBoard. + + **Note:** The `/workspace` directory is mounted from your local machine or cloud storage, + so changes persist across container restarts. + If using a cloud provider, ensure your network volume is correctly attached to avoid losing data. + +### Modifying and publishing custom Docker image + +If you want to make changes to the Dockerfile, follow these steps: +1. Edit the Dockerfile to fit your needs. + +2. Build the modified image: + ``` + docker build -t my-custom-image . + ``` + +3. Push to a Docker registry (if you want to use it on another VM): + ``` + docker tag my-custom-image mydockerhubuser/my-custom-image:tag + docker push mydockerhubuser/my-custom-image:tag + ``` + +4. Pull the custom image onto your VM: + ``` + docker pull mydockerhubuser/my-custom-image + ``` + +5. Run the container using the same docker run command with the new image name. diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 00000000..1c1af5e5 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,131 @@ +#!/bin/bash +set -e # Exit immediately if a command fails + +echo "🚀 Starting container..." + +# Ensure SSH directory exists and has correct permissions +if [ ! -d /root/.ssh ]; then + echo "📂 Creating SSH directory..." + mkdir -p /root/.ssh && chmod 700 /root/.ssh +fi + +# Ensure the VM's public SSH key is added for authentication +if [ -z "$VM_SSH_PUBLIC_KEY" ]; then + echo "❌ Error: Please set the VM_SSH_PUBLIC_KEY environment variable." + exit 1 +fi +if [ ! -f /root/.ssh/authorized_keys ] || ! grep -q "$VM_SSH_PUBLIC_KEY" /root/.ssh/authorized_keys; then + echo "🔑 Adding VM SSH public key..." + echo "$VM_SSH_PUBLIC_KEY" > /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys +fi + +# Ensure GitHub SSH private key is set up for authentication +if [ -z "$GITHUB_SSH_PRIVATE_KEY" ]; then + echo "❌ Error: Please set the GITHUB_SSH_PRIVATE_KEY environment variable." + exit 1 +fi +if [ ! -f /root/.ssh/github_rsa ]; then + echo "🔐 Adding GitHub SSH private key..." + echo "$GITHUB_SSH_PRIVATE_KEY" > /root/.ssh/github_rsa && chmod 600 /root/.ssh/github_rsa +fi + +# Configure SSH client for GitHub authentication +if [ ! -f /root/.ssh/config ] || ! grep -q 'Host github.com' /root/.ssh/config; then + echo "⚙️ Configuring SSH client for GitHub authentication..." + cat < /root/.ssh/config +Host github.com + User git + IdentityFile /root/.ssh/github_rsa +EOF + chmod 600 /root/.ssh/config +fi + +# Add GitHub to known hosts (to avoid SSH verification prompts) +echo "📌 Ensuring GitHub is a known host..." +ssh-keygen -F github.com > /dev/null 2>&1 || ssh-keyscan github.com >> /root/.ssh/known_hosts + +# Start SSH agent and add GitHub private key (if not already added) +if ! pgrep -x "ssh-agent" > /dev/null; then + echo "🕵️ Starting SSH agent..." + eval "$(ssh-agent -s)" +fi +if ssh-add -l | grep -q github_rsa; then + echo "✅ GitHub SSH key already added." +else + echo "🔑 Adding GitHub SSH key to agent..." + ssh-add /root/.ssh/github_rsa +fi + +# Set Git user name and email (if provided) +if [ -n "$GIT_USER_NAME" ]; then + echo "👤 Setting Git username: $GIT_USER_NAME" + git config --global user.name "$GIT_USER_NAME" +fi +if [ -n "$GIT_USER_EMAIL" ]; then + echo "📧 Setting Git email: $GIT_USER_EMAIL" + git config --global user.email "$GIT_USER_EMAIL" +fi + +# Ensure first-time setup runs only once +if [ ! -f /workspace/.setup_done ]; then + echo "🛠️ Running first-time setup..." + + # Ensure GitHub username is set + if [ -z "$GITHUB_USER" ]; then + echo "❌ Error: Please set the GITHUB_USER environment variable." + exit 1 + fi + + # Clone GitHub repo if not already cloned + if [ ! -d "/workspace/.git" ]; then + echo "📦 Cloning repository: $GITHUB_USER/torch-uncertainty..." + git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace + fi + + # Mark setup as completed + touch /workspace/.setup_done + echo "✅ First-time setup complete!" +else + echo "⏩ Skipping first-time setup (already done)." +fi + +# Apply compact shell prompt customization (if enabled) +if [ -n "$USE_COMPACT_SHELL_PROMPT" ]; then + echo "🎨 Applying compact shell prompt customization..." + echo 'force_color_prompt=yes' >> /root/.bashrc + echo 'PS1="\[\033[01;34m\]\W\[\033[00m\]\$ "' >> /root/.bashrc + echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc + echo ' test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)"' >> /root/.bashrc + echo ' alias ls="ls --color=auto"' >> /root/.bashrc + echo ' alias grep="grep --color=auto"' >> /root/.bashrc + echo ' alias fgrep="fgrep --color=auto"' >> /root/.bashrc + echo ' alias egrep="egrep --color=auto"' >> /root/.bashrc + echo 'fi' >> /root/.bashrc +fi + +# Ensure /workspace is in PYTHONPATH +if ! echo "$PYTHONPATH" | grep -q "/workspace"; then + echo "📌 Adding /workspace to PYTHONPATH" + export PYTHONPATH="/workspace:$PYTHONPATH" +else + echo "✅ PYTHONPATH is already correctly set." +fi + +# Check if torch_uncertainty is installed in editable mode +if pip show torch_uncertainty | grep -q "Editable project location: /workspace"; then + echo "✅ torch_uncertainty is already installed in editable mode. 🎉" +else + echo "🔄 Reinstalling torch_uncertainty in editable mode..." + pip uninstall -y torch-uncertainty + pip install -e /workspace + echo "✅ torch_uncertainty is now installed in editable mode! 🚀" +fi + +# Activate pre-commit hooks (if enabled) +echo "🔗 Activating pre-commit hooks..." +pre-commit install + +# Ensure SSH server is started +echo "🔑 Starting SSH server..." +mkdir -p /run/sshd && chmod 755 /run/sshd +/usr/sbin/sshd -D diff --git a/docs/source/api.rst b/docs/source/api.rst index 16f1f1c2..0165ad49 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -201,6 +201,7 @@ Functions :toctree: generated/ :nosignatures: + batch_ensemble deep_ensembles mc_dropout @@ -212,6 +213,7 @@ Classes :nosignatures: :template: class.rst + BatchEnsemble CheckpointEnsemble EMA MCDropout diff --git a/docs/source/conf.py b/docs/source/conf.py index 2fe7b638..2c9d46a0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.4.1" +release = "0.4.2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 55f6b3c6..70f5cf8e 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_samples: 16 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 0c7989ab..3f8b63c2 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -38,7 +38,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_classes: 10 loss: CrossEntropyLoss diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml new file mode 100644 index 00000000..d385b100 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -0,0 +1,68 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: batch_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # BatchEnsemble + class_path: torch_uncertainty.models.lenet.batchensemble_lenet + init_args: + in_channels: 1 + num_classes: 10 + num_estimators: 5 + activation: torch.nn.ReLU + norm: torch.nn.BatchNorm2d + groups: 1 + dropout_rate: 0 + repeat_training_inputs: true + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index c5398a87..354b9bf7 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} save_schedule: - 20 @@ -67,7 +66,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml new file mode 100644 index 00000000..1d47b782 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -0,0 +1,78 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: deep_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # DeepEnsemble + class_path: torch_uncertainty.models.wrappers.deep_ensembles.deep_ensembles + init_args: + models: + # LeNet + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + # last_layer_dropout: false + layer_args: {} + num_estimators: 5 + task: classification + probabilistic: false + reset_model_parameters: true + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 363461c6..d453df0b 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} momentum: 0.99 num_classes: 10 @@ -55,7 +54,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 2274bdb5..09d7d506 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 19 cycle_length: 5 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index ddff0067..e33d954f 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 10 cycle_length: 5 diff --git a/pyproject.toml b/pyproject.toml index f3edcd96..e404e640 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.4.1" +version = "0.4.2" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, @@ -29,6 +29,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dependencies = [ "timm", @@ -39,24 +40,27 @@ dependencies = [ ] [project.optional-dependencies] +experiments = [ + "tensorboard", + "huggingface-hub", + "safetensors", +] image = [ "scikit-image", + "kornia", "h5py", "opencv-python", - "Wand", ] tabular = ["pandas"] dev = [ - "scikit-learn", - "huggingface-hub", - "torch_uncertainty[image]", - "ruff==0.7.4", + "torch_uncertainty[experiments,image]", + "ruff==0.9.9", "pytest-cov", "pre-commit", "pre-commit-hooks", ] docs = [ - "sphinx<6", + "sphinx<7", "tu_sphinx_theme", "sphinx-copybutton", "sphinx-gallery", @@ -64,11 +68,11 @@ docs = [ "sphinx-codeautolink", ] all = [ - "torch_uncertainty[dev,docs,image,tabular]", + "torch_uncertainty[dev,docs,tabular]", + "scikit-learn", "laplace-torch", - "glest==0.0.1a1", "scipy", - "tensorboard", + "glest==0.0.1a1", ] [project.urls] diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index 7a967edc..f6ab4f8e 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -37,6 +37,7 @@ def test_mnist_cutout(self): dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.setup("fit") dm.setup("test") dm.train_dataloader() @@ -47,6 +48,7 @@ def test_mnist_cutout(self): dm.setup("other") dm.eval_ood = True + dm.eval_shift = True dm.ood_transform = dm.test_transform dm.val_split = 0.1 dm.prepare_data() diff --git a/tests/layers/test_batch.py b/tests/layers/test_batch.py index bb7ca6e9..75e54485 100644 --- a/tests/layers/test_batch.py +++ b/tests/layers/test_batch.py @@ -33,6 +33,17 @@ def test_linear_one_estimator_no_bias(self, feat_input: torch.Tensor): out = layer(feat_input) assert out.shape == torch.Size([4, 2]) + def test_convert_from_linear(self, feat_input: torch.Tensor): + linear = torch.nn.Linear(6, 3) + layer = BatchLinear.from_linear(linear, num_estimators=2) + assert layer.linear.weight.shape == torch.Size([3, 6]) + assert layer.linear.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(feat_input) + assert out.shape == torch.Size([4, 3]) + class TestBatchConv2d: """Testing the BatchConv2d layer class.""" @@ -47,3 +58,14 @@ def test_conv_two_estimators(self, img_input: torch.Tensor): layer = BatchConv2d(6, 2, num_estimators=2, kernel_size=1) out = layer(img_input) assert out.shape == torch.Size([5, 2, 3, 3]) + + def test_convert_from_conv2d(self, img_input: torch.Tensor): + conv = torch.nn.Conv2d(6, 3, 1) + layer = BatchConv2d.from_conv2d(conv, num_estimators=2) + assert layer.conv.weight.shape == torch.Size([3, 6, 1, 1]) + assert layer.conv.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(img_input) + assert out.shape == torch.Size([5, 3, 3, 3]) diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index cfcee746..e30e7cd3 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -196,7 +196,23 @@ def test_linear_einsum_implementation( assert out.shape == torch.Size([1, 2, 3, 4, 2]) def test_linear_extend(self): - _ = PackedConv2d(5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1) + layer = PackedLinear(5, 3, alpha=1, num_estimators=2, gamma=1, implementation="legacy") + assert layer.weight.shape == torch.Size([4, 3, 1]) + assert layer.bias.shape == torch.Size([4]) + layer = PackedLinear(5, 3, alpha=1, num_estimators=2, gamma=1, implementation="full") + assert layer.weight.shape == torch.Size([2, 2, 3]) + assert layer.bias.shape == torch.Size([4]) + # with first=True + layer = PackedLinear( + 5, 3, alpha=1, num_estimators=2, gamma=1, implementation="legacy", first=True + ) + assert layer.weight.shape == torch.Size([4, 5, 1]) + assert layer.bias.shape == torch.Size([4]) + layer = PackedLinear( + 5, 3, alpha=1, num_estimators=2, gamma=1, implementation="full", first=True + ) + assert layer.weight.shape == torch.Size([1, 4, 5]) + assert layer.bias.shape == torch.Size([4]) def test_linear_failures(self): with pytest.raises(ValueError): @@ -220,7 +236,7 @@ def test_linear_failures(self): with pytest.raises(ValueError): _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=-1, rearrange=True) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): _ = PackedLinear( 5, 2, diff --git a/tests/metrics/test_sparsification.py b/tests/metrics/test_sparsification.py index e89df5fe..114183b0 100644 --- a/tests/metrics/test_sparsification.py +++ b/tests/metrics/test_sparsification.py @@ -33,3 +33,9 @@ def test_plot(self) -> None: assert ax.get_xlabel() == "Rejection Rate (%)" assert ax.get_ylabel() == "Error Rate (%)" plt.close(fig) + + def test_compute_nan(self) -> None: + probs = torch.Tensor([[0.1, 0.9]]) + targets = torch.Tensor([1]).long() + metric = AUSE() + assert torch.isnan(metric(probs, targets)).all() diff --git a/tests/models/test_lenet.py b/tests/models/test_lenet.py index 8519ffdf..c6a08180 100644 --- a/tests/models/test_lenet.py +++ b/tests/models/test_lenet.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch_uncertainty.models.lenet import bayesian_lenet, lenet, packed_lenet +from torch_uncertainty.models.lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet class TestLeNet: @@ -18,6 +18,7 @@ def test_main(self): model.eval() model(torch.randn(1, 1, 20, 20)) + batchensemble_lenet(1, 1) packed_lenet(1, 1) bayesian_lenet(1, 1) bayesian_lenet( diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py new file mode 100644 index 00000000..3ec42082 --- /dev/null +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -0,0 +1,82 @@ +import pytest +import torch +from torch import nn + +from torch_uncertainty.layers import BatchConv2d, BatchLinear +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble, batch_ensemble + + +@pytest.fixture() +def img_input() -> torch.Tensor: + return torch.rand((5, 6, 3, 3)) + + +# Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) +class _DummyModel(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.conv = nn.Conv2d(in_features, out_features, 3) + self.fc = nn.Linear(out_features, out_features) + + def forward(self, x): + x = self.conv(x) + x = x.flatten(1) + return self.fc(x) + + +class _DummyBEModel(nn.Module): + def __init__(self, in_features, out_features, num_estimators): + super().__init__() + self.conv = BatchConv2d(in_features, out_features, 3, num_estimators) + self.fc = BatchLinear(out_features, out_features, num_estimators=num_estimators) + + def forward(self, x): + x = self.conv(x) + x = x.flatten(1) + return self.fc(x) + + +class TestBatchEnsembleModel: + def test_convert_layers(self): + in_features = 6 + out_features = 4 + num_estimators = 3 + + model = _DummyModel(in_features, out_features) + wrapped_model = batch_ensemble(model, num_estimators, convert_layers=True) + assert wrapped_model.num_estimators == num_estimators + assert isinstance(wrapped_model.model.conv, BatchConv2d) + assert isinstance(wrapped_model.model.fc, BatchLinear) + + def test_forward_pass(self, img_input): + batch_size = img_input.size(0) + in_features = img_input.size(1) + out_features = 4 + num_estimators = 3 + model = _DummyBEModel(in_features, out_features, num_estimators) + # with repeat_training_inputs=False + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=False) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (img_input.size(0), out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # with repeat_training_inputs=True + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=True) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + + def test_errors(self): + with pytest.raises(ValueError): + BatchEnsemble(_DummyBEModel(10, 5, 1), 0) + with pytest.raises(ValueError): + BatchEnsemble(_DummyModel(10, 5), 1) + with pytest.raises(ValueError): + BatchEnsemble(nn.Identity(), 2, convert_layers=True) diff --git a/tests/post_processing/test_scalers.py b/tests/post_processing/test_scalers.py index eda1356e..fabe77f3 100644 --- a/tests/post_processing/test_scalers.py +++ b/tests/post_processing/test_scalers.py @@ -57,6 +57,12 @@ def test_errors(self): with pytest.raises(ValueError): scaler.set_temperature(val=-1) + scaler = TemperatureScaler( + model=None, + ) + with pytest.raises(ValueError, match="Cannot fit a Scaler method without model."): + scaler.fit(None) + class TestVectorScaler: """Testing the VectorScaler class.""" diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 1ee28555..ce6a76cc 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -88,6 +88,10 @@ def test_motion_blur(self): transform = MotionBlur(0) transform(inputs) + inputs = torch.rand(1, 3, 32, 32) + transform = MotionBlur(1) + transform(inputs) + def test_zoom_blur(self): inputs = torch.rand(3, 32, 32) transform = ZoomBlur(1) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 0dc88033..fdaddd14 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -1,10 +1,18 @@ from abc import ABC, abstractmethod +from importlib import util from pathlib import Path from typing import Literal from lightning.pytorch.core import LightningDataModule from numpy.typing import ArrayLike -from sklearn.model_selection import StratifiedKFold + +if util.find_spec("sklearn"): + from sklearn.model_selection import StratifiedKFold + + sklearn_installed = True +else: # coverage: ignore + sklearn_installed = False + from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler @@ -15,6 +23,8 @@ class TUDataModule(ABC, LightningDataModule): val: Dataset test: Dataset + shift_severity = 1 + def __init__( self, root: str | Path, @@ -120,7 +130,15 @@ def _get_train_targets(self) -> ArrayLike: raise NotImplementedError def make_cross_val_splits(self, n_splits: int = 10, train_over: int = 4) -> list: + if not sklearn_installed: + raise ImportError( + "Please install torch_uncertainty with the image option" + "to use crossval:" + """pip install -U "torch_uncertainty[image]".""" + ) + self.setup("fit") + skf = StratifiedKFold(n_splits) cv_dm = [] diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 3ee9c8a3..1e1441c2 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -228,8 +228,7 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 2717cc12..11d0a7fa 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -210,8 +210,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or + CIFAR-100C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d436fd16..24fd5c6d 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -289,8 +289,8 @@ def test_dataloader(self) -> list[DataLoader]: """Get the test dataloaders for ImageNet. Return: - list[DataLoader]: ImageNet test set (in distribution data) and - Textures test split (out-of-distribution data). + list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split + (out-of-distribution data), and/or ImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index a49fc168..f6879c1a 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -171,10 +171,12 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: Dataloaders of the MNIST test set (in - distribution data) and FashionMNIST test split - (out-of-distribution data). + distribution data), FashionMNIST or NotMNIST test split + (out-of-distribution data), and/or MNISTC (shifted data). """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 3c3c7ec4..bf95159e 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -237,8 +237,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, OOD data, and/or + TinyImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index ebafe7f8..b9020b8c 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -111,7 +111,7 @@ def __init__( if shift_severity not in list(range(1, 6)): raise ValueError( - "Corruptions shift_severity should be chosen between 1 and 5 " "included." + "Corruptions shift_severity should be chosen between 1 and 5 included." ) samples, labels = self.make_dataset(self.root, self.subset, self.shift_severity) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index f8354a0b..4cbbecb1 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -27,7 +27,7 @@ class CIFAR10H(CIFAR10): """ h_test_list = ["cifar-10h-probs.npy", "7b41f73eee90fdefc73bfc820ab29ba8"] - h_url = "https://github.com/jcpeterson/cifar-10h/raw/master/data/" "cifar10h-probs.npy" + h_url = "https://github.com/jcpeterson/cifar-10h/raw/master/data/cifar10h-probs.npy" def __init__( self, @@ -39,7 +39,7 @@ def __init__( ) -> None: if train: raise ValueError("CIFAR10H does not support training data.") - print("WARNING: CIFAR10H cannot be used with Classification routines " "for now.") + print("WARNING: CIFAR10H cannot be used within Classification routines for now.") super().__init__( Path(root), train=False, @@ -53,7 +53,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) self.targets = list(torch.as_tensor(np.load(self.root / self.h_test_list[0]))) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_n.py b/torch_uncertainty/datasets/classification/cifar/cifar_n.py index 6f6f8c95..e8193d68 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_n.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_n.py @@ -61,7 +61,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) self.targets = list(torch.load(self.root / self.filename)[file_arg]) @@ -112,7 +112,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) self.targets = list(torch.load(self.root / self.filename)[file_arg]) diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py index 079d20ec..38bacf70 100644 --- a/torch_uncertainty/datasets/classification/cub.py +++ b/torch_uncertainty/datasets/classification/cub.py @@ -51,7 +51,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index 722955cf..4e5ed41a 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -52,7 +52,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. " "You can use download=True to download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index e0b28ae0..ff3e364b 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -60,7 +60,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/fractals.py b/torch_uncertainty/datasets/fractals.py index 329a2086..8153f2a0 100644 --- a/torch_uncertainty/datasets/fractals.py +++ b/torch_uncertainty/datasets/fractals.py @@ -40,7 +40,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__(self.root, transform=transform, target_transform=target_transform) diff --git a/torch_uncertainty/datasets/frost.py b/torch_uncertainty/datasets/frost.py index dbde6d89..6aa069bb 100644 --- a/torch_uncertainty/datasets/frost.py +++ b/torch_uncertainty/datasets/frost.py @@ -44,7 +44,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 73ea2ba7..cc93c3d9 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -232,7 +232,7 @@ def _make_dataset(self, path: Path) -> None: """ if "depth" in path.name: raise NotImplementedError( - "Depth mode is not implemented yet. Raise an issue " "if you need it." + "Depth mode is not implemented yet. Raise an issue if you need it." ) self.samples = sorted((path / "leftImg8bit/").glob("**/*")) if self.target_type == "semantic": diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index a2c50bc3..502b9a1e 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -191,7 +191,7 @@ def download(self) -> None: logging.info("Files already downloaded and verified") return if self.url is None: - raise ValueError(f"The dataset {self.dataset_name} is not available for " "download.") + raise ValueError(f"The dataset {self.dataset_name} is not available for download.") download_root = self.root / self.root_appendix / self.dataset_name if self.dataset_name == "boston": download_url( diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 937b83f7..ec9b8e4c 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -133,7 +133,7 @@ def __init__( """ if split not in ["train", "val", "test", None]: raise ValueError( - f"Unknown split '{split}'. " "Supported splits are ['train', 'val', 'test', None]" + f"Unknown split '{split}'. Supported splits are ['train', 'val', 'test', None]" ) super().__init__(root, transforms, None, None) @@ -153,7 +153,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. " "You can use download=True to download it" + "Dataset not found or corrupted. You can use download=True to download it" ) # get filenames for split diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index dde72139..af4f31ef 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -1,6 +1,7 @@ import math import torch +from einops import repeat from torch import Tensor, nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair @@ -79,11 +80,18 @@ def __init__( :math:`H_{out} = \text{out_features}`. Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`. + It is advised to ensure that `batch_size` is divisible by :attr:`num_estimators` when + calling :func:`forward()`, so each estimator receives the same number of examples. + In a BatchEnsemble architecture, the input is typically **repeated** `num_estimators` + times along the batch dimension. Incorrect batch size may lead to unexpected results. + + To simplify batch handling, wrap your model with `torch_uncertainty.wrappers.BatchEnsemble`, + which automatically repeats the batch before passing it through the network. + Examples: >>> # With three estimators - >>> m = LinearBE(20, 30, 3) + >>> m = BatchLinear(20, 30, 3) >>> input = torch.randn(8, 20) >>> output = m(input) >>> print(output.size()) @@ -110,6 +118,30 @@ def __init__( self.register_parameter("bias", None) self.reset_parameters() + @classmethod + def from_linear(cls, linear: nn.Linear, num_estimators: int) -> "BatchLinear": + r"""Create a BatchEnsemble-style Linear layer from an existing Linear layer. + + Args: + linear (nn.Linear): The Linear layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchLinear: The converted BatchEnsemble-style Linear layer. + + Example: + >>> linear = nn.Linear(20, 30) + >>> be_linear = BatchLinear.from_linear(linear, num_estimators=3) + """ + return cls( + in_features=linear.in_features, + out_features=linear.out_features, + num_estimators=num_estimators, + bias=linear.bias is not None, + device=linear.weight.device, + dtype=linear.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -125,16 +157,12 @@ def forward(self, inputs: Tensor) -> Tensor: ) extra = batch_size % self.num_estimators - r_group = torch.repeat_interleave(self.r_group, examples_per_estimator, dim=0) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = torch.repeat_interleave(self.s_group, examples_per_estimator, dim=0) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) + r_group = repeat(self.r_group, "m h -> (m b) h", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = torch.repeat_interleave( - self.bias, - examples_per_estimator, - dim=0, - ) + bias = repeat(self.bias, "m h -> (m b) h", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None @@ -143,7 +171,7 @@ def forward(self, inputs: Tensor) -> Tensor: def extra_repr(self) -> str: return ( - f"in_features={ self.in_features}," + f"in_features={self.in_features}," f" out_features={self.out_features}," f" num_estimators={self.num_estimators}," f" bias={self.bias is not None}" @@ -273,12 +301,16 @@ def __init__( {\text{stride}[1]} + 1\right\rfloor Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_channels` when calling :func:`forward()`. + Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. + In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` + times along the first axis. Incorrect batch size may lead to unexpected results. + To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically + repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. Examples: >>> # With square kernels, four estimators and equal stride - >>> m = Conv2dBE(3, 32, 3, 4, stride=1) + >>> m = BatchConv2d(3, 32, 3, 4, stride=1) >>> input = torch.randn(8, 3, 16, 16) >>> output = m(input) >>> print(output.size()) @@ -315,6 +347,38 @@ def __init__( self.reset_parameters() + @classmethod + def from_conv2d(cls, conv2d: nn.Conv2d, num_estimators: int) -> "BatchConv2d": + r"""Create a BatchEnsemble-style Conv2d layer from an existing Conv2d layer. + + Args: + conv2d (nn.Conv2d): The Conv2d layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchConv2d: The converted BatchEnsemble-style Conv2d layer. + + Warning: + All parameters of the original Conv2d layer will be discarded. + + Example: + >>> conv2d = nn.Conv2d(3, 32, kernel_size=3) + >>> be_conv2d = BatchConv2d.from_conv2d(conv2d, num_estimators=3) + """ + return cls( + in_channels=conv2d.in_channels, + out_channels=conv2d.out_channels, + kernel_size=conv2d.kernel_size, + stride=conv2d.stride, + padding=conv2d.padding, + dilation=conv2d.dilation, + groups=conv2d.groups, + bias=conv2d.bias is not None, + num_estimators=num_estimators, + device=conv2d.weight.device, + dtype=conv2d.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -328,50 +392,13 @@ def forward(self, inputs: Tensor) -> Tensor: examples_per_estimator = batch_size // self.num_estimators extra = batch_size % self.num_estimators - r_group = ( - torch.repeat_interleave( - self.r_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.r_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = ( - torch.repeat_interleave( - self.s_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.s_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # + r_group = repeat(self.r_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = ( - torch.repeat_interleave( - self.bias, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.bias.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - + bias = repeat(self.bias, "m h -> (m b) h 1 1", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index e328aa39..f1277ed4 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -89,7 +89,7 @@ def __init__( if transposed: raise NotImplementedError( - "Bayesian transposed convolution not implemented yet. Raise an" " issue if needed." + "Bayesian transposed convolution not implemented yet. Raise an issue if needed." ) self.in_channels = in_channels @@ -164,7 +164,7 @@ def sample(self) -> tuple[Tensor, Tensor | None]: return weight, bias def extra_repr(self) -> str: # coverage: ignore - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", stride={stride}" + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): diff --git a/torch_uncertainty/layers/filter_response_norm.py b/torch_uncertainty/layers/filter_response_norm.py index ef306a8f..0f22717d 100644 --- a/torch_uncertainty/layers/filter_response_norm.py +++ b/torch_uncertainty/layers/filter_response_norm.py @@ -23,13 +23,13 @@ def __init__( super().__init__() if dimension < 1 or not isinstance(dimension, int): raise ValueError( - "dimension should be an integer greater or equal than 1. " f"got {dimension}." + f"dimension should be an integer greater or equal than 1. Got {dimension}." ) self.dimension = dimension if num_channels < 1 or not isinstance(num_channels, int): raise ValueError( - "num_channels should be an integer greater or equal than 1. " f"got {num_channels}." + f"num_channels should be an integer greater or equal than 1. Got {num_channels}." ) shape = (1, num_channels) + (1,) * dimension self.eps = eps diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index c962531e..a5ab40f9 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -79,15 +79,15 @@ def packed_in_projection( emb_q // num_groups, emb_v // num_groups, ), f"expecting value weights shape of {(emb_q, emb_v)}, but got {w_v.shape}" - assert b_q is None or b_q.shape == ( - emb_q, - ), f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" - assert b_k is None or b_k.shape == ( - emb_q, - ), f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" - assert b_v is None or b_v.shape == ( - emb_q, - ), f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" + assert b_q is None or b_q.shape == (emb_q,), ( + f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" + ) + assert b_k is None or b_k.shape == (emb_q,), ( + f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" + ) + assert b_v is None or b_v.shape == (emb_q,), ( + f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" + ) return ( packed_linear(q, w_q, num_groups, implementation, b_q), @@ -324,47 +324,47 @@ def packed_multi_head_attention_forward( # noqa: D417 # longer causal. is_causal = False - assert ( - embed_dim == embed_dim_to_check - ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + assert embed_dim == embed_dim_to_check, ( + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + ) if isinstance(embed_dim, Tensor): # embed_dim can be a tensor when JIT tracing head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + assert head_dim * num_heads == embed_dim, ( + f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + ) if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert ( - key.shape[:2] == value.shape[:2] - ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + assert key.shape[:2] == value.shape[:2], ( + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + ) else: - assert ( - key.shape == value.shape - ), f"key shape {key.shape} does not match value shape {value.shape}" + assert key.shape == value.shape, ( + f"key shape {key.shape} does not match value shape {value.shape}" + ) # # compute in-projection # if not use_separate_proj_weight: - assert ( - in_proj_weight is not None - ), "use_separate_proj_weight is False but in_proj_weight is None" + assert in_proj_weight is not None, ( + "use_separate_proj_weight is False but in_proj_weight is None" + ) q, k, v = packed_in_projection_packed( q=query, k=key, v=value, w=in_proj_weight, num_groups=num_groups, b=in_proj_bias ) else: - assert ( - q_proj_weight is not None - ), "use_separate_proj_weight is True but q_proj_weight is None" - assert ( - k_proj_weight is not None - ), "use_separate_proj_weight is True but k_proj_weight is None" - assert ( - v_proj_weight is not None - ), "use_separate_proj_weight is True but v_proj_weight is None" + assert q_proj_weight is not None, ( + "use_separate_proj_weight is True but q_proj_weight is None" + ) + assert k_proj_weight is not None, ( + "use_separate_proj_weight is True but k_proj_weight is None" + ) + assert v_proj_weight is not None, ( + "use_separate_proj_weight is True but v_proj_weight is None" + ) if in_proj_bias is None: b_q = b_k = b_v = None else: diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 71537b82..66150367 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -33,11 +33,9 @@ def check_packed_parameters_consistency(alpha: float, gamma: int, num_estimators if num_estimators is None: raise ValueError("You must specify the value of the arg. `num_estimators`") if not isinstance(num_estimators, int): - raise TypeError( - "Attribute `num_estimators` should be an int, not " f"{type(num_estimators)}" - ) + raise TypeError(f"Attribute `num_estimators` should be an int, not {type(num_estimators)}") if num_estimators <= 0: - raise ValueError("Attribute `num_estimators` should be >= 1, not " f"{num_estimators}") + raise ValueError(f"Attribute `num_estimators` should be >= 1, not {num_estimators}") class PackedLinear(nn.Module): @@ -73,8 +71,12 @@ def __init__( network. Defaults to ``False``. last (bool, optional): Whether this is the last layer of the network. Defaults to ``False``. - implementation (str, optional): The implementation to use. Defaults - to ``"legacy"``. + implementation (str, optional): The implementation to use. Available implementations: + + - ``"legacy"`` (default): The legacy implementation of the linear layer. + - ``"sparse"``: The sparse implementation of the linear layer. + - ``"full"``: The full implementation of the linear layer. + - ``"einsum"``: The einsum implementation of the linear layer. rearrange (bool, optional): Rearrange the input and outputs for compatibility with previous and later layers. Defaults to ``True``. device (torch.device, optional): The device to use for the layer's @@ -103,6 +105,13 @@ def __init__( default. """ check_packed_parameters_consistency(alpha, gamma, num_estimators) + + if implementation not in ["legacy", "sparse", "full", "einsum"]: + raise ValueError( + f"Unknown implementation: {implementation} for PackedLinear" + "Available implementations are: 'legacy', 'sparse', 'full', 'einsum'" + ) + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -121,16 +130,10 @@ def __init__( # fix if not divisible by groups if extended_in_features % actual_groups: extended_in_features += num_estimators - extended_in_features % (actual_groups) - if extended_out_features % actual_groups: - extended_out_features += num_estimators - extended_out_features % (actual_groups) - - # FIXME: This is a temporary check - assert implementation in [ - "legacy", - "sparse", - "full", - "einsum", - ], f"Unknown implementation: {implementation} for PackedLinear" + if extended_out_features % num_estimators * gamma: + extended_out_features += num_estimators - extended_out_features % ( + num_estimators * gamma + ) if self.implementation == "legacy": self.weight = nn.Parameter( @@ -697,9 +700,9 @@ def __init__( self.dropout = dropout self.batch_first = batch_first self.head_dim = self.embed_dim // self.num_heads - assert ( - self.head_dim * self.num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert self.head_dim * self.num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) self.num_estimators = num_estimators self.alpha = alpha diff --git a/torch_uncertainty/losses/bayesian.py b/torch_uncertainty/losses/bayesian.py index 3e0e632e..4af6d77b 100644 --- a/torch_uncertainty/losses/bayesian.py +++ b/torch_uncertainty/losses/bayesian.py @@ -98,12 +98,12 @@ def set_model(self, model: nn.Module | None) -> None: def _elbo_loss_checks(inner_loss: nn.Module, kl_weight: float, num_samples: int) -> None: if isinstance(inner_loss, type): - raise TypeError("The inner_loss should be an instance of a class." f"Got {inner_loss}.") + raise TypeError(f"The inner_loss should be an instance of a class.Got {inner_loss}.") if kl_weight < 0: raise ValueError(f"The KL weight should be non-negative. Got {kl_weight}.") if num_samples < 1: - raise ValueError("The number of samples should not be lower than 1." f"Got {num_samples}.") + raise ValueError(f"The number of samples should not be lower than 1. Got {num_samples}.") if not isinstance(num_samples, int): - raise TypeError("The number of samples should be an integer. " f"Got {type(num_samples)}.") + raise TypeError(f"The number of samples should be an integer. Got {type(num_samples)}.") diff --git a/torch_uncertainty/losses/classification.py b/torch_uncertainty/losses/classification.py index 82c48280..96769338 100644 --- a/torch_uncertainty/losses/classification.py +++ b/torch_uncertainty/losses/classification.py @@ -31,12 +31,12 @@ def __init__( if reg_weight is not None and (reg_weight < 0): raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight if annealing_step is not None and (annealing_step <= 0): - raise ValueError("The annealing step should be positive, but got " f"{annealing_step}.") + raise ValueError(f"The annealing step should be positive, but got {annealing_step}.") self.annealing_step = annealing_step if reduction not in ("none", "mean", "sum") and reduction is not None: @@ -178,11 +178,11 @@ def __init__( self.reduction = reduction if eps < 0: - raise ValueError("The epsilon value should be non-negative, but got " f"{eps}.") + raise ValueError(f"The epsilon value should be non-negative, but got {eps}.") self.eps = eps if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight @@ -233,7 +233,7 @@ def __init__( self.reduction = reduction if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight @@ -287,7 +287,7 @@ def __init__( if gamma < 0: raise ValueError( - "The gamma term of the focal loss should be non-negative, but got " f"{gamma}." + f"The gamma term of the focal loss should be non-negative, but got {gamma}." ) self.gamma = gamma diff --git a/torch_uncertainty/losses/regression.py b/torch_uncertainty/losses/regression.py index 888de286..479f7ff5 100644 --- a/torch_uncertainty/losses/regression.py +++ b/torch_uncertainty/losses/regression.py @@ -67,7 +67,7 @@ def __init__(self, reg_weight: float, reduction: str | None = "mean") -> None: if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight @@ -114,7 +114,7 @@ def __init__(self, beta: float = 0.5, reduction: str | None = "mean") -> None: super().__init__() if beta < 0 or beta > 1: - raise ValueError("The beta parameter should be in range [0, 1], but got " f"{beta}.") + raise ValueError(f"The beta parameter should be in range [0, 1], but got {beta}.") self.beta = beta self.nll_loss = nn.GaussianNLLLoss(reduction="none") if reduction not in ("none", "mean", "sum"): diff --git a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py index 3bcc47fe..667a0343 100644 --- a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py +++ b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py @@ -92,9 +92,9 @@ def _ace_compute( class BinaryAdaptiveCalibrationError(Metric): r"""Adaptive Top-label Calibration Error for binary tasks.""" - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False confidences: list[Tensor] accuracies: list[Tensor] diff --git a/torch_uncertainty/metrics/classification/brier_score.py b/torch_uncertainty/metrics/classification/brier_score.py index d474c6aa..24045a60 100644 --- a/torch_uncertainty/metrics/classification/brier_score.py +++ b/torch_uncertainty/metrics/classification/brier_score.py @@ -8,9 +8,9 @@ class BrierScore(Metric): - is_differentiable: bool = True - higher_is_better: bool | None = False - full_state_update: bool = False + is_differentiable = True + higher_is_better = False + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/disagreement.py b/torch_uncertainty/metrics/classification/disagreement.py index 6dd74215..843b7ff5 100644 --- a/torch_uncertainty/metrics/classification/disagreement.py +++ b/torch_uncertainty/metrics/classification/disagreement.py @@ -8,9 +8,9 @@ class Disagreement(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = None - full_state_update: bool = False + is_differentiable = False + higher_is_better = None + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/entropy.py b/torch_uncertainty/metrics/classification/entropy.py index ef5df817..400e3968 100644 --- a/torch_uncertainty/metrics/classification/entropy.py +++ b/torch_uncertainty/metrics/classification/entropy.py @@ -6,9 +6,9 @@ class Entropy(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = None - full_state_update: bool = False + is_differentiable = False + higher_is_better = None + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/fpr.py b/torch_uncertainty/metrics/classification/fpr.py index 53e3e779..ceaa5f73 100644 --- a/torch_uncertainty/metrics/classification/fpr.py +++ b/torch_uncertainty/metrics/classification/fpr.py @@ -6,9 +6,9 @@ class FPRx(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False conf: list[Tensor] targets: list[Tensor] @@ -35,7 +35,7 @@ def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: self.add_state("targets", [], dist_reduce_fx="cat") rank_zero_warn( - f"Metric `FPR{int(recall_level*100)}` will save all targets and predictions" + f"Metric `FPR{int(recall_level * 100)}` will save all targets and predictions" " in buffer. For large datasets this may lead to large memory" " footprint." ) diff --git a/torch_uncertainty/metrics/classification/grouping_loss.py b/torch_uncertainty/metrics/classification/grouping_loss.py index 137ee264..3cef6b57 100644 --- a/torch_uncertainty/metrics/classification/grouping_loss.py +++ b/torch_uncertainty/metrics/classification/grouping_loss.py @@ -26,9 +26,9 @@ def fit(self, probs: Tensor, targets: Tensor, features: Tensor) -> "GLEstimator" class GroupingLoss(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/mean_iou.py b/torch_uncertainty/metrics/classification/mean_iou.py index 5d9ef56d..8bd3b0cc 100644 --- a/torch_uncertainty/metrics/classification/mean_iou.py +++ b/torch_uncertainty/metrics/classification/mean_iou.py @@ -4,9 +4,9 @@ class MeanIntersectionOverUnion(MulticlassStatScores): - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False + is_differentiable = False + higher_is_better = True + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/mutual_information.py b/torch_uncertainty/metrics/classification/mutual_information.py index c1224eec..8e1d5691 100644 --- a/torch_uncertainty/metrics/classification/mutual_information.py +++ b/torch_uncertainty/metrics/classification/mutual_information.py @@ -6,9 +6,9 @@ class MutualInformation(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = None - full_state_update: bool = False + is_differentiable = False + higher_is_better = None + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py index 6f76ef1f..e2682313 100644 --- a/torch_uncertainty/metrics/classification/risk_coverage.py +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -11,9 +11,9 @@ class AURC(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False scores: list[Tensor] errors: list[Tensor] @@ -44,8 +44,8 @@ def __init__(self, **kwargs) -> None: kwargs: Additional keyword arguments. Reference: - Geifman & El-Yaniv. "Selective classification for deep neural - networks." In NeurIPS, 2017. + Geifman & El-Yaniv. "Selective classification for deep neural networks." In NeurIPS, + 2017. """ super().__init__(**kwargs) self.add_state("scores", default=[], dist_reduce_fx="cat") @@ -87,6 +87,7 @@ def compute(self) -> Tensor: num_samples = error_rates.size(0) if num_samples < 2: return torch.tensor([float("nan")], device=self.device) + # There is no error rate associated to 0 coverage: starting at 1 cov = torch.arange(1, num_samples + 1, device=self.device) / num_samples return _auc_compute(cov, error_rates) / (1 - 1 / num_samples) @@ -115,7 +116,7 @@ def plot( error_rates = self.partial_compute().cpu().flip(0) num_samples = error_rates.size(0) - x = torch.arange(num_samples) / num_samples + x = torch.arange(1, num_samples + 1) / num_samples aurc = _auc_compute(x, error_rates).cpu().item() # reduce plot size @@ -273,9 +274,9 @@ def plot( class CovAtxRisk(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False scores: list[Tensor] errors: list[Tensor] @@ -347,9 +348,9 @@ def __init__(self, **kwargs) -> None: class RiskAtxCov(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False scores: list[Tensor] errors: list[Tensor] diff --git a/torch_uncertainty/metrics/classification/variation_ratio.py b/torch_uncertainty/metrics/classification/variation_ratio.py index ff1ffd73..9a244734 100644 --- a/torch_uncertainty/metrics/classification/variation_ratio.py +++ b/torch_uncertainty/metrics/classification/variation_ratio.py @@ -8,9 +8,9 @@ class VariationRatio(Metric): - full_state_update: bool = False - is_differentiable: bool = True - higher_is_better: bool = False + full_state_update = False + is_differentiable = True + higher_is_better = False def __init__( self, diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/sparsification.py index 8977b7e0..cee99884 100644 --- a/torch_uncertainty/metrics/sparsification.py +++ b/torch_uncertainty/metrics/sparsification.py @@ -1,35 +1,25 @@ -from importlib import util - import matplotlib.pyplot as plt -import numpy as np import torch - -if util.find_spec("sklearn"): - from sklearn.metrics import auc - - sklearn_installed = True -else: # coverage: ignore - sklearn_installed = False - from torch import Tensor from torchmetrics.metric import Metric +from torchmetrics.utilities.compute import _auc_compute from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.plot import _AX_TYPE class AUSE(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False - plot_lower_bound: float = 0.0 - plot_upper_bound: float = 100.0 - plot_legend_name: str = "Sparsification Curves" + is_differentiable = False + higher_is_better = False + full_state_update = False + plot_lower_bound = 0.0 + plot_upper_bound = 100.0 + plot_legend_name = "Sparsification Curves" scores: list[Tensor] errors: list[Tensor] def __init__(self, **kwargs) -> None: - r"""The Area Under the Sparsification Error curve (AUSE) metric to estimate + r"""The Area Under the Sparsification Error curve (AUSE) metric to evaluate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors. @@ -56,9 +46,6 @@ def __init__(self, **kwargs) -> None: self.add_state("scores", default=[], dist_reduce_fx="cat") self.add_state("errors", default=[], dist_reduce_fx="cat") - if not sklearn_installed: - raise ImportError("Please install scikit-learn to use AUSE.") - def update(self, scores: Tensor, errors: Tensor) -> None: """Store the scores and their associated errors for later computation. @@ -72,6 +59,9 @@ def update(self, scores: Tensor, errors: Tensor) -> None: def partial_compute(self) -> tuple[Tensor, Tensor]: scores = dim_zero_cat(self.scores) errors = dim_zero_cat(self.errors) + if scores.shape[0] < 2: + nan = torch.tensor([float("nan")], device=self.device) + return nan, nan error_rates = _ause_rejection_rate_compute(scores, errors) optimal_error_rates = _ause_rejection_rate_compute(errors, errors) return error_rates.cpu(), optimal_error_rates.cpu() @@ -84,10 +74,12 @@ def compute(self) -> Tensor: Tensor: The AUSE. """ error_rates, optimal_error_rates = self.partial_compute() + if torch.isnan(error_rates[0]).item(): + return torch.tensor([float("nan")], device=self.device) num_samples = error_rates.size(0) - x = np.arange(1, num_samples + 1) / num_samples - y = (error_rates - optimal_error_rates).numpy() - return torch.tensor([auc(x, y)]) + x = torch.arange(0, num_samples, device=self.device) / num_samples + y = error_rates - optimal_error_rates + return torch.tensor([_auc_compute(x, y)]) def plot( self, @@ -114,12 +106,12 @@ def plot( # Computation of AUSEC error_rates, optimal_error_rates = self.partial_compute() num_samples = error_rates.size(0) - x = np.arange(num_samples) / num_samples - y = (error_rates - optimal_error_rates).numpy() + x = torch.arange(num_samples) / num_samples + y = error_rates - optimal_error_rates - ausec = auc(x, y) + ausec = _auc_compute(x, y).cpu().item() - rejection_rates = (np.arange(num_samples) / num_samples) * 100 + rejection_rates = torch.arange(num_samples) / num_samples * 100 ax.plot( rejection_rates, diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 4b8964bc..a3faf1f4 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -5,9 +5,11 @@ STEP_UPDATE_MODEL, SWA, SWAG, + BatchEnsemble, CheckpointEnsemble, MCDropout, StochasticModel, + batch_ensemble, deep_ensembles, mc_dropout, ) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 4b76c9e6..b31ba133 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -9,6 +9,7 @@ from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear from torch_uncertainty.models import StochasticModel +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble __all__ = ["bayesian_lenet", "lenet", "packed_lenet"] @@ -119,6 +120,32 @@ def lenet( ) +def batchensemble_lenet( + in_channels: int, + num_classes: int, + num_estimators: int = 4, + activation: Callable = F.relu, + norm: type[nn.Module] = nn.BatchNorm2d, + groups: int = 1, + dropout_rate: float = 0.0, + repeat_training_inputs: bool = False, +) -> _LeNet: + model = lenet( + in_channels=in_channels, + num_classes=num_classes, + activation=activation, + norm=norm, + groups=groups, + dropout_rate=dropout_rate, + ) + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=True, + ) + + def packed_lenet( in_channels: int, num_classes: int, diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index 75f37e66..fb4ff50c 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 +from .batch_ensemble import BatchEnsemble, batch_ensemble from .checkpoint_ensemble import ( CheckpointEnsemble, ) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py new file mode 100644 index 00000000..99132fec --- /dev/null +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -0,0 +1,153 @@ +import torch +from einops import repeat +from torch import nn + +from torch_uncertainty.layers import BatchConv2d, BatchLinear + + +class BatchEnsemble(nn.Module): + def __init__( + self, + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, + ) -> None: + """Wrap a BatchEnsemble model to ensure correct batch replication. + + In a BatchEnsemble architecture, each estimator operates on a **sub-batch** + of the input. This means that the input batch must be **repeated** + :attr:`num_estimators` times before being processed. + + This wrapper automatically **duplicates the input batch** along the first axis, + ensuring that each estimator receives the correct data format. + + Args: + model (nn.Module): The BatchEnsemble model. + num_estimators (int): Number of ensemble members. + repeat_training_inputs (optional, bool): Whether to repeat the input batch during training. + If ``True``, the input batch is repeated during both training and evaluation. If ``False``, + the input batch is repeated only during evaluation. Default is ``False``. + convert_layers (optional, bool): Whether to convert the model's layers to BatchEnsemble layers. + If ``True``, the wrapper will convert all ``nn.Linear`` and ``nn.Conv2d`` layers to their + BatchEnsemble counterparts. Default is ``False``. + + Raises: + ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the + end of initialization. + ValueError: If ``num_estimators`` is less than or equal to ``0``. + ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are + found in the model. + + Warning: + If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d`` + layers in the model to their BatchEnsemble counterparts. If the model contains other types of + layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d`` + layers in the model, the wrapper will raise an error during conversion. + + Warning: + If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines`` + for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when + initializing the routine. + + Example: + >>> model = nn.Sequential( + ... nn.Linear(10, 5), + ... nn.ReLU(), + ... nn.Linear(5, 2) + ... ) + >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) + >>> model + BatchEnsemble( + (model): Sequential( + (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) + (1): ReLU() + (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) + ) + ) + """ + super().__init__() + self.model = model + self.num_estimators = num_estimators + self.repeat_training_inputs = repeat_training_inputs + + if convert_layers: + self._convert_layers() + + filtered_modules = [ + module + for module in self.model.modules() + if isinstance(module, BatchLinear | BatchConv2d) + ] + _batch_ensemble_checks(filtered_modules, num_estimators) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and pass it through the model.""" + if not self.training or self.repeat_training_inputs: + x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) + return self.model(x) + + def _convert_layers(self) -> None: + """Convert the model's layers to BatchEnsemble layers.""" + no_valid_layers = True + for name, layer in self.model.named_modules(): + if isinstance(layer, nn.Linear): + setattr( + self.model, + name, + BatchLinear.from_linear(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + elif isinstance(layer, nn.Conv2d): + setattr( + self.model, + name, + BatchConv2d.from_conv2d(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + if no_valid_layers: + raise ValueError( + "No valid layers found in the model. " + "Please use `nn.Linear` or `nn.Conv2d` layers to apply BatchEnsemble." + ) + + +def _batch_ensemble_checks(filtered_modules, num_estimators): + """Check if the model contains the required number of dropout modules.""" + if len(filtered_modules) == 0: + raise ValueError( + "No BatchEnsemble layers found in the model. " + "Please use `BatchLinear` or `BatchConv2d` layers in your model " + "or set `convert_layers=True` when initializing the wrapper." + ) + if num_estimators <= 0: + raise ValueError("`num_estimators` must be greater than 0.") + + +def batch_ensemble( + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, +) -> BatchEnsemble: + """BatchEnsemble wrapper for a model. + + Args: + model (nn.Module): model to wrap + num_estimators (int): number of ensemble members + repeat_training_inputs (bool, optional): whether to repeat the input batch during training. + If ``True``, the input batch is repeated during both training and evaluation. If ``False``, + the input batch is repeated only during evaluation. Default is ``False``. + convert_layers (bool, optional): whether to convert the model's layers to BatchEnsemble layers. + If ``True``, the wrapper will convert all ``nn.Linear`` and ``nn.Conv2d`` layers to their + BatchEnsemble counterparts. Default is ``False``. + + Returns: + BatchEnsemble: BatchEnsemble wrapper for the model + """ + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=convert_layers, + ) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 9a1a9ef6..a5398df2 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -60,6 +60,11 @@ def fit( progress (bool, optional): Whether to show a progress bar. Defaults to True. """ + if self.model is None: + raise ValueError( + "Cannot fit a Scaler method without model. Call .set_model(model) first." + ) + logits_list = [] labels_list = [] calibration_dl = DataLoader(calibration_set, batch_size=32, shuffle=False, drop_last=False) @@ -89,7 +94,7 @@ def calib_eval() -> float: def forward(self, inputs: Tensor) -> Tensor: if not self.trained: logging.error( - "TemperatureScaler has not been trained yet. Returning " "manually tempered inputs." + "TemperatureScaler has not been trained yet. Returning manually tempered inputs." ) return self._scale(self.model(inputs)) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 5357d954..0a1d250d 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -168,4 +168,4 @@ def _mcbn_checks(model, num_estimators, mc_batch_size, convert): if mc_batch_size < 1 or not isinstance(mc_batch_size, int): raise ValueError(f"mc_batch_size must be a positive integer, got {mc_batch_size}.") if not convert and not has_mcbn(model): - raise ValueError("model does not contain any MCBatchNorm2d nor is not to be " "converted.") + raise ValueError("model does not contain any MCBatchNorm2d nor is not to be converted.") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index bb616c58..e5469b5c 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -123,6 +123,21 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the Lightning CLI. + Warning: + When using an ensemble model, you must: + 1. Set :attr:`is_ensemble` to ``True``. + 2. Set :attr:`format_batch_fn` to :class:`torch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators)`. + 3. Ensure that the model's forward pass outputs a tensor of shape :math:`(M \times B, C)`, + where :math:`M` is the number of estimators, :math:`B` is the batch size, :math:`C` is the number of classes. + + For automated batch handling, consider using the available model wrappers in `torch_uncertainty.models.wrappers`. + + Note: + If :attr:`eval_ood` is ``True``, we perform a binary classification and update the + OOD-related metrics twice: + - once during the test on ID values where the given binary label is 0 (for ID) + - once during the test on OOD values where the given binary label is 1 (for OOD) + Note: :attr:`optim_recipe` can be anything that can be returned by :meth:`LightningModule.configure_optimizers()`. Find more details @@ -475,7 +490,7 @@ def test_step( """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) confs = probs.max(-1)[0] @@ -606,9 +621,7 @@ def on_test_epoch_end(self) -> None: if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() - shift_severity = self.trainer.test_dataloaders[ - 2 if self.eval_ood else 1 - ].dataset.shift_severity + shift_severity = self.trainer.datamodule.shift_severity tmp_metrics["shift/shift_severity"] = shift_severity self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) @@ -719,7 +732,7 @@ def _classification_routine_checks( if not is_ensemble and ood_criterion in ["mi", "vr"]: raise ValueError( - "You cannot use mutual information or variation ratio with a single" " model." + "You cannot use mutual information or variation ratio with a single model." ) if is_ensemble and eval_grouping_loss: @@ -729,12 +742,12 @@ def _classification_routine_checks( if num_classes < 1: raise ValueError( - "The number of classes must be a positive integer >= 1." f"Got {num_classes}." + f"The number of classes must be a positive integer >= 1. Got {num_classes}." ) if eval_grouping_loss and not hasattr(model, "feats_forward"): raise ValueError( - "Your model must have a `feats_forward` method to compute the " "grouping loss." + "Your model must have a `feats_forward` method to compute the grouping loss." ) if eval_grouping_loss and not ( diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 3fa450f2..92b9b9da 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -86,7 +86,7 @@ def __init__( _depth_routine_checks(output_dim, num_image_plot, log_plots) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue if needed." ) self.model = model @@ -294,7 +294,7 @@ def test_step( """ if dataloader_idx != 0: raise NotImplementedError( - "Depth OOD detection not implemented yet. Raise an issue " "if needed." + "Depth OOD detection not implemented yet. Raise an issue if needed." ) inputs, targets = batch if self.one_dim_depth: diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 07a43bb8..88514d39 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -76,7 +76,7 @@ def __init__( _regression_routine_checks(output_dim) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue if needed." ) self.model = model @@ -271,7 +271,7 @@ def test_step( """ if dataloader_idx != 0: raise NotImplementedError( - "Regression OOD detection not implemented yet. Raise an issue " "if needed." + "Regression OOD detection not implemented yet. Raise an issue if needed." ) inputs, targets = batch diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index ba08e367..fac59ecb 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -76,7 +76,7 @@ def __init__( ) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue if needed." ) self.model = model diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index 1426992a..19e0edab 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange +from einops import rearrange, repeat from torch import Tensor, nn @@ -21,7 +21,7 @@ def __init__(self, num_repeats: int) -> None: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat(self.num_repeats, *[1] * (targets.ndim - 1)) + return inputs, repeat(targets, "b ... -> (m b) ...", m=self.num_repeats) class MIMOBatchFormat(nn.Module): @@ -79,6 +79,6 @@ def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: [torch.index_select(targets, dim=0, index=indices) for indices in shuffle_indices], dim=0, ) - inputs = rearrange(inputs, "m b c h w -> (m b) c h w", m=self.num_estimators) + inputs = rearrange(inputs, "m b ... -> (m b) ...", m=self.num_estimators) targets = rearrange(targets, "m b -> (m b)", m=self.num_estimators) return inputs, targets diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index ffa5f491..0b28c73a 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -1,6 +1,5 @@ """Adapted from https://github.com/hendrycks/robustness.""" -import ctypes from importlib import util from io import BytesIO @@ -40,24 +39,12 @@ ToTensor, ) -if util.find_spec("wand"): - from wand.api import library as wandlibrary - from wand.image import Image as WandImage +if util.find_spec("kornia"): + from kornia.filters import motion_blur - wandlibrary.MagickMotionBlurImage.argtypes = ( - ctypes.c_void_p, # wand - ctypes.c_double, # radius - ctypes.c_double, # sigma - ctypes.c_double, - ) # angle - - class MotionImage(WandImage): - def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): - wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle) - - wand_installed = True + kornia_installed = True else: # coverage: ignore - wand_installed = False + kornia_installed = False from torch_uncertainty.datasets import FrostImages @@ -66,24 +53,24 @@ def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): from .image import Saturation as ISaturation __all__ = [ - "GaussianNoise", - "ShotNoise", - "ImpulseNoise", - "DefocusBlur", - "GlassBlur", - "MotionBlur", - "ZoomBlur", - "Snow", - "Frost", - "Fog", "Brightness", "Contrast", + "DefocusBlur", "Elastic", - "Pixelate", - "JPEGCompression", + "Fog", + "Frost", "GaussianBlur", - "SpeckleNoise", + "GaussianNoise", + "GlassBlur", + "ImpulseNoise", + "JPEGCompression", + "MotionBlur", + "Pixelate", "Saturation", + "ShotNoise", + "Snow", + "SpeckleNoise", + "ZoomBlur", "corruption_transforms", ] @@ -143,7 +130,7 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - if not skimage_installed: # coverage: ignore + if not skimage_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -168,7 +155,7 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - if not cv2_installed: # coverage: ignore + if not cv2_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -196,7 +183,7 @@ def forward(self, img: Tensor) -> Tensor: class GlassBlur(TUCorruption): # TODO: batch def __init__(self, severity: int) -> None: super().__init__(severity) - if not skimage_installed or not cv2_installed: # coverage: ignore + if not skimage_installed or not cv2_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -230,20 +217,23 @@ def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): xs, ys = np.meshgrid(size, size) aliased_disk = np.array((xs**2 + ys**2) <= radius**2, dtype=dtype) aliased_disk /= np.sum(aliased_disk) - return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) class MotionBlur(TUCorruption): def __init__(self, severity: int) -> None: + """Apply a motion blur corruption on the image. + + Note: + Originally, Hendrycks et al. used Gaussian motion blur. To remove the dependency with + with `Wand` we changed the transform to a simpler motion blur and kept the values of + sigma as the new half kernel sizes. + """ super().__init__(severity) self.rng = np.random.default_rng() - self.radius = [10, 15, 15, 15, 20][severity - 1] - self.sigma = [3, 5, 8, 12, 15][severity - 1] - self.to_pil_img = ToPILImage() - self.to_tensor = ToTensor() + self.radius = [3, 5, 8, 12, 15][severity - 1] - if not wand_installed: # coverage: ignore + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -252,18 +242,16 @@ def __init__(self, severity: int) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - output = BytesIO() - pil_img = self.to_pil_img(img) - pil_img.save(output, "PNG") - x = MotionImage(blob=output.getvalue()) - x.motion_blur( - radius=self.radius, - sigma=self.sigma, - angle=self.rng.uniform(-45, 45), + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + out = motion_blur( + img, kernel_size=self.radius * 2 + 1, angle=self.rng.uniform(-45, 45), direction=0 ) - x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED) - x = np.clip(x[..., [2, 1, 0]], 0, 255) - return self.to_tensor(x) + if no_batch: + out = out.squeeze(0) + return out def clipped_zoom(img, zoom_factor): @@ -294,7 +282,7 @@ def __init__(self, severity: int) -> None: np.arange(1, 1.31, 0.03), ][severity - 1] - if not scipy_installed: # coverage: ignore + if not scipy_installed: raise ImportError( "Please install torch_uncertainty with the all option:" """pip install -U "torch_uncertainty[all]".""" @@ -313,17 +301,22 @@ def forward(self, img: Tensor) -> Tensor: class Snow(TUCorruption): def __init__(self, severity: int) -> None: + """Apply a snow effect on the image. + + Note: + The transformation has been slightly modified, see MotionBlur for details. + """ super().__init__(severity) self.mix = [ - (0.1, 0.3, 3, 0.5, 10, 4, 0.8), - (0.2, 0.3, 2, 0.5, 12, 4, 0.7), - (0.55, 0.3, 4, 0.9, 12, 8, 0.7), - (0.55, 0.3, 4.5, 0.85, 12, 8, 0.65), - (0.55, 0.3, 2.5, 0.85, 12, 12, 0.55), + (0.1, 0.3, 3, 0.5, 4, 0.8), + (0.2, 0.3, 2, 0.5, 4, 0.7), + (0.55, 0.3, 4, 0.9, 8, 0.7), + (0.55, 0.3, 4.5, 0.85, 8, 0.65), + (0.55, 0.3, 2.5, 0.85, 12, 0.55), ][severity - 1] self.rng = np.random.default_rng() - if not wand_installed: # coverage: ignore + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -339,27 +332,20 @@ def forward(self, img: Tensor) -> Tensor: ] snow_layer = clipped_zoom(snow_layer, self.mix[2]) snow_layer[snow_layer < self.mix[3]] = 0 - snow_layer = Image.fromarray( - (np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), - mode="L", - ) - output = BytesIO() - snow_layer.save(output, format="PNG") - snow_layer = MotionImage(blob=output.getvalue()) - snow_layer.motion_blur( - radius=self.mix[4], - sigma=self.mix[5], - angle=self.rng.uniform(-135, -45), - ) + snow_layer = np.clip(snow_layer.squeeze(), 0, 1) + snow_layer = ( - cv2.imdecode( - np.fromstring(snow_layer.make_blob(), np.uint8), - cv2.IMREAD_UNCHANGED, + motion_blur( + torch.as_tensor(snow_layer).unsqueeze(0).unsqueeze(0), + kernel_size=self.mix[4] * 2 + 1, + angle=self.rng.uniform(-135, -45), + direction=0, ) - / 255.0 + .squeeze(0) + .numpy() ) - snow_layer = snow_layer[np.newaxis, ...] - x = self.mix[6] * x + (1 - self.mix[6]) * np.maximum( + + x = self.mix[5] * x + (1 - self.mix[5]) * np.maximum( x, cv2.cvtColor(x.transpose([1, 2, 0]), cv2.COLOR_RGB2GRAY).reshape(1, height, width) * 1.5 + 0.5, @@ -516,7 +502,7 @@ def forward(self, img: Tensor) -> Tensor: class Elastic(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - if not cv2_installed or not scipy_installed: # coverage: ignore + if not cv2_installed or not scipy_installed: raise ImportError( "Please install torch_uncertainty with the all option:" """pip install -U "torch_uncertainty[all]".""" @@ -620,7 +606,7 @@ def forward(self, img: Tensor) -> Tensor: class GaussianBlur(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - if not skimage_installed: # coverage: ignore + if not skimage_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" diff --git a/torch_uncertainty/transforms/image.py b/torch_uncertainty/transforms/image.py index b4695489..cb28dd6e 100644 --- a/torch_uncertainty/transforms/image.py +++ b/torch_uncertainty/transforms/image.py @@ -240,52 +240,6 @@ def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Ima class RandomRescale(Transform): - """Randomly rescale the input. - - This transformation can be used together with ``RandomCrop`` as data augmentations to train - models on image segmentation task. - - Output spatial size is randomly sampled from the interval ``[min_size, max_size]``: - - .. code-block:: python - - scale = uniform_sample(min_scale, max_scale) - output_width = input_width * scale - output_height = input_height * scale - - If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, - :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) - it can have arbitrary number of leading batch dimensions. For example, - the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. - - Args: - min_scale (int): Minimum scale for random sampling - max_scale (int): Maximum scale for random sampling - interpolation (InterpolationMode, optional): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, - ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. - The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. - antialias (bool, optional): Whether to apply antialiasing. - It only affects **tensors** with bilinear or bicubic modes and it is - ignored otherwise: on PIL images, antialiasing is always applied on - bilinear or bicubic modes; on other modes (for PIL images and - tensors), antialiasing makes no sense and this parameter is ignored. - Possible values are: - - - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. - Other mode aren't affected. This is probably what you want to use. - - ``False``: will not apply antialiasing for tensors on any mode. PIL - images are still antialiased on bilinear or bicubic modes, because - PIL doesn't support no antialias. - - ``None``: equivalent to ``False`` for tensors and ``True`` for - PIL images. This value exists for legacy reasons and you probably - don't want to use it unless you really know what you are doing. - - The default value changed from ``None`` to ``True`` in - v0.17, for the PIL and Tensor backends to be consistent. - """ - def __init__( self, min_scale: int, @@ -293,18 +247,69 @@ def __init__( interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, antialias: bool | None = True, ) -> None: + """Randomly rescale the input. + + This transformation can be used together with ``RandomCrop`` as data augmentations to train + models on image segmentation task. + + Output spatial size is randomly sampled from the interval ``[min_size, max_size]``: + + .. code-block:: python + + scale = uniform_sample(min_scale, max_scale) + output_width = input_width * scale + output_height = input_height * scale + + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + + Args: + min_scale (int): Minimum scale for random sampling + max_scale (int): Maximum scale for random sampling + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. + """ super().__init__() self.min_scale = min_scale self.max_scale = max_scale self.interpolation = interpolation self.antialias = antialias + # Compatibility with torchvision < 0.21 def _get_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) scale = torch.rand(1) scale = self.min_scale + scale * (self.max_scale - self.min_scale) return {"size": (int(height * scale), int(width * scale))} + # Compatibility with torchvision >= 0.21 + def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: + return self._get_params(flat_inputs) + + # Compatibility with torchvision < 0.21 def _transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel( F.resize, @@ -313,3 +318,7 @@ def _transform(self, inpt: Any, params: dict[str, Any]) -> Any: interpolation=self.interpolation, antialias=self.antialias, ) + + # Compatibility with torchvision >= 0.21 + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + return self._transform(inpt, params) diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 3e115e93..860b96d5 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -33,7 +33,7 @@ def get_dist_class(dist_family: str) -> type[Distribution]: if dist_family == "student": return StudentT raise NotImplementedError( - f"{dist_family} distribution is not supported." "Raise an issue if needed." + f"{dist_family} distribution is not supported. Raise an issue if needed." ) @@ -52,7 +52,7 @@ def get_dist_estimate(dist: Distribution, dist_estimate: str) -> Tensor: if dist_estimate == "mode": return dist.mode raise NotImplementedError( - f"{dist_estimate} estimate is not supported." "Raise an issue if needed." + f"{dist_estimate} estimate is not supported.Raise an issue if needed." )