From a6d0e97f84aa88f8dc0011e3d6c76d14433e289f Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Sun, 2 Mar 2025 14:09:05 +0100 Subject: [PATCH 01/68] :whale: Add Dockerfile --- Dockerfile | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..9d1b118c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +FROM pytorch/pytorch:2.6.0-cuda12.6-cudnn9-runtime + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONUNBUFFERED=1 \ + LC_ALL=C.UTF-8 \ + LANG=C.UTF-8 + +RUN apt-get update && apt-get install -y \ + git \ + curl \ + wget \ + vim \ + python3-pip \ + python3-venv \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY pyproject.toml . +RUN pip install --no-cache-dir . + +EXPOSE 8888 +CMD [ "/bin/bash" ] From 8b416c0be18f88d3af521002f21cacb7fd2832c2 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Sun, 2 Mar 2025 18:10:08 +0100 Subject: [PATCH 02/68] :whale: Fix flit build errors --- Dockerfile | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 9d1b118c..9bb48918 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,17 +7,22 @@ ENV DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && apt-get install -y \ git \ - curl \ - wget \ - vim \ python3-pip \ - python3-venv \ && rm -rf /var/lib/apt/lists/* WORKDIR /workspace +# Copy README.md and torch_uncertainy module (required by pyproject.toml, othwise flit build will fail) +COPY README.md . +COPY torch_uncertainty ./torch_uncertainty + +# Copy dependency file COPY pyproject.toml . + +# Install dependencies RUN pip install --no-cache-dir . +# Expose port 8888 for TensorBoard and Jupyter Notebook EXPOSE 8888 + CMD [ "/bin/bash" ] From 3ef913257a71f4cc366cecfaae330a5a1cf6f666 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Sun, 2 Mar 2025 18:25:20 +0100 Subject: [PATCH 03/68] :whale: Install OpenSSH and start server when container starts --- Dockerfile | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 9bb48918..cfb550bb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,7 @@ ENV DEBIAN_FRONTEND=noninteractive \ LC_ALL=C.UTF-8 \ LANG=C.UTF-8 +# Install git and pip RUN apt-get update && apt-get install -y \ git \ python3-pip \ @@ -22,7 +23,19 @@ COPY pyproject.toml . # Install dependencies RUN pip install --no-cache-dir . +# Install OpenSSH Server +RUN apt-get update && apt-get install -y openssh-server && rm -rf /var/lib/apt/lists/* + +# Create SSH directory & keys +RUN mkdir -p /var/run/sshd && echo 'root:root' | chpasswd + +# Allow root login via SSH +RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config + # Expose port 8888 for TensorBoard and Jupyter Notebook EXPOSE 8888 +# Expose port 22 for SSH +EXPOSE 22 -CMD [ "/bin/bash" ] +# Ensure the SSH server starts on container launch +CMD ["/usr/sbin/sshd", "-D"] From 4ffeb34df21d4add9ac54b16646e66a4639ccf80 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Mon, 3 Mar 2025 14:05:31 +0100 Subject: [PATCH 04/68] :whale: Fix SSH key not being recognized and add shell prompt customization --- Dockerfile | 64 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/Dockerfile b/Dockerfile index cfb550bb..4ae9edca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,37 +5,55 @@ ENV DEBIAN_FRONTEND=noninteractive \ LC_ALL=C.UTF-8 \ LANG=C.UTF-8 -# Install git and pip +# Install Git and OpenSSH Server (PyTorch's base image alraedy includes Conda and Pip) RUN apt-get update && apt-get install -y \ git \ - python3-pip \ + openssh-server \ && rm -rf /var/lib/apt/lists/* WORKDIR /workspace # Copy README.md and torch_uncertainy module (required by pyproject.toml, othwise flit build will fail) -COPY README.md . -COPY torch_uncertainty ./torch_uncertainty +COPY README.md /workspace/ +COPY torch_uncertainty /workspace/torch_uncertainty # Copy dependency file -COPY pyproject.toml . +COPY pyproject.toml /workspace/ # Install dependencies -RUN pip install --no-cache-dir . - -# Install OpenSSH Server -RUN apt-get update && apt-get install -y openssh-server && rm -rf /var/lib/apt/lists/* - -# Create SSH directory & keys -RUN mkdir -p /var/run/sshd && echo 'root:root' | chpasswd - -# Allow root login via SSH -RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config - -# Expose port 8888 for TensorBoard and Jupyter Notebook -EXPOSE 8888 -# Expose port 22 for SSH -EXPOSE 22 - -# Ensure the SSH server starts on container launch -CMD ["/usr/sbin/sshd", "-D"] +RUN pip install --no-cache-dir ".[all]" + +# Always activate Conda when opening a new terminal +RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc + +# Customize the Bash prompt +RUN echo 'force_color_prompt=yes' >> /root/.bashrc && \ + echo 'PS1="\[\033[01;34m\]\W\[\033[00m\]\$ "' >> /root/.bashrc && \ + echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc && \ + echo ' test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)"' >> /root/.bashrc && \ + echo ' alias ls="ls --color=auto"' >> /root/.bashrc && \ + echo ' alias grep="grep --color=auto"' >> /root/.bashrc && \ + echo ' alias fgrep="fgrep --color=auto"' >> /root/.bashrc && \ + echo ' alias egrep="egrep --color=auto"' >> /root/.bashrc && \ + echo ' cd /workspace' >> /root/.bashrc && \ + echo 'fi' >> /root/.bashrc + +# Allow SSH login as root without a password (key-based authentication only) +RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ + echo "PubkeyAuthentication yes" >> /etc/ssh/sshd_config && \ + echo "AuthorizedKeysFile .ssh/authorized_keys" >> /etc/ssh/sshd_config + +# Set default environment variable for SSH key (empty by default) +ENV SSH_PUBLIC_KEY="" + +# Expose port 8888 for TensorBoard and Jupyter Notebook and port 22 for SSH +EXPOSE 8888 22 + +# Ensure SSH key is added when the container starts +CMD ["/bin/bash", "-c", "\ + mkdir -p /root/.ssh && \ + chmod 700 /root/.ssh && \ + echo \"$SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && \ + chmod 600 /root/.ssh/authorized_keys && \ + mkdir -p /run/sshd && \ + /usr/sbin/sshd -D"] From ed04e004cda475461ab23dd1c9ccec59a5671037 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Mon, 3 Mar 2025 17:43:39 +0100 Subject: [PATCH 05/68] :whale: Reformat Dockerfile --- Dockerfile | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4ae9edca..b7b86295 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ ENV DEBIAN_FRONTEND=noninteractive \ LC_ALL=C.UTF-8 \ LANG=C.UTF-8 -# Install Git and OpenSSH Server (PyTorch's base image alraedy includes Conda and Pip) +# Install Git and OpenSSH Server (PyTorch's base image already includes Conda and Pip) RUN apt-get update && apt-get install -y \ git \ openssh-server \ @@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y \ WORKDIR /workspace -# Copy README.md and torch_uncertainy module (required by pyproject.toml, othwise flit build will fail) +# Copy README.md and torch_uncertainty module (required by pyproject.toml, otherwise flit build will fail) COPY README.md /workspace/ COPY torch_uncertainty /workspace/torch_uncertainty @@ -26,7 +26,7 @@ RUN pip install --no-cache-dir ".[all]" # Always activate Conda when opening a new terminal RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc -# Customize the Bash prompt +# Customize shell prompt RUN echo 'force_color_prompt=yes' >> /root/.bashrc && \ echo 'PS1="\[\033[01;34m\]\W\[\033[00m\]\$ "' >> /root/.bashrc && \ echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc && \ @@ -38,22 +38,19 @@ RUN echo 'force_color_prompt=yes' >> /root/.bashrc && \ echo ' cd /workspace' >> /root/.bashrc && \ echo 'fi' >> /root/.bashrc -# Allow SSH login as root without a password (key-based authentication only) +# Configure SSH server RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ echo "PubkeyAuthentication yes" >> /etc/ssh/sshd_config && \ echo "AuthorizedKeysFile .ssh/authorized_keys" >> /etc/ssh/sshd_config -# Set default environment variable for SSH key (empty by default) -ENV SSH_PUBLIC_KEY="" - # Expose port 8888 for TensorBoard and Jupyter Notebook and port 22 for SSH EXPOSE 8888 22 -# Ensure SSH key is added when the container starts +# Ensure public key for RunPod-Auth is added when the container starts CMD ["/bin/bash", "-c", "\ mkdir -p /root/.ssh && \ chmod 700 /root/.ssh && \ - echo \"$SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && \ + echo \"$RUNPOD_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && \ chmod 600 /root/.ssh/authorized_keys && \ mkdir -p /run/sshd && \ /usr/sbin/sshd -D"] From 5d731cd722a38f291351ef04372fc03f4ba49ea7 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Mon, 3 Mar 2025 22:18:04 +0100 Subject: [PATCH 06/68] :whale: Install OpenGL and run SSH setup for GitHub-Auth --- Dockerfile | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index b7b86295..8c88e414 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,10 +5,11 @@ ENV DEBIAN_FRONTEND=noninteractive \ LC_ALL=C.UTF-8 \ LANG=C.UTF-8 -# Install Git and OpenSSH Server (PyTorch's base image already includes Conda and Pip) +# Install Git, OpenSSH Server, and OpenGL (PyTorch's base image already includes Conda and Pip) RUN apt-get update && apt-get install -y \ git \ openssh-server \ + libgl1 \ && rm -rf /var/lib/apt/lists/* WORKDIR /workspace @@ -52,5 +53,14 @@ CMD ["/bin/bash", "-c", "\ chmod 700 /root/.ssh && \ echo \"$RUNPOD_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && \ chmod 600 /root/.ssh/authorized_keys && \ + echo \"$GITHUB_SSH_PRIVATE_KEY\" > /root/.ssh/github_rsa && \ + chmod 600 /root/.ssh/github_rsa && \ + echo 'Host github.com' > /root/.ssh/config && \ + echo ' User git' >> /root/.ssh/config && \ + echo ' IdentityFile /root/.ssh/github_rsa' >> /root/.ssh/config && \ + chmod 600 /root/.ssh/config && \ + eval $(ssh-agent -s) && ssh-add /root/.ssh/github_rsa && \ + ssh-keyscan github.com >> /root/.ssh/known_hosts && \ + git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace && \ mkdir -p /run/sshd && \ /usr/sbin/sshd -D"] From 025f84dadedb4a6bc99a16c0260872cadae50b78 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Mon, 3 Mar 2025 23:46:46 +0100 Subject: [PATCH 07/68] :whale: Ensure setup only runs once when container is started for ther first time --- Dockerfile | 46 +++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8c88e414..cd288ee3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,15 +27,18 @@ RUN pip install --no-cache-dir ".[all]" # Always activate Conda when opening a new terminal RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc -# Customize shell prompt +# Customize shell prompt (optional) RUN echo 'force_color_prompt=yes' >> /root/.bashrc && \ + # Blue working directory, no username, and no hostname, with $ at the end echo 'PS1="\[\033[01;34m\]\W\[\033[00m\]\$ "' >> /root/.bashrc && \ + # Colorize ls, grep, fgrep, and egrep echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc && \ echo ' test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)"' >> /root/.bashrc && \ echo ' alias ls="ls --color=auto"' >> /root/.bashrc && \ echo ' alias grep="grep --color=auto"' >> /root/.bashrc && \ echo ' alias fgrep="fgrep --color=auto"' >> /root/.bashrc && \ echo ' alias egrep="egrep --color=auto"' >> /root/.bashrc && \ + # Automatically change to workspace directory when opening a new terminal echo ' cd /workspace' >> /root/.bashrc && \ echo 'fi' >> /root/.bashrc @@ -47,20 +50,29 @@ RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ # Expose port 8888 for TensorBoard and Jupyter Notebook and port 22 for SSH EXPOSE 8888 22 -# Ensure public key for RunPod-Auth is added when the container starts CMD ["/bin/bash", "-c", "\ - mkdir -p /root/.ssh && \ - chmod 700 /root/.ssh && \ - echo \"$RUNPOD_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && \ - chmod 600 /root/.ssh/authorized_keys && \ - echo \"$GITHUB_SSH_PRIVATE_KEY\" > /root/.ssh/github_rsa && \ - chmod 600 /root/.ssh/github_rsa && \ - echo 'Host github.com' > /root/.ssh/config && \ - echo ' User git' >> /root/.ssh/config && \ - echo ' IdentityFile /root/.ssh/github_rsa' >> /root/.ssh/config && \ - chmod 600 /root/.ssh/config && \ - eval $(ssh-agent -s) && ssh-add /root/.ssh/github_rsa && \ - ssh-keyscan github.com >> /root/.ssh/known_hosts && \ - git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace && \ - mkdir -p /run/sshd && \ - /usr/sbin/sshd -D"] + # Ensure first-time setup only runs once + if [ ! -f /workspace/.setup_done ]; then \ + echo 'Running first-time setup...'; \ + # Add public key for RunPod-Auth and private key for GitHub-Auth + mkdir -p /root/.ssh && chmod 700 /root/.ssh; \ + echo \"$RUNPOD_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys; \ + echo \"$GITHUB_SSH_PRIVATE_KEY\" > /root/.ssh/github_rsa && chmod 600 /root/.ssh/github_rsa; \ + # Add GitHub credentials to SSh config + echo 'Host github.com' > /root/.ssh/config; \ + echo ' User git' >> /root/.ssh/config; \ + echo ' IdentityFile /root/.ssh/github_rsa' >> /root/.ssh/config; \ + chmod 600 /root/.ssh/config; \ + # Add GitHub to known hosts + ssh-keyscan github.com >> /root/.ssh/known_hosts; \ + # Clone GitHub repo if not already cloned + if [ ! -d \"/workspace/.git\" ]; then git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace; fi; \ + # Mark first-time setup as done + touch /workspace/.setup_done; \ + else \ + echo 'Skipping first-time setup, already done.'; \ + fi; \ + # Always start SSH agent and add key every time the container starts + eval $(ssh-agent -s) && ssh-add /root/.ssh/github_rsa; \ + # Start SSH server + mkdir -p /run/sshd && /usr/sbin/sshd -D"] From 4106501a50235657a964975d23355fabd4a70a7c Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Tue, 4 Mar 2025 09:51:41 +0100 Subject: [PATCH 08/68] :whale: Optimize setup steps and make custom prompt optional via env variable --- .gitignore | 1 + Dockerfile | 70 +++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 9ad40cf4..fcd341af 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.setup_done # PyInstaller # Usually these files are written by a python script from a template diff --git a/Dockerfile b/Dockerfile index cd288ee3..8618c582 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,8 @@ RUN pip install --no-cache-dir ".[all]" RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc # Customize shell prompt (optional) -RUN echo 'force_color_prompt=yes' >> /root/.bashrc && \ +RUN if [ ! -z "$USE_COMPACT_SHELL_PROMPT" ] && [ "$USE_COMPACT_SHELL_PROMPT" = "true" ]; then \ + echo 'force_color_prompt=yes' >> /root/.bashrc && \ # Blue working directory, no username, and no hostname, with $ at the end echo 'PS1="\[\033[01;34m\]\W\[\033[00m\]\$ "' >> /root/.bashrc && \ # Colorize ls, grep, fgrep, and egrep @@ -40,7 +41,8 @@ RUN echo 'force_color_prompt=yes' >> /root/.bashrc && \ echo ' alias egrep="egrep --color=auto"' >> /root/.bashrc && \ # Automatically change to workspace directory when opening a new terminal echo ' cd /workspace' >> /root/.bashrc && \ - echo 'fi' >> /root/.bashrc + echo 'fi' >> /root/.bashrc \ + fi; # Configure SSH server RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ @@ -50,29 +52,65 @@ RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ # Expose port 8888 for TensorBoard and Jupyter Notebook and port 22 for SSH EXPOSE 8888 22 +# Entrypoint script (runs every time the container starts) CMD ["/bin/bash", "-c", "\ - # Ensure first-time setup only runs once - if [ ! -f /workspace/.setup_done ]; then \ - echo 'Running first-time setup...'; \ - # Add public key for RunPod-Auth and private key for GitHub-Auth + # Create SSH directory and set permissions if not present \ + if [ ! -d /root/.ssh ]; then \ mkdir -p /root/.ssh && chmod 700 /root/.ssh; \ + fi; \ + # Add public key for RunPod-Auth if not present \ + if [ -z \"$RUNPOD_SSH_PUBLIC_KEY\" ]; then \ + echo 'Please set the RUNPOD_SSH_PUBLIC_KEY environment variable.'; \ + exit 1; \ + fi; \ + if [ ! -f /root/.ssh/authorized_keys ] || ! grep -q \"$RUNPOD_SSH_PUBLIC_KEY\" /root/.ssh/authorized_keys; then \ echo \"$RUNPOD_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys; \ + fi; \ + # Add private key for GitHub-Auth if not present \ + if [ -z \"$GITHUB_SSH_PRIVATE_KEY\" ]; then \ + echo 'Please set the GITHUB_SSH_PRIVATE_KEY environment variable.'; \ + exit 1; \ + fi; \ + if [ ! -f /root/.ssh/github_rsa ]; then \ echo \"$GITHUB_SSH_PRIVATE_KEY\" > /root/.ssh/github_rsa && chmod 600 /root/.ssh/github_rsa; \ - # Add GitHub credentials to SSh config + fi; \ + # Add GitHub credentials to SSH config if not present \ + if [ ! -f /root/.ssh/config ] || ! grep -q 'Host github.com' /root/.ssh/config; then \ echo 'Host github.com' > /root/.ssh/config; \ echo ' User git' >> /root/.ssh/config; \ echo ' IdentityFile /root/.ssh/github_rsa' >> /root/.ssh/config; \ chmod 600 /root/.ssh/config; \ - # Add GitHub to known hosts - ssh-keyscan github.com >> /root/.ssh/known_hosts; \ - # Clone GitHub repo if not already cloned - if [ ! -d \"/workspace/.git\" ]; then git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace; fi; \ - # Mark first-time setup as done + fi; \ + # Add GitHub to known hosts if not already added \ + ssh-keygen -F github.com > /dev/null 2>&1 || ssh-keyscan github.com >> /root/.ssh/known_hosts; \ + # Start SSH agent if not running and add GitHub private key \ + if ! pgrep -x \"ssh-agent\" > /dev/null; then \ + eval $(ssh-agent -s); \ + fi; \ + ssh-add -l | grep github_rsa > /dev/null || ssh-add /root/.ssh/github_rsa; \ + # Ensure first-time setup only runs once \ + if [ ! -f /workspace/.setup_done ]; then \ + echo 'Running first-time setup...'; \ + # Clone GitHub repo if not already cloned \ + if [ -z \"$GITHUB_USER\" ]; then \ + echo 'Please set the GITHUB_USER environment variable.'; \ + exit 1; \ + fi; \ + if [ ! -d \"/workspace/.git\" ]; then \ + git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace; \ + fi; \ + # Set Git user name and email if provided \ + if [ ! -z \"$GIT_USER_NAME\" ]; then \ + git config --global user.name \"$GIT_USER_NAME\"; \ + fi; \ + if [ ! -z \"$GIT_USER_EMAIL\" ]; then \ + git config --global user.email \"$GIT_USER_EMAIL\"; \ + fi; \ + # Mark first-time setup as done \ touch /workspace/.setup_done; \ else \ echo 'Skipping first-time setup, already done.'; \ fi; \ - # Always start SSH agent and add key every time the container starts - eval $(ssh-agent -s) && ssh-add /root/.ssh/github_rsa; \ - # Start SSH server - mkdir -p /run/sshd && /usr/sbin/sshd -D"] + # Start SSH server \ + mkdir -p /run/sshd && chmod 755 /run/sshd; \ + /usr/sbin/sshd -D"] From c8cd50b5f019f86a4411e4c85e3dd8a5cb78802e Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Tue, 4 Mar 2025 13:32:24 +0100 Subject: [PATCH 09/68] :whale: Optimize build time and run shell prompt in initial setup --- Dockerfile | 49 ++++++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8618c582..6bfd1f74 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,9 +14,8 @@ RUN apt-get update && apt-get install -y \ WORKDIR /workspace -# Copy README.md and torch_uncertainty module (required by pyproject.toml, otherwise flit build will fail) -COPY README.md /workspace/ -COPY torch_uncertainty /workspace/torch_uncertainty +# Create an empty README.md file and an empty torch_uncertainty module to satisfy flit +RUN touch README.md && mkdir -p torch_uncertainty && touch torch_uncertainty/__init__.py # Copy dependency file COPY pyproject.toml /workspace/ @@ -27,23 +26,6 @@ RUN pip install --no-cache-dir ".[all]" # Always activate Conda when opening a new terminal RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc -# Customize shell prompt (optional) -RUN if [ ! -z "$USE_COMPACT_SHELL_PROMPT" ] && [ "$USE_COMPACT_SHELL_PROMPT" = "true" ]; then \ - echo 'force_color_prompt=yes' >> /root/.bashrc && \ - # Blue working directory, no username, and no hostname, with $ at the end - echo 'PS1="\[\033[01;34m\]\W\[\033[00m\]\$ "' >> /root/.bashrc && \ - # Colorize ls, grep, fgrep, and egrep - echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc && \ - echo ' test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)"' >> /root/.bashrc && \ - echo ' alias ls="ls --color=auto"' >> /root/.bashrc && \ - echo ' alias grep="grep --color=auto"' >> /root/.bashrc && \ - echo ' alias fgrep="fgrep --color=auto"' >> /root/.bashrc && \ - echo ' alias egrep="egrep --color=auto"' >> /root/.bashrc && \ - # Automatically change to workspace directory when opening a new terminal - echo ' cd /workspace' >> /root/.bashrc && \ - echo 'fi' >> /root/.bashrc \ - fi; - # Configure SSH server RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ echo "PubkeyAuthentication yes" >> /etc/ssh/sshd_config && \ @@ -58,14 +40,16 @@ CMD ["/bin/bash", "-c", "\ if [ ! -d /root/.ssh ]; then \ mkdir -p /root/.ssh && chmod 700 /root/.ssh; \ fi; \ - # Add public key for RunPod-Auth if not present \ - if [ -z \"$RUNPOD_SSH_PUBLIC_KEY\" ]; then \ - echo 'Please set the RUNPOD_SSH_PUBLIC_KEY environment variable.'; \ + \ + # Add public key for VM-Auth if not present \ + if [ -z \"$VM_SSH_PUBLIC_KEY\" ]; then \ + echo 'Please set the VM_SSH_PUBLIC_KEY environment variable.'; \ exit 1; \ fi; \ - if [ ! -f /root/.ssh/authorized_keys ] || ! grep -q \"$RUNPOD_SSH_PUBLIC_KEY\" /root/.ssh/authorized_keys; then \ - echo \"$RUNPOD_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys; \ + if [ ! -f /root/.ssh/authorized_keys ] || ! grep -q \"$VM_SSH_PUBLIC_KEY\" /root/.ssh/authorized_keys; then \ + echo \"$VM_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys; \ fi; \ + \ # Add private key for GitHub-Auth if not present \ if [ -z \"$GITHUB_SSH_PRIVATE_KEY\" ]; then \ echo 'Please set the GITHUB_SSH_PRIVATE_KEY environment variable.'; \ @@ -88,6 +72,7 @@ CMD ["/bin/bash", "-c", "\ eval $(ssh-agent -s); \ fi; \ ssh-add -l | grep github_rsa > /dev/null || ssh-add /root/.ssh/github_rsa; \ + \ # Ensure first-time setup only runs once \ if [ ! -f /workspace/.setup_done ]; then \ echo 'Running first-time setup...'; \ @@ -111,6 +96,20 @@ CMD ["/bin/bash", "-c", "\ else \ echo 'Skipping first-time setup, already done.'; \ fi; \ + \ + # Apply shell prompt customization \ + if [ ! -z \"$USE_COMPACT_SHELL_PROMPT\" ]; then \ + echo 'force_color_prompt=yes' >> /root/.bashrc; \ + echo 'PS1=\"\[\033[01;34m\]\W\[\033[00m\]\$ \"' >> /root/.bashrc; \ + echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc; \ + echo ' test -r ~/.dircolors && eval \"$(dircolors -b ~/.dircolors)\" || eval \"$(dircolors -b)\"' >> /root/.bashrc; \ + echo ' alias ls=\"ls --color=auto\"' >> /root/.bashrc; \ + echo ' alias grep=\"grep --color=auto\"' >> /root/.bashrc; \ + echo ' alias fgrep=\"fgrep --color=auto\"' >> /root/.bashrc; \ + echo ' alias egrep=\"egrep --color=auto\"' >> /root/.bashrc; \ + echo 'fi' >> /root/.bashrc; \ + fi; \ + \ # Start SSH server \ mkdir -p /run/sshd && chmod 755 /run/sshd; \ /usr/sbin/sshd -D"] From a49bd3ec0477f2cc3b2adfa6fad119e8d348ab3b Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Tue, 4 Mar 2025 13:48:28 +0100 Subject: [PATCH 10/68] :books: Add documentation for using Docker image --- DOCKER.md | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 6 +++++ 2 files changed, 75 insertions(+) create mode 100644 DOCKER.md diff --git a/DOCKER.md b/DOCKER.md new file mode 100644 index 00000000..1b53beb4 --- /dev/null +++ b/DOCKER.md @@ -0,0 +1,69 @@ +# :whale: Docker image for contributors +### Pre-built Docker image +1. To pull the pre-built image from Docker Hub, simply run: + ```bash + docker pull docker.io/tonyzamyatin/torch-uncertainty:latest + ``` + + This image includes: + - PyTorch with CUDA support + - OpenGL (for visualization tasks) + - Git, OpenSSH, and all Python dependencies + + Checkout the [registry on Docker Hub](https://hub.docker.com/repository/docker/tonyzamyatin/torch-uncertainty/general) for all available images. + +2. To start a container using this image, set up the necessary environment variables and run: + ```bash + docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \ + -e VM_SSH_PUBLIC_KEY="your-public-key" \ + -e GITHUB_SSH_PRIVATE_KEY="your-github-key" \ + -e GITHUB_USER="your-github-username" \ + -e GIT_USER_EMAIL="your-git-email" \ + -e GIT_USER_NAME="your-git-name" \ + docker.io/tonyzamyatin/torch-uncertainty + ``` + + Optionally, you can also set `-e USER_COMPACT_SHELL_PROMPT="true"` + to make the VM's shell prompts compact and colorized. + + **Note:** Some cloud providers offer templates, in which you can preconfigure + in advance which Docker image to pull and which environment variables to set. + In this case, the provider will pull the image, set all environment variables, + and start the container for you. + +3. Once your cloud provider has deployed the VM, it will display the host address and SSH port. + You can connect to the container via SSH using: + ```bash + ssh -i /path/to/private_key root@ -p + ``` + + Replace `` and `` with the values provided by your cloud provider, + and `/path/to/private_key` with the private key that corresponds to `VM_SSH_PUBLIC_KEY`. + +4. The container exposes port `8888` in case you want to run Jupyter Notebooks or TensorBoard. + + **Note:** The `/workspace` directory is mounted from your local machine or cloud storage, + so changes persist across container restarts. + If using a cloud provider, ensure your network volume is correctly attached to avoid losing data. + +### Modifying and publishing custom Docker image +If you want to make changes to the Dockerfile, follow these steps: +1. Edit the Dockerfile to fit your needs. + +2. Build the modified image: + ``` + docker build -t my-custom-image . + ``` + +3. Push to a Docker registry (if you want to use it on another VM): + ``` + docker tag my-custom-image mydockerhubuser/my-custom-image:tag + docker push mydockerhubuser/my-custom-image:tag + ``` + +4. Pull the custom image onto your VM: + ``` + docker pull mydockerhubuser/my-custom-image + ``` + +5. Run the container using the same docker run command with the new image name. diff --git a/README.md b/README.md index 52b0ed8e..aca06bf3 100644 --- a/README.md +++ b/README.md @@ -95,3 +95,9 @@ The following projects use TorchUncertainty: - *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) **If you are using TorchUncertainty in your project, please let us know, we will add your project to this list!** + +## :whale: Docker image for contributors +For contributors who want to run experiments on cloud GPU instances, we provide a pre-built Docker image that includes all necessary dependencies and configurations and the Dockerfile for building your custom Docker images. +This allows you to quickly launch an experiment-ready container with minimal setup. + +Please refer to [DOCKER.md](DOCKER.md) for further details. From bb79e0bf1071e73469ddde403f2f774cef023542 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Tue, 4 Mar 2025 18:39:24 +0100 Subject: [PATCH 11/68] :whale: Install dependencies in editable mode --- Dockerfile | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6bfd1f74..255291d1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,8 +20,8 @@ RUN touch README.md && mkdir -p torch_uncertainty && touch torch_uncertainty/__i # Copy dependency file COPY pyproject.toml /workspace/ -# Install dependencies -RUN pip install --no-cache-dir ".[all]" +# Install dependencies (in editable mode!) +RUN pip install --no-cache-dir "-e .[all]" # Always activate Conda when opening a new terminal RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc @@ -109,7 +109,6 @@ CMD ["/bin/bash", "-c", "\ echo ' alias egrep=\"egrep --color=auto\"' >> /root/.bashrc; \ echo 'fi' >> /root/.bashrc; \ fi; \ - \ # Start SSH server \ mkdir -p /run/sshd && chmod 755 /run/sshd; \ /usr/sbin/sshd -D"] From eaa09d641b775810ab66bfb413b068823e218279 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Tue, 4 Mar 2025 22:16:24 +0100 Subject: [PATCH 12/68] :bug: Fix pip install in editable mode --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 255291d1..db632e3a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,7 +21,7 @@ RUN touch README.md && mkdir -p torch_uncertainty && touch torch_uncertainty/__i COPY pyproject.toml /workspace/ # Install dependencies (in editable mode!) -RUN pip install --no-cache-dir "-e .[all]" +RUN pip install --no-cache-dir -e ".[all]" # Always activate Conda when opening a new terminal RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc From 2522a31c04c73158118f15a884c5d9a2d12f30b2 Mon Sep 17 00:00:00 2001 From: Anton Date: Tue, 4 Mar 2025 22:48:22 +0000 Subject: [PATCH 13/68] :bug: Fix MNIST test dataloader for shifted data --- torch_uncertainty/datamodules/classification/mnist.py | 4 +++- torch_uncertainty/datasets/classification/mnist_c.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index a49fc168..d50421bf 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -171,10 +171,12 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: Dataloaders of the MNIST test set (in - distribution data) and FashionMNIST test split + distribution data), MNISTC (shifted data), and FashionMNIST test split (out-of-distribution data). """ dataloader = [self._data_loader(self.test)] if self.eval_ood: dataloader.append(self._data_loader(self.ood)) + if self.eval_shift: + dataloader.append(self._data_loader(self.shift)) return dataloader diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index 6d8086cb..aeffaa29 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -23,6 +23,7 @@ class MNISTC(VisionDataset): takes in the target and transforms it. Defaults to None. subset (str): The subset to use, one of ``all`` or the keys in ``mnistc_subsets``. + shift_severity (int): The shift_severity of the corruption, between 1 and 5. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. @@ -70,6 +71,7 @@ def __init__( target_transform: Callable | None = None, split: Literal["train", "test"] = "test", subset: str = "all", + shift_severity: int = 1, download: bool = False, ) -> None: self.root = Path(root) @@ -90,6 +92,12 @@ def __init__( raise ValueError(f"The subset '{subset}' does not exist in MNIST-C.") self.subset = subset + self.shift_severity = shift_severity + if shift_severity not in list(range(1, 6)): + raise ValueError( + "Corruptions shift_severity should be chosen between 1 and 5 " "included." + ) + if split not in ["train", "test"]: raise ValueError(f"The split '{split}' should be either 'train' or 'test'.") self.split = split From ae0405ad999bbb94c68c483430740dbf073f1a7a Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Wed, 5 Mar 2025 01:22:12 +0100 Subject: [PATCH 14/68] :hammer: Fix and format container start script and move it to entrypoint.sh --- Dockerfile | 87 +++-------------------- README.md | 2 +- DOCKER.md => docker/DOCKER.md | 0 docker/entrypoint.sh | 127 ++++++++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 80 deletions(-) rename DOCKER.md => docker/DOCKER.md (100%) create mode 100644 docker/entrypoint.sh diff --git a/Dockerfile b/Dockerfile index db632e3a..cc89a22b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,8 +20,8 @@ RUN touch README.md && mkdir -p torch_uncertainty && touch torch_uncertainty/__i # Copy dependency file COPY pyproject.toml /workspace/ -# Install dependencies (in editable mode!) -RUN pip install --no-cache-dir -e ".[all]" +# Install dependencies all dependencies +RUN pip install --no-cache-dir -e ".[dev]" # Always activate Conda when opening a new terminal RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc @@ -35,80 +35,9 @@ RUN echo "PermitRootLogin yes" >> /etc/ssh/sshd_config && \ EXPOSE 8888 22 # Entrypoint script (runs every time the container starts) -CMD ["/bin/bash", "-c", "\ - # Create SSH directory and set permissions if not present \ - if [ ! -d /root/.ssh ]; then \ - mkdir -p /root/.ssh && chmod 700 /root/.ssh; \ - fi; \ - \ - # Add public key for VM-Auth if not present \ - if [ -z \"$VM_SSH_PUBLIC_KEY\" ]; then \ - echo 'Please set the VM_SSH_PUBLIC_KEY environment variable.'; \ - exit 1; \ - fi; \ - if [ ! -f /root/.ssh/authorized_keys ] || ! grep -q \"$VM_SSH_PUBLIC_KEY\" /root/.ssh/authorized_keys; then \ - echo \"$VM_SSH_PUBLIC_KEY\" > /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys; \ - fi; \ - \ - # Add private key for GitHub-Auth if not present \ - if [ -z \"$GITHUB_SSH_PRIVATE_KEY\" ]; then \ - echo 'Please set the GITHUB_SSH_PRIVATE_KEY environment variable.'; \ - exit 1; \ - fi; \ - if [ ! -f /root/.ssh/github_rsa ]; then \ - echo \"$GITHUB_SSH_PRIVATE_KEY\" > /root/.ssh/github_rsa && chmod 600 /root/.ssh/github_rsa; \ - fi; \ - # Add GitHub credentials to SSH config if not present \ - if [ ! -f /root/.ssh/config ] || ! grep -q 'Host github.com' /root/.ssh/config; then \ - echo 'Host github.com' > /root/.ssh/config; \ - echo ' User git' >> /root/.ssh/config; \ - echo ' IdentityFile /root/.ssh/github_rsa' >> /root/.ssh/config; \ - chmod 600 /root/.ssh/config; \ - fi; \ - # Add GitHub to known hosts if not already added \ - ssh-keygen -F github.com > /dev/null 2>&1 || ssh-keyscan github.com >> /root/.ssh/known_hosts; \ - # Start SSH agent if not running and add GitHub private key \ - if ! pgrep -x \"ssh-agent\" > /dev/null; then \ - eval $(ssh-agent -s); \ - fi; \ - ssh-add -l | grep github_rsa > /dev/null || ssh-add /root/.ssh/github_rsa; \ - \ - # Ensure first-time setup only runs once \ - if [ ! -f /workspace/.setup_done ]; then \ - echo 'Running first-time setup...'; \ - # Clone GitHub repo if not already cloned \ - if [ -z \"$GITHUB_USER\" ]; then \ - echo 'Please set the GITHUB_USER environment variable.'; \ - exit 1; \ - fi; \ - if [ ! -d \"/workspace/.git\" ]; then \ - git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace; \ - fi; \ - # Set Git user name and email if provided \ - if [ ! -z \"$GIT_USER_NAME\" ]; then \ - git config --global user.name \"$GIT_USER_NAME\"; \ - fi; \ - if [ ! -z \"$GIT_USER_EMAIL\" ]; then \ - git config --global user.email \"$GIT_USER_EMAIL\"; \ - fi; \ - # Mark first-time setup as done \ - touch /workspace/.setup_done; \ - else \ - echo 'Skipping first-time setup, already done.'; \ - fi; \ - \ - # Apply shell prompt customization \ - if [ ! -z \"$USE_COMPACT_SHELL_PROMPT\" ]; then \ - echo 'force_color_prompt=yes' >> /root/.bashrc; \ - echo 'PS1=\"\[\033[01;34m\]\W\[\033[00m\]\$ \"' >> /root/.bashrc; \ - echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc; \ - echo ' test -r ~/.dircolors && eval \"$(dircolors -b ~/.dircolors)\" || eval \"$(dircolors -b)\"' >> /root/.bashrc; \ - echo ' alias ls=\"ls --color=auto\"' >> /root/.bashrc; \ - echo ' alias grep=\"grep --color=auto\"' >> /root/.bashrc; \ - echo ' alias fgrep=\"fgrep --color=auto\"' >> /root/.bashrc; \ - echo ' alias egrep=\"egrep --color=auto\"' >> /root/.bashrc; \ - echo 'fi' >> /root/.bashrc; \ - fi; \ - # Start SSH server \ - mkdir -p /run/sshd && chmod 755 /run/sshd; \ - /usr/sbin/sshd -D"] +COPY docker/entrypoint.sh /usr/local/bin/ +RUN chmod +x /usr/local/bin/entrypoint.sh +ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] + +# Note that if /workspace/ is a mounted volume, any files copied to /workspace/ during the build will be overwritten by the mounted volume +# This is why we copy the entrypoint script to /usr/local/bin/ instead of /workspace/ diff --git a/README.md b/README.md index aca06bf3..9130926e 100644 --- a/README.md +++ b/README.md @@ -100,4 +100,4 @@ The following projects use TorchUncertainty: For contributors who want to run experiments on cloud GPU instances, we provide a pre-built Docker image that includes all necessary dependencies and configurations and the Dockerfile for building your custom Docker images. This allows you to quickly launch an experiment-ready container with minimal setup. -Please refer to [DOCKER.md](DOCKER.md) for further details. +Please refer to [DOCKER.md](docker/DOCKER.md) for further details. diff --git a/DOCKER.md b/docker/DOCKER.md similarity index 100% rename from DOCKER.md rename to docker/DOCKER.md diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 00000000..ffef3e62 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,127 @@ +#!/bin/bash +set -e # Exit immediately if a command fails + +echo "🚀 Starting container..." + +# Ensure SSH directory exists and has correct permissions +if [ ! -d /root/.ssh ]; then + echo "📂 Creating SSH directory..." + mkdir -p /root/.ssh && chmod 700 /root/.ssh +fi + +# Ensure the VM's public SSH key is added for authentication +if [ -z "$VM_SSH_PUBLIC_KEY" ]; then + echo "❌ Error: Please set the VM_SSH_PUBLIC_KEY environment variable." + exit 1 +fi +if [ ! -f /root/.ssh/authorized_keys ] || ! grep -q "$VM_SSH_PUBLIC_KEY" /root/.ssh/authorized_keys; then + echo "🔑 Adding VM SSH public key..." + echo "$VM_SSH_PUBLIC_KEY" > /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys +fi + +# Ensure GitHub SSH private key is set up for authentication +if [ -z "$GITHUB_SSH_PRIVATE_KEY" ]; then + echo "❌ Error: Please set the GITHUB_SSH_PRIVATE_KEY environment variable." + exit 1 +fi +if [ ! -f /root/.ssh/github_rsa ]; then + echo "🔐 Adding GitHub SSH private key..." + echo "$GITHUB_SSH_PRIVATE_KEY" > /root/.ssh/github_rsa && chmod 600 /root/.ssh/github_rsa +fi + +# Configure SSH client for GitHub authentication +if [ ! -f /root/.ssh/config ] || ! grep -q 'Host github.com' /root/.ssh/config; then + echo "⚙️ Configuring SSH client for GitHub authentication..." + cat < /root/.ssh/config +Host github.com + User git + IdentityFile /root/.ssh/github_rsa +EOF + chmod 600 /root/.ssh/config +fi + +# Add GitHub to known hosts (to avoid SSH verification prompts) +echo "📌 Ensuring GitHub is a known host..." +ssh-keygen -F github.com > /dev/null 2>&1 || ssh-keyscan github.com >> /root/.ssh/known_hosts + +# Start SSH agent and add GitHub private key (if not already added) +if ! pgrep -x "ssh-agent" > /dev/null; then + echo "🕵️ Starting SSH agent..." + eval "$(ssh-agent -s)" +fi +if ssh-add -l | grep -q github_rsa; then + echo "✅ GitHub SSH key already added." +else + echo "🔑 Adding GitHub SSH key to agent..." + ssh-add /root/.ssh/github_rsa +fi + +# Set Git user name and email (if provided) +if [ -n "$GIT_USER_NAME" ]; then + echo "👤 Setting Git username: $GIT_USER_NAME" + git config --global user.name "$GIT_USER_NAME" +fi +if [ -n "$GIT_USER_EMAIL" ]; then + echo "📧 Setting Git email: $GIT_USER_EMAIL" + git config --global user.email "$GIT_USER_EMAIL" +fi + +# Ensure first-time setup runs only once +if [ ! -f /workspace/.setup_done ]; then + echo "🛠️ Running first-time setup..." + + # Ensure GitHub username is set + if [ -z "$GITHUB_USER" ]; then + echo "❌ Error: Please set the GITHUB_USER environment variable." + exit 1 + fi + + # Clone GitHub repo if not already cloned + if [ ! -d "/workspace/.git" ]; then + echo "📦 Cloning repository: $GITHUB_USER/torch-uncertainty..." + git clone git@github.com:$GITHUB_USER/torch-uncertainty.git /workspace + fi + + # Mark setup as completed + touch /workspace/.setup_done + echo "✅ First-time setup complete!" +else + echo "⏩ Skipping first-time setup (already done)." +fi + +# Apply compact shell prompt customization (if enabled) +if [ -n "$USE_COMPACT_SHELL_PROMPT" ]; then + echo "🎨 Applying compact shell prompt customization..." + echo 'force_color_prompt=yes' >> /root/.bashrc + echo 'PS1="\[\033[01;34m\]\W\[\033[00m\]\$ "' >> /root/.bashrc + echo 'if [ -x /usr/bin/dircolors ]; then' >> /root/.bashrc + echo ' test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)"' >> /root/.bashrc + echo ' alias ls="ls --color=auto"' >> /root/.bashrc + echo ' alias grep="grep --color=auto"' >> /root/.bashrc + echo ' alias fgrep="fgrep --color=auto"' >> /root/.bashrc + echo ' alias egrep="egrep --color=auto"' >> /root/.bashrc + echo 'fi' >> /root/.bashrc +fi + +# Ensure /workspace is in PYTHONPATH +if ! echo "$PYTHONPATH" | grep -q "/workspace"; then + echo "📌 Adding /workspace to PYTHONPATH" + export PYTHONPATH="/workspace:$PYTHONPATH" +else + echo "✅ PYTHONPATH is already correctly set." +fi + +# Check if torch_uncertainty is installed in editable mode +if pip show torch_uncertainty | grep -q "Editable project location: /workspace"; then + echo "✅ torch_uncertainty is already installed in editable mode. 🎉" +else + echo "🔄 Reinstalling torch_uncertainty in editable mode..." + pip uninstall -y torch-uncertainty + pip install -e /workspace + echo "✅ torch_uncertainty is now installed in editable mode! 🚀" +fi + +# Ensure SSH server is started +echo "🔑 Starting SSH server..." +mkdir -p /run/sshd && chmod 755 /run/sshd +/usr/sbin/sshd -D From 5d1d27f79f34118028bb117d1375e3bfc4ba70ac Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 7 Mar 2025 00:17:38 +0100 Subject: [PATCH 15/68] :bug: Fix `RandomRescale` --- torch_uncertainty/transforms/image.py | 47 ++++++++++++++++----------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/torch_uncertainty/transforms/image.py b/torch_uncertainty/transforms/image.py index b4695489..56941b9f 100644 --- a/torch_uncertainty/transforms/image.py +++ b/torch_uncertainty/transforms/image.py @@ -240,25 +240,32 @@ def forward(self, img: Tensor | Image.Image, level: float) -> Tensor | Image.Ima class RandomRescale(Transform): - """Randomly rescale the input. + def __init__( + self, + min_scale: int, + max_scale: int, + interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, + antialias: bool | None = True, + ) -> None: + """Randomly rescale the input. - This transformation can be used together with ``RandomCrop`` as data augmentations to train - models on image segmentation task. + This transformation can be used together with ``RandomCrop`` as data augmentations to train + models on image segmentation task. - Output spatial size is randomly sampled from the interval ``[min_size, max_size]``: + Output spatial size is randomly sampled from the interval ``[min_size, max_size]``: - .. code-block:: python + .. code-block:: python scale = uniform_sample(min_scale, max_scale) output_width = input_width * scale output_height = input_height * scale - If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, - :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) - it can have arbitrary number of leading batch dimensions. For example, - the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. + If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, + :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) + it can have arbitrary number of leading batch dimensions. For example, + the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. - Args: + Args: min_scale (int): Minimum scale for random sampling max_scale (int): Maximum scale for random sampling interpolation (InterpolationMode, optional): Desired interpolation enum defined by @@ -284,27 +291,25 @@ class RandomRescale(Transform): The default value changed from ``None`` to ``True`` in v0.17, for the PIL and Tensor backends to be consistent. - """ - - def __init__( - self, - min_scale: int, - max_scale: int, - interpolation: InterpolationMode | int = InterpolationMode.BILINEAR, - antialias: bool | None = True, - ) -> None: + """ super().__init__() self.min_scale = min_scale self.max_scale = max_scale self.interpolation = interpolation self.antialias = antialias + # Compatibility with torchvision < 0.21 def _get_params(self, flat_inputs: list[Any]) -> dict[str, Any]: height, width = query_size(flat_inputs) scale = torch.rand(1) scale = self.min_scale + scale * (self.max_scale - self.min_scale) return {"size": (int(height * scale), int(width * scale))} + # Compatibility with torchvision >= 0.21 + def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]: + return self._get_params(flat_inputs) + + # Compatibility with torchvision < 0.21 def _transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel( F.resize, @@ -313,3 +318,7 @@ def _transform(self, inpt: Any, params: dict[str, Any]) -> Any: interpolation=self.interpolation, antialias=self.antialias, ) + + # Compatibility with torchvision >= 0.21 + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + return self._transform(inpt, params) From 9624b2cd0d676c6ceb03a9079d5ac806c964b48f Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 10:39:25 +0100 Subject: [PATCH 16/68] :shirt: Add scaler exception --- torch_uncertainty/post_processing/calibration/scaler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 9a1a9ef6..a3e90eed 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -60,6 +60,11 @@ def fit( progress (bool, optional): Whether to show a progress bar. Defaults to True. """ + if self.model is None: + raise ValueError( + "Cannot fit a Scaler method without model. Call .set_model(model) first." + ) + logits_list = [] labels_list = [] calibration_dl = DataLoader(calibration_set, batch_size=32, shuffle=False, drop_last=False) From 748f698a9bc73674d780a9dbaf3acaf5cb60187f Mon Sep 17 00:00:00 2001 From: Olivier Laurent <62881275+o-laurent@users.noreply.github.com> Date: Fri, 7 Mar 2025 11:03:46 +0100 Subject: [PATCH 17/68] :shirt: Move Docker in the installation section --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 9130926e..e339c989 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,11 @@ pip install torch-uncertainty The installation procedure for contributors is different: have a look at the [contribution page](https://torch-uncertainty.github.io/contributing.html). +### :whale: Docker image for contributors + +For contributors who want to run experiments on cloud GPU instances, we provide a pre-built Docker image that includes all necessary dependencies and configurations and the Dockerfile for building your custom Docker images. +This allows you to quickly launch an experiment-ready container with minimal setup. Please refer to [DOCKER.md](docker/DOCKER.md) for further details. + ## :racehorse: Quickstart We make a quickstart available at [torch-uncertainty.github.io/quickstart](https://torch-uncertainty.github.io/quickstart.html). @@ -94,10 +99,5 @@ The following projects use TorchUncertainty: - *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) -**If you are using TorchUncertainty in your project, please let us know, we will add your project to this list!** - -## :whale: Docker image for contributors -For contributors who want to run experiments on cloud GPU instances, we provide a pre-built Docker image that includes all necessary dependencies and configurations and the Dockerfile for building your custom Docker images. -This allows you to quickly launch an experiment-ready container with minimal setup. +**If you are using TorchUncertainty in your project, please let us know, and we will add your project to this list!** -Please refer to [DOCKER.md](docker/DOCKER.md) for further details. From 3b6c7c62b1116efbf8c0f463f70537b7e0a75bcc Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 11:55:15 +0100 Subject: [PATCH 18/68] :bug: Fast track config fixes Co-authored-by: Anton --- .gitignore | 1 + experiments/classification/mnist/configs/bayesian_lenet.yaml | 1 - experiments/classification/mnist/configs/lenet.yaml | 1 - .../classification/mnist/configs/lenet_checkpoint_ensemble.yaml | 1 - experiments/classification/mnist/configs/lenet_ema.yaml | 1 - experiments/classification/mnist/configs/lenet_swa.yaml | 1 - experiments/classification/mnist/configs/lenet_swag.yaml | 1 - 7 files changed, 1 insertion(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 9ad40cf4..ca0ee960 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Custom .vscode/ +.itea/ data/ logs/ lightning_logs/ diff --git a/experiments/classification/mnist/configs/bayesian_lenet.yaml b/experiments/classification/mnist/configs/bayesian_lenet.yaml index 55f6b3c6..70f5cf8e 100644 --- a/experiments/classification/mnist/configs/bayesian_lenet.yaml +++ b/experiments/classification/mnist/configs/bayesian_lenet.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_samples: 16 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet.yaml b/experiments/classification/mnist/configs/lenet.yaml index 0c7989ab..3f8b63c2 100644 --- a/experiments/classification/mnist/configs/lenet.yaml +++ b/experiments/classification/mnist/configs/lenet.yaml @@ -38,7 +38,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} num_classes: 10 loss: CrossEntropyLoss diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index c5398a87..c387ef47 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} save_schedule: - 20 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 363461c6..2ea72001 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} momentum: 0.99 num_classes: 10 diff --git a/experiments/classification/mnist/configs/lenet_swa.yaml b/experiments/classification/mnist/configs/lenet_swa.yaml index 2274bdb5..09d7d506 100644 --- a/experiments/classification/mnist/configs/lenet_swa.yaml +++ b/experiments/classification/mnist/configs/lenet_swa.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 19 cycle_length: 5 diff --git a/experiments/classification/mnist/configs/lenet_swag.yaml b/experiments/classification/mnist/configs/lenet_swag.yaml index ddff0067..e33d954f 100644 --- a/experiments/classification/mnist/configs/lenet_swag.yaml +++ b/experiments/classification/mnist/configs/lenet_swag.yaml @@ -41,7 +41,6 @@ model: norm: torch.nn.Identity groups: 1 dropout_rate: 0 - last_layer_dropout: false layer_args: {} cycle_start: 10 cycle_length: 5 From 11801da484c0dc3651f7dfb740e5df2c52a88afc Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 12:06:27 +0100 Subject: [PATCH 19/68] :fire: Remove shift-severity in MNISTC --- torch_uncertainty/datasets/classification/mnist_c.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index aeffaa29..e61071d8 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -71,7 +71,6 @@ def __init__( target_transform: Callable | None = None, split: Literal["train", "test"] = "test", subset: str = "all", - shift_severity: int = 1, download: bool = False, ) -> None: self.root = Path(root) @@ -92,12 +91,6 @@ def __init__( raise ValueError(f"The subset '{subset}' does not exist in MNIST-C.") self.subset = subset - self.shift_severity = shift_severity - if shift_severity not in list(range(1, 6)): - raise ValueError( - "Corruptions shift_severity should be chosen between 1 and 5 " "included." - ) - if split not in ["train", "test"]: raise ValueError(f"The split '{split}' should be either 'train' or 'test'.") self.split = split From 16cfb44d8286f95921869f7f3b63bf3a6ac185c0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 12:06:55 +0100 Subject: [PATCH 20/68] :book: Update test dataloaders docstring --- torch_uncertainty/datamodules/classification/cifar10.py | 3 +-- torch_uncertainty/datamodules/classification/cifar100.py | 4 ++-- torch_uncertainty/datamodules/classification/imagenet.py | 4 ++-- torch_uncertainty/datamodules/classification/mnist.py | 4 ++-- torch_uncertainty/datamodules/classification/tiny_imagenet.py | 4 ++-- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/torch_uncertainty/datamodules/classification/cifar10.py b/torch_uncertainty/datamodules/classification/cifar10.py index 3ee9c8a3..1e1441c2 100644 --- a/torch_uncertainty/datamodules/classification/cifar10.py +++ b/torch_uncertainty/datamodules/classification/cifar10.py @@ -228,8 +228,7 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or CIFAR-10C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/cifar100.py b/torch_uncertainty/datamodules/classification/cifar100.py index 2717cc12..11d0a7fa 100644 --- a/torch_uncertainty/datamodules/classification/cifar100.py +++ b/torch_uncertainty/datamodules/classification/cifar100.py @@ -210,8 +210,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, SVHN data, and/or + CIFAR-100C data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index d436fd16..24fd5c6d 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -289,8 +289,8 @@ def test_dataloader(self) -> list[DataLoader]: """Get the test dataloaders for ImageNet. Return: - list[DataLoader]: ImageNet test set (in distribution data) and - Textures test split (out-of-distribution data). + list[DataLoader]: ImageNet test set (in distribution data), OOD dataset test split + (out-of-distribution data), and/or ImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/mnist.py b/torch_uncertainty/datamodules/classification/mnist.py index d50421bf..f6879c1a 100644 --- a/torch_uncertainty/datamodules/classification/mnist.py +++ b/torch_uncertainty/datamodules/classification/mnist.py @@ -171,8 +171,8 @@ def test_dataloader(self) -> list[DataLoader]: Return: list[DataLoader]: Dataloaders of the MNIST test set (in - distribution data), MNISTC (shifted data), and FashionMNIST test split - (out-of-distribution data). + distribution data), FashionMNIST or NotMNIST test split + (out-of-distribution data), and/or MNISTC (shifted data). """ dataloader = [self._data_loader(self.test)] if self.eval_ood: diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 3c3c7ec4..bf95159e 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -237,8 +237,8 @@ def test_dataloader(self) -> list[DataLoader]: r"""Get test dataloaders for TinyImageNet. Return: - list[DataLoader]: test set for in distribution data - and out-of-distribution data. + list[DataLoader]: test set for in distribution data, OOD data, and/or + TinyImageNetC data. """ dataloader = [self._data_loader(self.test)] if self.eval_ood: From 9d761f1c699486221ab88da369bf430ef2fd8831 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 12:08:09 +0100 Subject: [PATCH 21/68] :hammer: use datamodule's shift-severity --- torch_uncertainty/datamodules/abstract.py | 2 ++ torch_uncertainty/routines/classification.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 0dc88033..40168429 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -15,6 +15,8 @@ class TUDataModule(ABC, LightningDataModule): val: Dataset test: Dataset + shift_severity = 1 + def __init__( self, root: str | Path, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index bb616c58..8487f93f 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -606,9 +606,7 @@ def on_test_epoch_end(self) -> None: if self.eval_shift: tmp_metrics = self.test_shift_metrics.compute() - shift_severity = self.trainer.test_dataloaders[ - 2 if self.eval_ood else 1 - ].dataset.shift_severity + shift_severity = self.trainer.datamodule.shift_severity tmp_metrics["shift/shift_severity"] = shift_severity self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) From 3a186b835b554bf431cfd46462b4dd3ae970800a Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 13:05:52 +0100 Subject: [PATCH 22/68] :white_check_mark: Fix coverage --- tests/datamodules/classification/test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index 7a967edc..73304283 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -47,6 +47,7 @@ def test_mnist_cutout(self): dm.setup("other") dm.eval_ood = True + dm.eval_shift = True dm.ood_transform = dm.test_transform dm.val_split = 0.1 dm.prepare_data() From d4d09f5806f59ac713da4e2387a65c4a13505688 Mon Sep 17 00:00:00 2001 From: Olivier Laurent <62881275+o-laurent@users.noreply.github.com> Date: Fri, 7 Mar 2025 13:16:14 +0100 Subject: [PATCH 23/68] :fire: Remove documentation for removed argument --- torch_uncertainty/datasets/classification/mnist_c.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index e61071d8..6d8086cb 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -23,7 +23,6 @@ class MNISTC(VisionDataset): takes in the target and transforms it. Defaults to None. subset (str): The subset to use, one of ``all`` or the keys in ``mnistc_subsets``. - shift_severity (int): The shift_severity of the corruption, between 1 and 5. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Defaults to False. From 6df09149ac69ab3277feadb902f950f1a82b6610 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 11:43:49 +0000 Subject: [PATCH 24/68] :bug: Add missing learning rate scheduler class paths to lenet configs --- .../mnist/configs/lenet_checkpoint_ensemble.yaml | 10 ++++++---- .../classification/mnist/configs/lenet_ema.yaml | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml index c5398a87..10ea9fe8 100644 --- a/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_checkpoint_ensemble.yaml @@ -67,7 +67,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_ema.yaml b/experiments/classification/mnist/configs/lenet_ema.yaml index 363461c6..713bf607 100644 --- a/experiments/classification/mnist/configs/lenet_ema.yaml +++ b/experiments/classification/mnist/configs/lenet_ema.yaml @@ -55,7 +55,9 @@ optimizer: weight_decay: 5e-4 nesterov: true lr_scheduler: - milestones: - - 25 - - 50 - gamma: 0.1 + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 From 56ab6527239de90d3ab995dc7663766546642079 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 12:05:42 +0000 Subject: [PATCH 25/68] :sparkles: Add BatchEnsemble wrapper --- .../models/wrappers/batch_ensemble.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 torch_uncertainty/models/wrappers/batch_ensemble.py diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py new file mode 100644 index 00000000..d62adcb8 --- /dev/null +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -0,0 +1,34 @@ +import torch +from torch import nn + +class BatchEnsemble(nn.Module): + """Wraps a BatchEnsemble model to ensure correct batch replication. + + In a BatchEnsemble architecture, each estimator operates on a **sub-batch** + of the input. This means that the input batch must be **repeated** + :attr:`num_estimators` times before being processed. + + This wrapper automatically **duplicates the input batch** along the first axis, + ensuring that each estimator receives the correct data format. + + **Usage Example:** + ```python + model = lenet(in_channels=1, num_classes=10) + wrapped_model = BatchEnsembleWrapper(model, num_estimators=5) + logits = wrapped_model(x) # `x` is automatically repeated `num_estimators` times + ``` + + Args: + model (nn.Module): The BatchEnsemble model. + num_estimators (int): Number of ensemble members. + """ + + def __init__(self, model: nn.Module, num_estimators: int): + super().__init__() + self.model = model + self.num_estimators = num_estimators + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Repeats the input batch and passes it through the model.""" + x = x.repeat(self.num_estimators, 1, 1, 1) + return self.model(x) From 918de33527e9fa4fff9c1897d5a780d68c773c20 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 12:06:58 +0000 Subject: [PATCH 26/68] :books: Update documentation regarding (batch) ensemble usage --- torch_uncertainty/layers/batch_ensemble.py | 14 ++++++++++++-- torch_uncertainty/routines/classification.py | 10 +++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index dde72139..3f15ac21 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -79,7 +79,13 @@ def __init__( :math:`H_{out} = \text{out_features}`. Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_features` when calling :func:`forward()`. + Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. + In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` + times along the first axis. Incorrect batch size may lead to unexpected results. + + To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically + repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. + Examples: >>> # With three estimators @@ -273,8 +279,12 @@ def __init__( {\text{stride}[1]} + 1\right\rfloor Warning: - Make sure that :attr:`num_estimators` divides :attr:`out_channels` when calling :func:`forward()`. + Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. + In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` + times along the first axis. Incorrect batch size may lead to unexpected results. + To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically + repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. Examples: >>> # With square kernels, four estimators and equal stride diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index bb616c58..82189fbf 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -123,6 +123,14 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the Lightning CLI. + Warning: + When using an ensemble model, you must: + 1. Set :attr:`is_ensemble` to ``True``. + 2. Set :attr:`format_batch_fn` to :class:`torch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators)`. + 3. Ensure that the model's forward pass outputs a tensor of shape :math:`(M \times B, C)`, where :math:`M` is the number of estimators, :math:`B` is the batch size, :math:`C` is the number of classes. + + For automated batch handling, consider using the available model wrappers in `torch_uncertainty.models.wrappers`. + Note: :attr:`optim_recipe` can be anything that can be returned by :meth:`LightningModule.configure_optimizers()`. Find more details @@ -475,7 +483,7 @@ def test_step( """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) - logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) + logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) probs_per_est = torch.sigmoid(logits) if self.binary_cls else F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) confs = probs.max(-1)[0] From c5b62d8bc9e6c4b652b993daa21b27a7db7fc88d Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 12:09:32 +0000 Subject: [PATCH 27/68] :sparkles: Add LeNet BatchEnsemble and Deep Ensemble --- .../mnist/configs/lenet_batch_ensemble.yaml | 67 ++++++++++++++++ .../mnist/configs/lenet_deep_ensemble.yaml | 78 +++++++++++++++++++ torch_uncertainty/models/lenet.py | 28 +++++++ 3 files changed, 173 insertions(+) create mode 100644 experiments/classification/mnist/configs/lenet_batch_ensemble.yaml create mode 100644 experiments/classification/mnist/configs/lenet_deep_ensemble.yaml diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml new file mode 100644 index 00000000..0d2a94ee --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -0,0 +1,67 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: batch_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # BatchEnsemble + class_path: torch_uncertainty.models.lenet.batchensemble_lenet + init_args: + in_channels: 1 + num_classes: 10 + num_estimators: 5 + activation: torch.nn.ReLU + norm: torch.nn.BatchNorm2d + groups: 1 + dropout_rate: 0 + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml new file mode 100644 index 00000000..1d47b782 --- /dev/null +++ b/experiments/classification/mnist/configs/lenet_deep_ensemble.yaml @@ -0,0 +1,78 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + fast_dev_run: false + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 10 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/lenet_trajectory + name: deep_ensemble + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + # ClassificationRoutine + model: + # DeepEnsemble + class_path: torch_uncertainty.models.wrappers.deep_ensembles.deep_ensembles + init_args: + models: + # LeNet + class_path: torch_uncertainty.models.lenet._LeNet + init_args: + in_channels: 1 + num_classes: 10 + linear_layer: torch.nn.Linear + conv2d_layer: torch.nn.Conv2d + activation: torch.nn.ReLU + norm: torch.nn.Identity + groups: 1 + dropout_rate: 0 + # last_layer_dropout: false + layer_args: {} + num_estimators: 5 + task: classification + probabilistic: false + reset_model_parameters: true + num_classes: 10 + loss: CrossEntropyLoss + is_ensemble: true + format_batch_fn: + class_path: torch_uncertainty.transforms.batch.RepeatTarget + init_args: + num_repeats: 5 +data: + root: ./data + batch_size: 128 + num_workers: 127 + eval_ood: true + eval_shift: true +optimizer: + lr: 0.05 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 25 + - 50 + gamma: 0.1 diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index 4b76c9e6..aca916ee 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -5,10 +5,12 @@ import torch.nn.functional as F from torch import nn +from torch_uncertainty.layers.batch_ensemble import BatchConv2d, BatchLinear from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear from torch_uncertainty.models import StochasticModel +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble __all__ = ["bayesian_lenet", "lenet", "packed_lenet"] @@ -119,6 +121,32 @@ def lenet( ) +def batchensemble_lenet( + in_channels: int, + num_classes: int, + num_estimators: int = 4, + activation: Callable = F.relu, + norm: type[nn.Module] = nn.BatchNorm2d, + groups: int = 1, + dropout_rate: float = 0.0, +) -> _LeNet: + model = _lenet( + stochastic=False, + in_channels=in_channels, + num_classes=num_classes, + linear_layer=BatchLinear, + conv2d_layer=BatchConv2d, + layer_args={ + "num_estimators": num_estimators, + }, + activation=activation, + norm=norm, + groups=groups, + dropout_rate=dropout_rate, + ) + return BatchEnsemble(model, num_estimators) + + def packed_lenet( in_channels: int, num_classes: int, From 940f721df13a04b7a67bd0e17f43e2b39efe6171 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 13:26:16 +0100 Subject: [PATCH 28/68] :book: State support of Python 3.13 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f3edcd96..da5063cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dependencies = [ "timm", From 0a6e32fe961a2119c8658f3dd81ddda48317d9a8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 13:29:31 +0100 Subject: [PATCH 29/68] :bug: Make sklearn optional. Fix #136 --- torch_uncertainty/datamodules/abstract.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 40168429..f4e548a4 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -1,10 +1,18 @@ from abc import ABC, abstractmethod +from importlib import util from pathlib import Path from typing import Literal from lightning.pytorch.core import LightningDataModule from numpy.typing import ArrayLike -from sklearn.model_selection import StratifiedKFold + +if util.find_spec("sklearn"): + from sklearn.model_selection import StratifiedKFold + + sklearn_installed = True +else: # coverage: ignore + sklearn_installed = False + from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler @@ -122,7 +130,11 @@ def _get_train_targets(self) -> ArrayLike: raise NotImplementedError def make_cross_val_splits(self, n_splits: int = 10, train_over: int = 4) -> list: + if not sklearn_installed: + raise ImportError("Please install scikit-learn to use Cross-validation splits.") + self.setup("fit") + skf = StratifiedKFold(n_splits) cv_dm = [] From b3c16875498fd2a8a510b5740bd93cf611fbfd44 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 13:38:38 +0100 Subject: [PATCH 30/68] :shirt: Refine dependencies and import error messages --- pyproject.toml | 4 ++-- torch_uncertainty/datamodules/abstract.py | 6 +++++- torch_uncertainty/metrics/sparsification.py | 6 +++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index da5063cf..f5cbd2ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,9 @@ image = [ ] tabular = ["pandas"] dev = [ - "scikit-learn", - "huggingface-hub", "torch_uncertainty[image]", + "huggingface-hub", + "safetensors", "ruff==0.7.4", "pytest-cov", "pre-commit", diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index f4e548a4..fdaddd14 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -131,7 +131,11 @@ def _get_train_targets(self) -> ArrayLike: def make_cross_val_splits(self, n_splits: int = 10, train_over: int = 4) -> list: if not sklearn_installed: - raise ImportError("Please install scikit-learn to use Cross-validation splits.") + raise ImportError( + "Please install torch_uncertainty with the image option" + "to use crossval:" + """pip install -U "torch_uncertainty[image]".""" + ) self.setup("fit") diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/sparsification.py index 8977b7e0..3f169efb 100644 --- a/torch_uncertainty/metrics/sparsification.py +++ b/torch_uncertainty/metrics/sparsification.py @@ -57,7 +57,11 @@ def __init__(self, **kwargs) -> None: self.add_state("errors", default=[], dist_reduce_fx="cat") if not sklearn_installed: - raise ImportError("Please install scikit-learn to use AUSE.") + raise ImportError( + "Please install torch_uncertainty with the image option" + "to use the AUSE:" + """pip install -U "torch_uncertainty[image]".""" + ) def update(self, scores: Tensor, errors: Tensor) -> None: """Store the scores and their associated errors for later computation. From 1b0ce535d4462e19ff83054466410c8ccefcff37 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 13:52:36 +0100 Subject: [PATCH 31/68] :zap: Update ruff and lint --- pyproject.toml | 2 +- .../datasets/classification/cifar/cifar_c.py | 2 +- .../datasets/classification/cifar/cifar_h.py | 6 +- .../datasets/classification/cifar/cifar_n.py | 4 +- .../datasets/classification/cub.py | 2 +- .../datasets/classification/imagenet/base.py | 2 +- .../datasets/classification/not_mnist.py | 2 +- torch_uncertainty/datasets/fractals.py | 2 +- torch_uncertainty/datasets/frost.py | 2 +- torch_uncertainty/datasets/muad.py | 2 +- .../datasets/regression/uci_regression.py | 2 +- .../datasets/segmentation/camvid.py | 4 +- torch_uncertainty/layers/batch_ensemble.py | 2 +- .../layers/bayesian/bayes_conv.py | 4 +- .../layers/filter_response_norm.py | 4 +- torch_uncertainty/layers/functional/packed.py | 66 +++++++++---------- torch_uncertainty/layers/packed.py | 12 ++-- torch_uncertainty/losses/bayesian.py | 6 +- torch_uncertainty/losses/classification.py | 12 ++-- torch_uncertainty/losses/regression.py | 4 +- .../metrics/classification/fpr.py | 2 +- .../post_processing/calibration/scaler.py | 2 +- .../post_processing/mc_batch_norm.py | 2 +- torch_uncertainty/routines/classification.py | 6 +- .../routines/pixel_regression.py | 4 +- torch_uncertainty/routines/regression.py | 4 +- torch_uncertainty/routines/segmentation.py | 2 +- torch_uncertainty/transforms/corruption.py | 26 ++++---- torch_uncertainty/utils/distributions.py | 4 +- 29 files changed, 96 insertions(+), 98 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f5cbd2ab..661ad447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dev = [ "torch_uncertainty[image]", "huggingface-hub", "safetensors", - "ruff==0.7.4", + "ruff==0.9.9", "pytest-cov", "pre-commit", "pre-commit-hooks", diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index ebafe7f8..b9020b8c 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -111,7 +111,7 @@ def __init__( if shift_severity not in list(range(1, 6)): raise ValueError( - "Corruptions shift_severity should be chosen between 1 and 5 " "included." + "Corruptions shift_severity should be chosen between 1 and 5 included." ) samples, labels = self.make_dataset(self.root, self.subset, self.shift_severity) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index f8354a0b..4cbbecb1 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -27,7 +27,7 @@ class CIFAR10H(CIFAR10): """ h_test_list = ["cifar-10h-probs.npy", "7b41f73eee90fdefc73bfc820ab29ba8"] - h_url = "https://github.com/jcpeterson/cifar-10h/raw/master/data/" "cifar10h-probs.npy" + h_url = "https://github.com/jcpeterson/cifar-10h/raw/master/data/cifar10h-probs.npy" def __init__( self, @@ -39,7 +39,7 @@ def __init__( ) -> None: if train: raise ValueError("CIFAR10H does not support training data.") - print("WARNING: CIFAR10H cannot be used with Classification routines " "for now.") + print("WARNING: CIFAR10H cannot be used within Classification routines for now.") super().__init__( Path(root), train=False, @@ -53,7 +53,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) self.targets = list(torch.as_tensor(np.load(self.root / self.h_test_list[0]))) diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_n.py b/torch_uncertainty/datasets/classification/cifar/cifar_n.py index 6f6f8c95..e8193d68 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_n.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_n.py @@ -61,7 +61,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) self.targets = list(torch.load(self.root / self.filename)[file_arg]) @@ -112,7 +112,7 @@ def __init__( if not self._check_specific_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) self.targets = list(torch.load(self.root / self.filename)[file_arg]) diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py index 079d20ec..38bacf70 100644 --- a/torch_uncertainty/datasets/classification/cub.py +++ b/torch_uncertainty/datasets/classification/cub.py @@ -51,7 +51,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) diff --git a/torch_uncertainty/datasets/classification/imagenet/base.py b/torch_uncertainty/datasets/classification/imagenet/base.py index 722955cf..4e5ed41a 100644 --- a/torch_uncertainty/datasets/classification/imagenet/base.py +++ b/torch_uncertainty/datasets/classification/imagenet/base.py @@ -52,7 +52,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. " "You can use download=True to download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/classification/not_mnist.py b/torch_uncertainty/datasets/classification/not_mnist.py index e0b28ae0..ff3e364b 100644 --- a/torch_uncertainty/datasets/classification/not_mnist.py +++ b/torch_uncertainty/datasets/classification/not_mnist.py @@ -60,7 +60,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/fractals.py b/torch_uncertainty/datasets/fractals.py index 329a2086..8153f2a0 100644 --- a/torch_uncertainty/datasets/fractals.py +++ b/torch_uncertainty/datasets/fractals.py @@ -40,7 +40,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__(self.root, transform=transform, target_transform=target_transform) diff --git a/torch_uncertainty/datasets/frost.py b/torch_uncertainty/datasets/frost.py index dbde6d89..6aa069bb 100644 --- a/torch_uncertainty/datasets/frost.py +++ b/torch_uncertainty/datasets/frost.py @@ -44,7 +44,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to " "download it." + "Dataset not found or corrupted. You can use download=True to download it." ) super().__init__( diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 73ea2ba7..cc93c3d9 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -232,7 +232,7 @@ def _make_dataset(self, path: Path) -> None: """ if "depth" in path.name: raise NotImplementedError( - "Depth mode is not implemented yet. Raise an issue " "if you need it." + "Depth mode is not implemented yet. Raise an issue if you need it." ) self.samples = sorted((path / "leftImg8bit/").glob("**/*")) if self.target_type == "semantic": diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index a2c50bc3..502b9a1e 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -191,7 +191,7 @@ def download(self) -> None: logging.info("Files already downloaded and verified") return if self.url is None: - raise ValueError(f"The dataset {self.dataset_name} is not available for " "download.") + raise ValueError(f"The dataset {self.dataset_name} is not available for download.") download_root = self.root / self.root_appendix / self.dataset_name if self.dataset_name == "boston": download_url( diff --git a/torch_uncertainty/datasets/segmentation/camvid.py b/torch_uncertainty/datasets/segmentation/camvid.py index 937b83f7..ec9b8e4c 100644 --- a/torch_uncertainty/datasets/segmentation/camvid.py +++ b/torch_uncertainty/datasets/segmentation/camvid.py @@ -133,7 +133,7 @@ def __init__( """ if split not in ["train", "val", "test", None]: raise ValueError( - f"Unknown split '{split}'. " "Supported splits are ['train', 'val', 'test', None]" + f"Unknown split '{split}'. Supported splits are ['train', 'val', 'test', None]" ) super().__init__(root, transforms, None, None) @@ -153,7 +153,7 @@ def __init__( if not self._check_integrity(): raise RuntimeError( - "Dataset not found or corrupted. " "You can use download=True to download it" + "Dataset not found or corrupted. You can use download=True to download it" ) # get filenames for split diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index dde72139..a1ce068a 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -143,7 +143,7 @@ def forward(self, inputs: Tensor) -> Tensor: def extra_repr(self) -> str: return ( - f"in_features={ self.in_features}," + f"in_features={self.in_features}," f" out_features={self.out_features}," f" num_estimators={self.num_estimators}," f" bias={self.bias is not None}" diff --git a/torch_uncertainty/layers/bayesian/bayes_conv.py b/torch_uncertainty/layers/bayesian/bayes_conv.py index e328aa39..f1277ed4 100644 --- a/torch_uncertainty/layers/bayesian/bayes_conv.py +++ b/torch_uncertainty/layers/bayesian/bayes_conv.py @@ -89,7 +89,7 @@ def __init__( if transposed: raise NotImplementedError( - "Bayesian transposed convolution not implemented yet. Raise an" " issue if needed." + "Bayesian transposed convolution not implemented yet. Raise an issue if needed." ) self.in_channels = in_channels @@ -164,7 +164,7 @@ def sample(self) -> tuple[Tensor, Tensor | None]: return weight, bias def extra_repr(self) -> str: # coverage: ignore - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" ", stride={stride}" + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" if self.padding != (0,) * len(self.padding): s += ", padding={padding}" if self.dilation != (1,) * len(self.dilation): diff --git a/torch_uncertainty/layers/filter_response_norm.py b/torch_uncertainty/layers/filter_response_norm.py index ef306a8f..0f22717d 100644 --- a/torch_uncertainty/layers/filter_response_norm.py +++ b/torch_uncertainty/layers/filter_response_norm.py @@ -23,13 +23,13 @@ def __init__( super().__init__() if dimension < 1 or not isinstance(dimension, int): raise ValueError( - "dimension should be an integer greater or equal than 1. " f"got {dimension}." + f"dimension should be an integer greater or equal than 1. Got {dimension}." ) self.dimension = dimension if num_channels < 1 or not isinstance(num_channels, int): raise ValueError( - "num_channels should be an integer greater or equal than 1. " f"got {num_channels}." + f"num_channels should be an integer greater or equal than 1. Got {num_channels}." ) shape = (1, num_channels) + (1,) * dimension self.eps = eps diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index c962531e..a5ab40f9 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -79,15 +79,15 @@ def packed_in_projection( emb_q // num_groups, emb_v // num_groups, ), f"expecting value weights shape of {(emb_q, emb_v)}, but got {w_v.shape}" - assert b_q is None or b_q.shape == ( - emb_q, - ), f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" - assert b_k is None or b_k.shape == ( - emb_q, - ), f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" - assert b_v is None or b_v.shape == ( - emb_q, - ), f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" + assert b_q is None or b_q.shape == (emb_q,), ( + f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" + ) + assert b_k is None or b_k.shape == (emb_q,), ( + f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" + ) + assert b_v is None or b_v.shape == (emb_q,), ( + f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" + ) return ( packed_linear(q, w_q, num_groups, implementation, b_q), @@ -324,47 +324,47 @@ def packed_multi_head_attention_forward( # noqa: D417 # longer causal. is_causal = False - assert ( - embed_dim == embed_dim_to_check - ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + assert embed_dim == embed_dim_to_check, ( + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + ) if isinstance(embed_dim, Tensor): # embed_dim can be a tensor when JIT tracing head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + assert head_dim * num_heads == embed_dim, ( + f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + ) if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert ( - key.shape[:2] == value.shape[:2] - ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + assert key.shape[:2] == value.shape[:2], ( + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + ) else: - assert ( - key.shape == value.shape - ), f"key shape {key.shape} does not match value shape {value.shape}" + assert key.shape == value.shape, ( + f"key shape {key.shape} does not match value shape {value.shape}" + ) # # compute in-projection # if not use_separate_proj_weight: - assert ( - in_proj_weight is not None - ), "use_separate_proj_weight is False but in_proj_weight is None" + assert in_proj_weight is not None, ( + "use_separate_proj_weight is False but in_proj_weight is None" + ) q, k, v = packed_in_projection_packed( q=query, k=key, v=value, w=in_proj_weight, num_groups=num_groups, b=in_proj_bias ) else: - assert ( - q_proj_weight is not None - ), "use_separate_proj_weight is True but q_proj_weight is None" - assert ( - k_proj_weight is not None - ), "use_separate_proj_weight is True but k_proj_weight is None" - assert ( - v_proj_weight is not None - ), "use_separate_proj_weight is True but v_proj_weight is None" + assert q_proj_weight is not None, ( + "use_separate_proj_weight is True but q_proj_weight is None" + ) + assert k_proj_weight is not None, ( + "use_separate_proj_weight is True but k_proj_weight is None" + ) + assert v_proj_weight is not None, ( + "use_separate_proj_weight is True but v_proj_weight is None" + ) if in_proj_bias is None: b_q = b_k = b_v = None else: diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 71537b82..647390a7 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -33,11 +33,9 @@ def check_packed_parameters_consistency(alpha: float, gamma: int, num_estimators if num_estimators is None: raise ValueError("You must specify the value of the arg. `num_estimators`") if not isinstance(num_estimators, int): - raise TypeError( - "Attribute `num_estimators` should be an int, not " f"{type(num_estimators)}" - ) + raise TypeError(f"Attribute `num_estimators` should be an int, not {type(num_estimators)}") if num_estimators <= 0: - raise ValueError("Attribute `num_estimators` should be >= 1, not " f"{num_estimators}") + raise ValueError(f"Attribute `num_estimators` should be >= 1, not {num_estimators}") class PackedLinear(nn.Module): @@ -697,9 +695,9 @@ def __init__( self.dropout = dropout self.batch_first = batch_first self.head_dim = self.embed_dim // self.num_heads - assert ( - self.head_dim * self.num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert self.head_dim * self.num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) self.num_estimators = num_estimators self.alpha = alpha diff --git a/torch_uncertainty/losses/bayesian.py b/torch_uncertainty/losses/bayesian.py index 3e0e632e..4af6d77b 100644 --- a/torch_uncertainty/losses/bayesian.py +++ b/torch_uncertainty/losses/bayesian.py @@ -98,12 +98,12 @@ def set_model(self, model: nn.Module | None) -> None: def _elbo_loss_checks(inner_loss: nn.Module, kl_weight: float, num_samples: int) -> None: if isinstance(inner_loss, type): - raise TypeError("The inner_loss should be an instance of a class." f"Got {inner_loss}.") + raise TypeError(f"The inner_loss should be an instance of a class.Got {inner_loss}.") if kl_weight < 0: raise ValueError(f"The KL weight should be non-negative. Got {kl_weight}.") if num_samples < 1: - raise ValueError("The number of samples should not be lower than 1." f"Got {num_samples}.") + raise ValueError(f"The number of samples should not be lower than 1. Got {num_samples}.") if not isinstance(num_samples, int): - raise TypeError("The number of samples should be an integer. " f"Got {type(num_samples)}.") + raise TypeError(f"The number of samples should be an integer. Got {type(num_samples)}.") diff --git a/torch_uncertainty/losses/classification.py b/torch_uncertainty/losses/classification.py index 82c48280..96769338 100644 --- a/torch_uncertainty/losses/classification.py +++ b/torch_uncertainty/losses/classification.py @@ -31,12 +31,12 @@ def __init__( if reg_weight is not None and (reg_weight < 0): raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight if annealing_step is not None and (annealing_step <= 0): - raise ValueError("The annealing step should be positive, but got " f"{annealing_step}.") + raise ValueError(f"The annealing step should be positive, but got {annealing_step}.") self.annealing_step = annealing_step if reduction not in ("none", "mean", "sum") and reduction is not None: @@ -178,11 +178,11 @@ def __init__( self.reduction = reduction if eps < 0: - raise ValueError("The epsilon value should be non-negative, but got " f"{eps}.") + raise ValueError(f"The epsilon value should be non-negative, but got {eps}.") self.eps = eps if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight @@ -233,7 +233,7 @@ def __init__( self.reduction = reduction if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight @@ -287,7 +287,7 @@ def __init__( if gamma < 0: raise ValueError( - "The gamma term of the focal loss should be non-negative, but got " f"{gamma}." + f"The gamma term of the focal loss should be non-negative, but got {gamma}." ) self.gamma = gamma diff --git a/torch_uncertainty/losses/regression.py b/torch_uncertainty/losses/regression.py index 888de286..479f7ff5 100644 --- a/torch_uncertainty/losses/regression.py +++ b/torch_uncertainty/losses/regression.py @@ -67,7 +67,7 @@ def __init__(self, reg_weight: float, reduction: str | None = "mean") -> None: if reg_weight < 0: raise ValueError( - "The regularization weight should be non-negative, but got " f"{reg_weight}." + f"The regularization weight should be non-negative, but got {reg_weight}." ) self.reg_weight = reg_weight @@ -114,7 +114,7 @@ def __init__(self, beta: float = 0.5, reduction: str | None = "mean") -> None: super().__init__() if beta < 0 or beta > 1: - raise ValueError("The beta parameter should be in range [0, 1], but got " f"{beta}.") + raise ValueError(f"The beta parameter should be in range [0, 1], but got {beta}.") self.beta = beta self.nll_loss = nn.GaussianNLLLoss(reduction="none") if reduction not in ("none", "mean", "sum"): diff --git a/torch_uncertainty/metrics/classification/fpr.py b/torch_uncertainty/metrics/classification/fpr.py index 53e3e779..d0779b44 100644 --- a/torch_uncertainty/metrics/classification/fpr.py +++ b/torch_uncertainty/metrics/classification/fpr.py @@ -35,7 +35,7 @@ def __init__(self, recall_level: float, pos_label: int, **kwargs) -> None: self.add_state("targets", [], dist_reduce_fx="cat") rank_zero_warn( - f"Metric `FPR{int(recall_level*100)}` will save all targets and predictions" + f"Metric `FPR{int(recall_level * 100)}` will save all targets and predictions" " in buffer. For large datasets this may lead to large memory" " footprint." ) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index a3e90eed..a5398df2 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -94,7 +94,7 @@ def calib_eval() -> float: def forward(self, inputs: Tensor) -> Tensor: if not self.trained: logging.error( - "TemperatureScaler has not been trained yet. Returning " "manually tempered inputs." + "TemperatureScaler has not been trained yet. Returning manually tempered inputs." ) return self._scale(self.model(inputs)) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index 5357d954..0a1d250d 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -168,4 +168,4 @@ def _mcbn_checks(model, num_estimators, mc_batch_size, convert): if mc_batch_size < 1 or not isinstance(mc_batch_size, int): raise ValueError(f"mc_batch_size must be a positive integer, got {mc_batch_size}.") if not convert and not has_mcbn(model): - raise ValueError("model does not contain any MCBatchNorm2d nor is not to be " "converted.") + raise ValueError("model does not contain any MCBatchNorm2d nor is not to be converted.") diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 8487f93f..7c752a11 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -717,7 +717,7 @@ def _classification_routine_checks( if not is_ensemble and ood_criterion in ["mi", "vr"]: raise ValueError( - "You cannot use mutual information or variation ratio with a single" " model." + "You cannot use mutual information or variation ratio with a single model." ) if is_ensemble and eval_grouping_loss: @@ -727,12 +727,12 @@ def _classification_routine_checks( if num_classes < 1: raise ValueError( - "The number of classes must be a positive integer >= 1." f"Got {num_classes}." + f"The number of classes must be a positive integer >= 1. Got {num_classes}." ) if eval_grouping_loss and not hasattr(model, "feats_forward"): raise ValueError( - "Your model must have a `feats_forward` method to compute the " "grouping loss." + "Your model must have a `feats_forward` method to compute the grouping loss." ) if eval_grouping_loss and not ( diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index 3fa450f2..92b9b9da 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -86,7 +86,7 @@ def __init__( _depth_routine_checks(output_dim, num_image_plot, log_plots) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue if needed." ) self.model = model @@ -294,7 +294,7 @@ def test_step( """ if dataloader_idx != 0: raise NotImplementedError( - "Depth OOD detection not implemented yet. Raise an issue " "if needed." + "Depth OOD detection not implemented yet. Raise an issue if needed." ) inputs, targets = batch if self.one_dim_depth: diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 07a43bb8..88514d39 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -76,7 +76,7 @@ def __init__( _regression_routine_checks(output_dim) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue if needed." ) self.model = model @@ -271,7 +271,7 @@ def test_step( """ if dataloader_idx != 0: raise NotImplementedError( - "Regression OOD detection not implemented yet. Raise an issue " "if needed." + "Regression OOD detection not implemented yet. Raise an issue if needed." ) inputs, targets = batch diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index ba08e367..fac59ecb 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -76,7 +76,7 @@ def __init__( ) if eval_shift: raise NotImplementedError( - "Distribution shift evaluation not implemented yet. Raise an issue " "if needed." + "Distribution shift evaluation not implemented yet. Raise an issue if needed." ) self.model = model diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index ffa5f491..2c9f59bd 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -66,24 +66,24 @@ def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): from .image import Saturation as ISaturation __all__ = [ - "GaussianNoise", - "ShotNoise", - "ImpulseNoise", - "DefocusBlur", - "GlassBlur", - "MotionBlur", - "ZoomBlur", - "Snow", - "Frost", - "Fog", "Brightness", "Contrast", + "DefocusBlur", "Elastic", - "Pixelate", - "JPEGCompression", + "Fog", + "Frost", "GaussianBlur", - "SpeckleNoise", + "GaussianNoise", + "GlassBlur", + "ImpulseNoise", + "JPEGCompression", + "MotionBlur", + "Pixelate", "Saturation", + "ShotNoise", + "Snow", + "SpeckleNoise", + "ZoomBlur", "corruption_transforms", ] diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index 3e115e93..860b96d5 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -33,7 +33,7 @@ def get_dist_class(dist_family: str) -> type[Distribution]: if dist_family == "student": return StudentT raise NotImplementedError( - f"{dist_family} distribution is not supported." "Raise an issue if needed." + f"{dist_family} distribution is not supported. Raise an issue if needed." ) @@ -52,7 +52,7 @@ def get_dist_estimate(dist: Distribution, dist_estimate: str) -> Tensor: if dist_estimate == "mode": return dist.mode raise NotImplementedError( - f"{dist_estimate} estimate is not supported." "Raise an issue if needed." + f"{dist_estimate} estimate is not supported.Raise an issue if needed." ) From 017d040ce2324624b54a5eb9faf48c1c06b5675c Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 14:09:04 +0100 Subject: [PATCH 32/68] :fire: Remove torchaudio from workflows --- .github/workflows/build-docs.yml | 2 +- .github/workflows/run-tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index f9fc48f3..f5c1236d 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -40,7 +40,7 @@ jobs: - name: Install dependencies run: | - python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[all] - name: Sphinx build diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 1c74dd5d..9b23e989 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -64,7 +64,7 @@ jobs: - name: Install dependencies if: steps.changed-files-specific.outputs.only_changed != 'true' run: | - python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu + python3 -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python3 -m pip install .[all] - name: Check style & format From a92385eab837c4302f4db44a40288e95c6ac7eb5 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 14:41:32 +0100 Subject: [PATCH 33/68] :shirt: Lint --- torch_uncertainty/models/wrappers/batch_ensemble.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index d62adcb8..88b2312b 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -1,6 +1,7 @@ import torch from torch import nn + class BatchEnsemble(nn.Module): """Wraps a BatchEnsemble model to ensure correct batch replication. From eda886b3aaff9880c4bbef965129c5233f39a3e7 Mon Sep 17 00:00:00 2001 From: Olivier Laurent <62881275+o-laurent@users.noreply.github.com> Date: Fri, 7 Mar 2025 14:43:28 +0100 Subject: [PATCH 34/68] :whale: Update Dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index cc89a22b..1fd3f3d8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,7 +21,7 @@ RUN touch README.md && mkdir -p torch_uncertainty && touch torch_uncertainty/__i COPY pyproject.toml /workspace/ # Install dependencies all dependencies -RUN pip install --no-cache-dir -e ".[dev]" +RUN pip install --no-cache-dir -e ".[all]" # Always activate Conda when opening a new terminal RUN echo "source /opt/conda/bin/activate" >> /root/.bashrc From ca73b4ba2ca2d5f7269d74a8303e6223483b7145 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 14:51:24 +0100 Subject: [PATCH 35/68] :shirt: Also format --- torch_uncertainty/models/wrappers/batch_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index 88b2312b..c6995ca8 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -18,7 +18,7 @@ class BatchEnsemble(nn.Module): wrapped_model = BatchEnsembleWrapper(model, num_estimators=5) logits = wrapped_model(x) # `x` is automatically repeated `num_estimators` times ``` - + Args: model (nn.Module): The BatchEnsemble model. num_estimators (int): Number of ensemble members. From 597691904fb1c1fa8f53b3c37b9dbfc604be94bc Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 15:46:55 +0000 Subject: [PATCH 36/68] :white_check_mark: Add test for BatchEnsemble wrapper and LeNet implementation --- tests/models/test_lenet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_lenet.py b/tests/models/test_lenet.py index 8519ffdf..c6a08180 100644 --- a/tests/models/test_lenet.py +++ b/tests/models/test_lenet.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch_uncertainty.models.lenet import bayesian_lenet, lenet, packed_lenet +from torch_uncertainty.models.lenet import batchensemble_lenet, bayesian_lenet, lenet, packed_lenet class TestLeNet: @@ -18,6 +18,7 @@ def test_main(self): model.eval() model(torch.randn(1, 1, 20, 20)) + batchensemble_lenet(1, 1) packed_lenet(1, 1) bayesian_lenet(1, 1) bayesian_lenet( From 96e3eed08a0f204191110693a22313181cdaedb8 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 7 Mar 2025 17:01:37 +0100 Subject: [PATCH 37/68] :bug: Fix `extended_out_features` computation in `PackedLinear` --- torch_uncertainty/layers/packed.py | 35 +++++++++++++++++------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 647390a7..ab9a2a61 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -71,8 +71,12 @@ def __init__( network. Defaults to ``False``. last (bool, optional): Whether this is the last layer of the network. Defaults to ``False``. - implementation (str, optional): The implementation to use. Defaults - to ``"legacy"``. + implementation (str, optional): The implementation to use. Available implementations: + + - ``"legacy"`` (default): The legacy implementation of the linear layer. + - ``"sparse"``: The sparse implementation of the linear layer. + - ``"full"``: The full implementation of the linear layer. + - ``"einsum"``: The einsum implementation of the linear layer. rearrange (bool, optional): Rearrange the input and outputs for compatibility with previous and later layers. Defaults to ``True``. device (torch.device, optional): The device to use for the layer's @@ -101,6 +105,13 @@ def __init__( default. """ check_packed_parameters_consistency(alpha, gamma, num_estimators) + + if implementation not in ["legacy", "sparse", "full", "einsum"]: + raise ValueError( + f"Unknown implementation: {implementation} for PackedLinear" + "Available implementations are: 'legacy', 'sparse', 'full', 'einsum'" + ) + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -119,16 +130,10 @@ def __init__( # fix if not divisible by groups if extended_in_features % actual_groups: extended_in_features += num_estimators - extended_in_features % (actual_groups) - if extended_out_features % actual_groups: - extended_out_features += num_estimators - extended_out_features % (actual_groups) - - # FIXME: This is a temporary check - assert implementation in [ - "legacy", - "sparse", - "full", - "einsum", - ], f"Unknown implementation: {implementation} for PackedLinear" + if extended_out_features % num_estimators * gamma: + extended_out_features += num_estimators - extended_out_features % ( + num_estimators * gamma + ) if self.implementation == "legacy": self.weight = nn.Parameter( @@ -695,9 +700,9 @@ def __init__( self.dropout = dropout self.batch_first = batch_first self.head_dim = self.embed_dim // self.num_heads - assert self.head_dim * self.num_heads == self.embed_dim, ( - "embed_dim must be divisible by num_heads" - ) + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" self.num_estimators = num_estimators self.alpha = alpha From c001ea686be08d22b688f0624cfdf8429c340998 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 7 Mar 2025 17:27:42 +0100 Subject: [PATCH 38/68] :white_check_mark: Update PackedLinear tests --- tests/layers/test_packed.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index cfcee746..e30e7cd3 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -196,7 +196,23 @@ def test_linear_einsum_implementation( assert out.shape == torch.Size([1, 2, 3, 4, 2]) def test_linear_extend(self): - _ = PackedConv2d(5, 3, kernel_size=1, alpha=1, num_estimators=2, gamma=1) + layer = PackedLinear(5, 3, alpha=1, num_estimators=2, gamma=1, implementation="legacy") + assert layer.weight.shape == torch.Size([4, 3, 1]) + assert layer.bias.shape == torch.Size([4]) + layer = PackedLinear(5, 3, alpha=1, num_estimators=2, gamma=1, implementation="full") + assert layer.weight.shape == torch.Size([2, 2, 3]) + assert layer.bias.shape == torch.Size([4]) + # with first=True + layer = PackedLinear( + 5, 3, alpha=1, num_estimators=2, gamma=1, implementation="legacy", first=True + ) + assert layer.weight.shape == torch.Size([4, 5, 1]) + assert layer.bias.shape == torch.Size([4]) + layer = PackedLinear( + 5, 3, alpha=1, num_estimators=2, gamma=1, implementation="full", first=True + ) + assert layer.weight.shape == torch.Size([1, 4, 5]) + assert layer.bias.shape == torch.Size([4]) def test_linear_failures(self): with pytest.raises(ValueError): @@ -220,7 +236,7 @@ def test_linear_failures(self): with pytest.raises(ValueError): _ = PackedLinear(5, 2, alpha=1, num_estimators=1, gamma=-1, rearrange=True) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): _ = PackedLinear( 5, 2, From 22428050954ee7a7dcd0d1d8d631661419b3b275 Mon Sep 17 00:00:00 2001 From: Anton Zamyatin Date: Fri, 7 Mar 2025 17:40:00 +0100 Subject: [PATCH 39/68] :whale: Install pre-commit hooks on container start --- README.md | 1 - docker/entrypoint.sh | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e339c989..e0f735a4 100644 --- a/README.md +++ b/README.md @@ -100,4 +100,3 @@ The following projects use TorchUncertainty: - *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) **If you are using TorchUncertainty in your project, please let us know, and we will add your project to this list!** - diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index ffef3e62..1c1af5e5 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -121,6 +121,10 @@ else echo "✅ torch_uncertainty is now installed in editable mode! 🚀" fi +# Activate pre-commit hooks (if enabled) +echo "🔗 Activating pre-commit hooks..." +pre-commit install + # Ensure SSH server is started echo "🔑 Starting SSH server..." mkdir -p /run/sshd && chmod 755 /run/sshd From 145181abad20a49d10c952e4c84214bc07ececc5 Mon Sep 17 00:00:00 2001 From: alafage Date: Fri, 7 Mar 2025 17:45:40 +0100 Subject: [PATCH 40/68] :art: Format PackedLinear --- torch_uncertainty/layers/functional/packed.py | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index a5ab40f9..c962531e 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -79,15 +79,15 @@ def packed_in_projection( emb_q // num_groups, emb_v // num_groups, ), f"expecting value weights shape of {(emb_q, emb_v)}, but got {w_v.shape}" - assert b_q is None or b_q.shape == (emb_q,), ( - f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" - ) - assert b_k is None or b_k.shape == (emb_q,), ( - f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" - ) - assert b_v is None or b_v.shape == (emb_q,), ( - f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" - ) + assert b_q is None or b_q.shape == ( + emb_q, + ), f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == ( + emb_q, + ), f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == ( + emb_q, + ), f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" return ( packed_linear(q, w_q, num_groups, implementation, b_q), @@ -324,47 +324,47 @@ def packed_multi_head_attention_forward( # noqa: D417 # longer causal. is_causal = False - assert embed_dim == embed_dim_to_check, ( - f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" - ) + assert ( + embed_dim == embed_dim_to_check + ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" if isinstance(embed_dim, Tensor): # embed_dim can be a tensor when JIT tracing head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, ( - f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" - ) + assert ( + head_dim * num_heads == embed_dim + ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert key.shape[:2] == value.shape[:2], ( - f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" - ) + assert ( + key.shape[:2] == value.shape[:2] + ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" else: - assert key.shape == value.shape, ( - f"key shape {key.shape} does not match value shape {value.shape}" - ) + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" # # compute in-projection # if not use_separate_proj_weight: - assert in_proj_weight is not None, ( - "use_separate_proj_weight is False but in_proj_weight is None" - ) + assert ( + in_proj_weight is not None + ), "use_separate_proj_weight is False but in_proj_weight is None" q, k, v = packed_in_projection_packed( q=query, k=key, v=value, w=in_proj_weight, num_groups=num_groups, b=in_proj_bias ) else: - assert q_proj_weight is not None, ( - "use_separate_proj_weight is True but q_proj_weight is None" - ) - assert k_proj_weight is not None, ( - "use_separate_proj_weight is True but k_proj_weight is None" - ) - assert v_proj_weight is not None, ( - "use_separate_proj_weight is True but v_proj_weight is None" - ) + assert ( + q_proj_weight is not None + ), "use_separate_proj_weight is True but q_proj_weight is None" + assert ( + k_proj_weight is not None + ), "use_separate_proj_weight is True but k_proj_weight is None" + assert ( + v_proj_weight is not None + ), "use_separate_proj_weight is True but v_proj_weight is None" if in_proj_bias is None: b_q = b_k = b_v = None else: From a62291031f13043998bd4e8449e188073948dff3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 7 Mar 2025 17:50:48 +0100 Subject: [PATCH 41/68] :shirt: Lint --- torch_uncertainty/layers/functional/packed.py | 66 +++++++++---------- torch_uncertainty/layers/packed.py | 6 +- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index c962531e..a5ab40f9 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -79,15 +79,15 @@ def packed_in_projection( emb_q // num_groups, emb_v // num_groups, ), f"expecting value weights shape of {(emb_q, emb_v)}, but got {w_v.shape}" - assert b_q is None or b_q.shape == ( - emb_q, - ), f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" - assert b_k is None or b_k.shape == ( - emb_q, - ), f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" - assert b_v is None or b_v.shape == ( - emb_q, - ), f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" + assert b_q is None or b_q.shape == (emb_q,), ( + f"expecting query bias shape of {(emb_q,)}, but got {b_q.shape}" + ) + assert b_k is None or b_k.shape == (emb_q,), ( + f"expecting key bias shape of {(emb_k,)}, but got {b_k.shape}" + ) + assert b_v is None or b_v.shape == (emb_q,), ( + f"expecting value bias shape of {(emb_v,)}, but got {b_v.shape}" + ) return ( packed_linear(q, w_q, num_groups, implementation, b_q), @@ -324,47 +324,47 @@ def packed_multi_head_attention_forward( # noqa: D417 # longer causal. is_causal = False - assert ( - embed_dim == embed_dim_to_check - ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + assert embed_dim == embed_dim_to_check, ( + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + ) if isinstance(embed_dim, Tensor): # embed_dim can be a tensor when JIT tracing head_dim = embed_dim.div(num_heads, rounding_mode="trunc") else: head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + assert head_dim * num_heads == embed_dim, ( + f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + ) if use_separate_proj_weight: # allow MHA to have different embedding dimensions when separate projection weights are used - assert ( - key.shape[:2] == value.shape[:2] - ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + assert key.shape[:2] == value.shape[:2], ( + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + ) else: - assert ( - key.shape == value.shape - ), f"key shape {key.shape} does not match value shape {value.shape}" + assert key.shape == value.shape, ( + f"key shape {key.shape} does not match value shape {value.shape}" + ) # # compute in-projection # if not use_separate_proj_weight: - assert ( - in_proj_weight is not None - ), "use_separate_proj_weight is False but in_proj_weight is None" + assert in_proj_weight is not None, ( + "use_separate_proj_weight is False but in_proj_weight is None" + ) q, k, v = packed_in_projection_packed( q=query, k=key, v=value, w=in_proj_weight, num_groups=num_groups, b=in_proj_bias ) else: - assert ( - q_proj_weight is not None - ), "use_separate_proj_weight is True but q_proj_weight is None" - assert ( - k_proj_weight is not None - ), "use_separate_proj_weight is True but k_proj_weight is None" - assert ( - v_proj_weight is not None - ), "use_separate_proj_weight is True but v_proj_weight is None" + assert q_proj_weight is not None, ( + "use_separate_proj_weight is True but q_proj_weight is None" + ) + assert k_proj_weight is not None, ( + "use_separate_proj_weight is True but k_proj_weight is None" + ) + assert v_proj_weight is not None, ( + "use_separate_proj_weight is True but v_proj_weight is None" + ) if in_proj_bias is None: b_q = b_k = b_v = None else: diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index ab9a2a61..66150367 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -700,9 +700,9 @@ def __init__( self.dropout = dropout self.batch_first = batch_first self.head_dim = self.embed_dim // self.num_heads - assert ( - self.head_dim * self.num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert self.head_dim * self.num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) self.num_estimators = num_estimators self.alpha = alpha From 18dacb376545b3a6355a81f0ab827763647abc27 Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 7 Mar 2025 18:16:46 +0000 Subject: [PATCH 42/68] :heavy_check_mark: Add test for BatchEnsemble wrapper and fix bug in batch replication --- tests/models/wrappers/test_batch_ensemble.py | 36 +++++++++++++++++++ .../models/wrappers/batch_ensemble.py | 3 +- 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 tests/models/wrappers/test_batch_ensemble.py diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py new file mode 100644 index 00000000..8e2271ee --- /dev/null +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -0,0 +1,36 @@ +import pytest +import torch +from torch import nn + +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble + + +# Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) +class SimpleModel(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.fc = nn.Linear(in_features, out_features) + self.r_group = nn.Parameter(torch.randn(in_features)) + self.s_group = nn.Parameter(torch.randn(out_features)) + self.bias = nn.Parameter(torch.randn(out_features)) + + def forward(self, x): + return self.fc(x) + + +# Test the BatchEnsemble wrapper +def test_batch_ensemble(): + in_features = 10 + out_features = 5 + num_estimators = 3 + model = SimpleModel(in_features, out_features) + wrapped_model = BatchEnsemble(model, num_estimators) + + # Test forward pass + x = torch.randn(2, in_features) # Batch size of 2 + logits = wrapped_model(x) + assert logits.shape == (2 * num_estimators, out_features) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index c6995ca8..5d6355d3 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -31,5 +31,6 @@ def __init__(self, model: nn.Module, num_estimators: int): def forward(self, x: torch.Tensor) -> torch.Tensor: """Repeats the input batch and passes it through the model.""" - x = x.repeat(self.num_estimators, 1, 1, 1) + repeat_shape = [self.num_estimators] + [1] * (x.dim() - 1) + x = x.repeat(repeat_shape) return self.model(x) From 79b567dd3bf50640cae82cae559333692cf355f1 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 11:02:37 +0100 Subject: [PATCH 43/68] :ok_hand: Comply with PEP 257 --- torch_uncertainty/models/wrappers/batch_ensemble.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index 5d6355d3..c6066e74 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -3,7 +3,7 @@ class BatchEnsemble(nn.Module): - """Wraps a BatchEnsemble model to ensure correct batch replication. + """Wrap a BatchEnsemble model to ensure correct batch replication. In a BatchEnsemble architecture, each estimator operates on a **sub-batch** of the input. This means that the input batch must be **repeated** From dbf4df9f73f40a4af31991f2ff8a3a51da82e6ca Mon Sep 17 00:00:00 2001 From: Anton Date: Mon, 10 Mar 2025 11:40:57 +0000 Subject: [PATCH 44/68] :books: Add note that BatchEnsemble wrapper expects model to use BatchEnsemble layers --- torch_uncertainty/models/wrappers/batch_ensemble.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index c6066e74..e306dcb7 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -22,6 +22,9 @@ class BatchEnsemble(nn.Module): Args: model (nn.Module): The BatchEnsemble model. num_estimators (int): Number of ensemble members. + + Note: + This wrapper assumes that the model uses **BatchEnsemble layers** (see `torchensemble.layers.batch_ensemble`). """ def __init__(self, model: nn.Module, num_estimators: int): From 16a235b69401778dcd9571802670d12909ab07e9 Mon Sep 17 00:00:00 2001 From: Anton Date: Mon, 10 Mar 2025 11:54:29 +0000 Subject: [PATCH 45/68] :hammer: Refactor BatchEnsemble test case to use framework's test format --- tests/models/wrappers/test_batch_ensemble.py | 29 ++++++++------------ 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py index 8e2271ee..8230819e 100644 --- a/tests/models/wrappers/test_batch_ensemble.py +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -1,4 +1,3 @@ -import pytest import torch from torch import nn @@ -6,7 +5,7 @@ # Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) -class SimpleModel(nn.Module): +class _DummyModel(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.fc = nn.Linear(in_features, out_features) @@ -18,19 +17,15 @@ def forward(self, x): return self.fc(x) -# Test the BatchEnsemble wrapper -def test_batch_ensemble(): - in_features = 10 - out_features = 5 - num_estimators = 3 - model = SimpleModel(in_features, out_features) - wrapped_model = BatchEnsemble(model, num_estimators) +class TestBatchEnsembleModel: + def test_forward_pass(self): + in_features = 10 + out_features = 5 + num_estimators = 3 + model = _DummyModel(in_features, out_features) + wrapped_model = BatchEnsemble(model, num_estimators) - # Test forward pass - x = torch.randn(2, in_features) # Batch size of 2 - logits = wrapped_model(x) - assert logits.shape == (2 * num_estimators, out_features) - - -if __name__ == "__main__": - pytest.main([__file__]) + # Test forward pass + x = torch.randn(2, in_features) # Batch size of 2 + logits = wrapped_model(x) + assert logits.shape == (2 * num_estimators, out_features) From 698ba1a16a15e9e2c004b76b14da2863ea6fa86b Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 13:59:25 +0100 Subject: [PATCH 46/68] :heavy_plus_sign: / :heavy_minus_sign: Replace wand by kornia --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 661ad447..8303472e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,9 +42,9 @@ dependencies = [ [project.optional-dependencies] image = [ "scikit-image", + "kornia", "h5py", "opencv-python", - "Wand", ] tabular = ["pandas"] dev = [ From 89b26f2286a1c9586a5eb02a5f43685570164c75 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 14:00:25 +0100 Subject: [PATCH 47/68] :hammer: Use kornia's motion blur & remove useless covignores --- torch_uncertainty/transforms/corruption.py | 114 +++++++++------------ 1 file changed, 50 insertions(+), 64 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index 2c9f59bd..c6a83a81 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -1,6 +1,5 @@ """Adapted from https://github.com/hendrycks/robustness.""" -import ctypes from importlib import util from io import BytesIO @@ -40,24 +39,12 @@ ToTensor, ) -if util.find_spec("wand"): - from wand.api import library as wandlibrary - from wand.image import Image as WandImage +if util.find_spec("kornia"): + from kornia.filters import motion_blur - wandlibrary.MagickMotionBlurImage.argtypes = ( - ctypes.c_void_p, # wand - ctypes.c_double, # radius - ctypes.c_double, # sigma - ctypes.c_double, - ) # angle - - class MotionImage(WandImage): - def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0): - wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle) - - wand_installed = True + kornia_installed = True else: # coverage: ignore - wand_installed = False + kornia_installed = False from torch_uncertainty.datasets import FrostImages @@ -143,7 +130,7 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - if not skimage_installed: # coverage: ignore + if not skimage_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -168,7 +155,7 @@ def __init__(self, severity: int) -> None: severity (int): Severity level of the corruption. """ super().__init__(severity) - if not cv2_installed: # coverage: ignore + if not cv2_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -196,7 +183,7 @@ def forward(self, img: Tensor) -> Tensor: class GlassBlur(TUCorruption): # TODO: batch def __init__(self, severity: int) -> None: super().__init__(severity) - if not skimage_installed or not cv2_installed: # coverage: ignore + if not skimage_installed or not cv2_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -230,20 +217,23 @@ def disk(radius: int, alias_blur: float = 0.1, dtype=np.float32): xs, ys = np.meshgrid(size, size) aliased_disk = np.array((xs**2 + ys**2) <= radius**2, dtype=dtype) aliased_disk /= np.sum(aliased_disk) - return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) class MotionBlur(TUCorruption): def __init__(self, severity: int) -> None: + """Apply a motion blur corruption on the image. + + Note: + Originally, Hendrycks et al. used gaussian motion blur. To remove the dependency with + with wand we changed the transform to a simpler motion blur and kept the values of + sigma as the new half kernel sizes. + """ super().__init__(severity) self.rng = np.random.default_rng() - self.radius = [10, 15, 15, 15, 20][severity - 1] - self.sigma = [3, 5, 8, 12, 15][severity - 1] - self.to_pil_img = ToPILImage() - self.to_tensor = ToTensor() + self.radius = [3, 5, 8, 12, 15][severity - 1] - if not wand_installed: # coverage: ignore + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -252,18 +242,16 @@ def __init__(self, severity: int) -> None: def forward(self, img: Tensor) -> Tensor: if self.severity == 0: return img - output = BytesIO() - pil_img = self.to_pil_img(img) - pil_img.save(output, "PNG") - x = MotionImage(blob=output.getvalue()) - x.motion_blur( - radius=self.radius, - sigma=self.sigma, - angle=self.rng.uniform(-45, 45), + no_batch = False + if img.ndim == 3: + no_batch = True + img = img.unsqueeze(0) + out = motion_blur( + img, kernel_size=self.radius * 2 + 1, angle=self.rng.uniform(-45, 45), direction=0 ) - x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED) - x = np.clip(x[..., [2, 1, 0]], 0, 255) - return self.to_tensor(x) + if no_batch: + out = out.squeeze(0) + return out def clipped_zoom(img, zoom_factor): @@ -294,7 +282,7 @@ def __init__(self, severity: int) -> None: np.arange(1, 1.31, 0.03), ][severity - 1] - if not scipy_installed: # coverage: ignore + if not scipy_installed: raise ImportError( "Please install torch_uncertainty with the all option:" """pip install -U "torch_uncertainty[all]".""" @@ -313,17 +301,22 @@ def forward(self, img: Tensor) -> Tensor: class Snow(TUCorruption): def __init__(self, severity: int) -> None: + """Apply a snow effect on the image. + + Note: + The transformation has been slightly modified, see MotionBlur for details. + """ super().__init__(severity) self.mix = [ - (0.1, 0.3, 3, 0.5, 10, 4, 0.8), - (0.2, 0.3, 2, 0.5, 12, 4, 0.7), - (0.55, 0.3, 4, 0.9, 12, 8, 0.7), - (0.55, 0.3, 4.5, 0.85, 12, 8, 0.65), - (0.55, 0.3, 2.5, 0.85, 12, 12, 0.55), + (0.1, 0.3, 3, 0.5, 4, 0.8), + (0.2, 0.3, 2, 0.5, 4, 0.7), + (0.55, 0.3, 4, 0.9, 8, 0.7), + (0.55, 0.3, 4.5, 0.85, 8, 0.65), + (0.55, 0.3, 2.5, 0.85, 12, 0.55), ][severity - 1] self.rng = np.random.default_rng() - if not wand_installed: # coverage: ignore + if not kornia_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" @@ -339,27 +332,20 @@ def forward(self, img: Tensor) -> Tensor: ] snow_layer = clipped_zoom(snow_layer, self.mix[2]) snow_layer[snow_layer < self.mix[3]] = 0 - snow_layer = Image.fromarray( - (np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), - mode="L", - ) - output = BytesIO() - snow_layer.save(output, format="PNG") - snow_layer = MotionImage(blob=output.getvalue()) - snow_layer.motion_blur( - radius=self.mix[4], - sigma=self.mix[5], - angle=self.rng.uniform(-135, -45), - ) + snow_layer = np.clip(snow_layer.squeeze(), 0, 1) + snow_layer = ( - cv2.imdecode( - np.fromstring(snow_layer.make_blob(), np.uint8), - cv2.IMREAD_UNCHANGED, + motion_blur( + torch.as_tensor(snow_layer).unsqueeze(0).unsqueeze(0), + kernel_size=self.mix[4] * 2 + 1, + angle=self.rng.uniform(-135, -45), + direction=0, ) - / 255.0 + .squeeze(0) + .numpy() ) - snow_layer = snow_layer[np.newaxis, ...] - x = self.mix[6] * x + (1 - self.mix[6]) * np.maximum( + + x = self.mix[5] * x + (1 - self.mix[5]) * np.maximum( x, cv2.cvtColor(x.transpose([1, 2, 0]), cv2.COLOR_RGB2GRAY).reshape(1, height, width) * 1.5 + 0.5, @@ -516,7 +502,7 @@ def forward(self, img: Tensor) -> Tensor: class Elastic(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - if not cv2_installed or not scipy_installed: # coverage: ignore + if not cv2_installed or not scipy_installed: raise ImportError( "Please install torch_uncertainty with the all option:" """pip install -U "torch_uncertainty[all]".""" @@ -620,7 +606,7 @@ def forward(self, img: Tensor) -> Tensor: class GaussianBlur(TUCorruption): def __init__(self, severity: int) -> None: super().__init__(severity) - if not skimage_installed: # coverage: ignore + if not skimage_installed: raise ImportError( "Please install torch_uncertainty with the image option:" """pip install -U "torch_uncertainty[image]".""" From 839444ebe519857c55f52e4aa188a5ff2567a1b3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 14:13:59 +0100 Subject: [PATCH 48/68] :book: Improve the clarity of the OOD detection behavior #139 --- torch_uncertainty/routines/classification.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 7c752a11..a447e53f 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -123,6 +123,12 @@ def __init__( Warning: You must define :attr:`optim_recipe` if you do not use the Lightning CLI. + Note: + If :attr:`eval_ood` is ``True``, we perform a binary classification and update the + OOD-related metrics twice: + - once during the test on ID values where the given binary label is 0 (for ID) + - once during the test on OOD values where the given binary label is 1 (for OOD) + Note: :attr:`optim_recipe` can be anything that can be returned by :meth:`LightningModule.configure_optimizers()`. Find more details From 4887f7819ef23a72b1c733d6d342b92d58604a56 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 14:17:59 +0100 Subject: [PATCH 49/68] :white_check_mark: Fix cov --- tests/transforms/test_corruption.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/transforms/test_corruption.py b/tests/transforms/test_corruption.py index 1ee28555..ce6a76cc 100644 --- a/tests/transforms/test_corruption.py +++ b/tests/transforms/test_corruption.py @@ -88,6 +88,10 @@ def test_motion_blur(self): transform = MotionBlur(0) transform(inputs) + inputs = torch.rand(1, 3, 32, 32) + transform = MotionBlur(1) + transform(inputs) + def test_zoom_blur(self): inputs = torch.rand(3, 32, 32) transform = ZoomBlur(1) From 12a7423624a59f230c3f8d12eeba7bd86b977449 Mon Sep 17 00:00:00 2001 From: Olivier Laurent <62881275+o-laurent@users.noreply.github.com> Date: Mon, 10 Mar 2025 14:21:06 +0100 Subject: [PATCH 50/68] :shirt: Fix typo in corruption.py --- torch_uncertainty/transforms/corruption.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_uncertainty/transforms/corruption.py b/torch_uncertainty/transforms/corruption.py index c6a83a81..0b28c73a 100644 --- a/torch_uncertainty/transforms/corruption.py +++ b/torch_uncertainty/transforms/corruption.py @@ -225,8 +225,8 @@ def __init__(self, severity: int) -> None: """Apply a motion blur corruption on the image. Note: - Originally, Hendrycks et al. used gaussian motion blur. To remove the dependency with - with wand we changed the transform to a simpler motion blur and kept the values of + Originally, Hendrycks et al. used Gaussian motion blur. To remove the dependency with + with `Wand` we changed the transform to a simpler motion blur and kept the values of sigma as the new half kernel sizes. """ super().__init__(severity) From 6e2596de0c7f6e1c873722ad0197a57faf6ee5de Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:00:50 +0100 Subject: [PATCH 51/68] :hammer: Use `einops.repeat` instead of `torch.repeat` in `RepeatTarget` --- torch_uncertainty/transforms/batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index 1426992a..19e0edab 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange +from einops import rearrange, repeat from torch import Tensor, nn @@ -21,7 +21,7 @@ def __init__(self, num_repeats: int) -> None: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat(self.num_repeats, *[1] * (targets.ndim - 1)) + return inputs, repeat(targets, "b ... -> (m b) ...", m=self.num_repeats) class MIMOBatchFormat(nn.Module): @@ -79,6 +79,6 @@ def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: [torch.index_select(targets, dim=0, index=indices) for indices in shuffle_indices], dim=0, ) - inputs = rearrange(inputs, "m b c h w -> (m b) c h w", m=self.num_estimators) + inputs = rearrange(inputs, "m b ... -> (m b) ...", m=self.num_estimators) targets = rearrange(targets, "m b -> (m b)", m=self.num_estimators) return inputs, targets From 16c77059c9d473efcf3f487b6ce1fbbe6666eeb5 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:02:20 +0100 Subject: [PATCH 52/68] :hammer: Refine BatchEnsemble layers and add conversion methods --- torch_uncertainty/layers/batch_ensemble.py | 133 ++++++++++++--------- 1 file changed, 75 insertions(+), 58 deletions(-) diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index 3f15ac21..abb481f8 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -1,6 +1,7 @@ import math import torch +from einops import repeat from torch import Tensor, nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair @@ -79,17 +80,18 @@ def __init__( :math:`H_{out} = \text{out_features}`. Warning: - Ensure that `batch_size` is divisible by :attr:`num_estimators` when calling :func:`forward()`. - In a BatchEnsemble architecture, the input batch is typically **repeated** `num_estimators` - times along the first axis. Incorrect batch size may lead to unexpected results. + It is advised to ensure that `batch_size` is divisible by :attr:`num_estimators` when + calling :func:`forward()`, so each estimator receives the same number of examples. + In a BatchEnsemble architecture, the input is typically **repeated** `num_estimators` + times along the batch dimension. Incorrect batch size may lead to unexpected results. - To simplify batch handling, wrap your model with `BatchEnsembleWrapper`, which automatically - repeats the batch before passing it through the network. See `BatchEnsembleWrapper` for details. + To simplify batch handling, wrap your model with `torch_uncertainty.wrappers.BatchEnsemble`, + which automatically repeats the batch before passing it through the network. Examples: >>> # With three estimators - >>> m = LinearBE(20, 30, 3) + >>> m = BatchLinear(20, 30, 3) >>> input = torch.randn(8, 20) >>> output = m(input) >>> print(output.size()) @@ -116,6 +118,30 @@ def __init__( self.register_parameter("bias", None) self.reset_parameters() + @classmethod + def from_linear(cls, linear: nn.Linear, num_estimators: int) -> "BatchLinear": + r"""Create a BatchEnsemble-style Linear layer from an existing Linear layer. + + Args: + linear (nn.Linear): The Linear layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchLinear: The converted BatchEnsemble-style Linear layer. + + Example: + >>> linear = nn.Linear(20, 30) + >>> be_linear = BatchLinear.from_linear(linear, num_estimators=3) + """ + return cls( + in_features=linear.in_features, + out_features=linear.out_features, + num_estimators=num_estimators, + bias=linear.bias is not None, + device=linear.weight.device, + dtype=linear.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -131,16 +157,12 @@ def forward(self, inputs: Tensor) -> Tensor: ) extra = batch_size % self.num_estimators - r_group = torch.repeat_interleave(self.r_group, examples_per_estimator, dim=0) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = torch.repeat_interleave(self.s_group, examples_per_estimator, dim=0) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) + r_group = repeat(self.r_group, "m h -> (m b) h", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = torch.repeat_interleave( - self.bias, - examples_per_estimator, - dim=0, - ) + bias = repeat(self.bias, "m h -> (m b) h", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None @@ -288,7 +310,7 @@ def __init__( Examples: >>> # With square kernels, four estimators and equal stride - >>> m = Conv2dBE(3, 32, 3, 4, stride=1) + >>> m = BatchConv2d(3, 32, 3, 4, stride=1) >>> input = torch.randn(8, 3, 16, 16) >>> output = m(input) >>> print(output.size()) @@ -325,6 +347,38 @@ def __init__( self.reset_parameters() + @classmethod + def from_conv2d(cls, conv2d: nn.Conv2d, num_estimators: int) -> "BatchConv2d": + r"""Create a BatchEnsemble-style Conv2d layer from an existing Conv2d layer. + + Args: + conv2d (nn.Conv2d): The Conv2d layer to convert. + num_estimators (int): Number of ensemble members. + + Returns: + BatchConv2d: The converted BatchEnsemble-style Conv2d layer. + + Warning: + All parameters of the original Conv2d layer will be discarded. + + Example: + >>> conv2d = nn.Conv2d(3, 32, kernel_size=3) + >>> be_conv2d = BatchConv2d.from_conv2d(conv2d, num_estimators=3) + """ + return cls( + in_channels=conv2d.in_channels, + out_channels=conv2d.out_channels, + kernel_size=conv2d.kernel_size, + stride=conv2d.stride, + padding=conv2d.padding, + dilation=conv2d.dilation, + groups=conv2d.groups, + bias=conv2d.bias is not None, + num_estimators=num_estimators, + device=conv2d.weight.device, + dtype=conv2d.weight.dtype, + ) + def reset_parameters(self) -> None: nn.init.normal_(self.r_group, mean=1.0, std=0.5) nn.init.normal_(self.s_group, mean=1.0, std=0.5) @@ -338,50 +392,13 @@ def forward(self, inputs: Tensor) -> Tensor: examples_per_estimator = batch_size // self.num_estimators extra = batch_size % self.num_estimators - r_group = ( - torch.repeat_interleave( - self.r_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.r_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - r_group = torch.cat([r_group, r_group[:extra]], dim=0) # .unsqueeze(-1).unsqueeze(-1) - s_group = ( - torch.repeat_interleave( - self.s_group, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.s_group.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - s_group = torch.cat([s_group, s_group[:extra]], dim=0) # + r_group = repeat(self.r_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + r_group = torch.cat([r_group, r_group[:extra]], dim=0) + s_group = repeat(self.s_group, "m h -> (m b) h 1 1", b=examples_per_estimator) + s_group = torch.cat([s_group, s_group[:extra]], dim=0) if self.bias is not None: - bias = ( - torch.repeat_interleave( - self.bias, - torch.full( - [self.num_estimators], - examples_per_estimator, - device=self.bias.device, - ), - dim=0, - ) - .unsqueeze(-1) - .unsqueeze(-1) - ) - + bias = repeat(self.bias, "m h -> (m b) h 1 1", b=examples_per_estimator) bias = torch.cat([bias, bias[:extra]], dim=0) else: bias = None From 9d4b4a49f1dbed6130b75b393f81b0e94a5265ad Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:04:02 +0100 Subject: [PATCH 53/68] :hammer: `BatchEnsemble` wrapper overhaul - now the input is repeated during training depending on `repeat_training_inputs` argument - with `convert_layers==True` all `nn.Linear` and `nn.Conv2d` layers are replaced by `BatchLinear` and `BatchConv2d` --- .../mnist/configs/lenet_batch_ensemble.yaml | 1 + tests/layers/test_batch.py | 22 +++ tests/models/wrappers/test_batch_ensemble.py | 75 ++++++++-- torch_uncertainty/models/__init__.py | 2 + torch_uncertainty/models/lenet.py | 17 +-- torch_uncertainty/models/wrappers/__init__.py | 1 + .../models/wrappers/batch_ensemble.py | 139 ++++++++++++++++-- 7 files changed, 224 insertions(+), 33 deletions(-) diff --git a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml index 0d2a94ee..d385b100 100644 --- a/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml +++ b/experiments/classification/mnist/configs/lenet_batch_ensemble.yaml @@ -40,6 +40,7 @@ model: norm: torch.nn.BatchNorm2d groups: 1 dropout_rate: 0 + repeat_training_inputs: true num_classes: 10 loss: CrossEntropyLoss is_ensemble: true diff --git a/tests/layers/test_batch.py b/tests/layers/test_batch.py index bb7ca6e9..75e54485 100644 --- a/tests/layers/test_batch.py +++ b/tests/layers/test_batch.py @@ -33,6 +33,17 @@ def test_linear_one_estimator_no_bias(self, feat_input: torch.Tensor): out = layer(feat_input) assert out.shape == torch.Size([4, 2]) + def test_convert_from_linear(self, feat_input: torch.Tensor): + linear = torch.nn.Linear(6, 3) + layer = BatchLinear.from_linear(linear, num_estimators=2) + assert layer.linear.weight.shape == torch.Size([3, 6]) + assert layer.linear.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(feat_input) + assert out.shape == torch.Size([4, 3]) + class TestBatchConv2d: """Testing the BatchConv2d layer class.""" @@ -47,3 +58,14 @@ def test_conv_two_estimators(self, img_input: torch.Tensor): layer = BatchConv2d(6, 2, num_estimators=2, kernel_size=1) out = layer(img_input) assert out.shape == torch.Size([5, 2, 3, 3]) + + def test_convert_from_conv2d(self, img_input: torch.Tensor): + conv = torch.nn.Conv2d(6, 3, 1) + layer = BatchConv2d.from_conv2d(conv, num_estimators=2) + assert layer.conv.weight.shape == torch.Size([3, 6, 1, 1]) + assert layer.conv.bias is None + assert layer.r_group.shape == torch.Size([2, 6]) + assert layer.s_group.shape == torch.Size([2, 3]) + assert layer.bias.shape == torch.Size([2, 3]) + out = layer(img_input) + assert out.shape == torch.Size([5, 3, 3, 3]) diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py index 8230819e..8c1b675e 100644 --- a/tests/models/wrappers/test_batch_ensemble.py +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -1,31 +1,82 @@ +import pytest import torch from torch import nn +from torch_uncertainty.layers import BatchConv2d, BatchLinear from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble +@pytest.fixture() +def img_input() -> torch.Tensor: + return torch.rand((5, 6, 3, 3)) + + # Define a simple model for testing wrapper functionality (disregarding the actual BatchEnsemble architecture) class _DummyModel(nn.Module): def __init__(self, in_features, out_features): super().__init__() - self.fc = nn.Linear(in_features, out_features) - self.r_group = nn.Parameter(torch.randn(in_features)) - self.s_group = nn.Parameter(torch.randn(out_features)) - self.bias = nn.Parameter(torch.randn(out_features)) + self.conv = nn.Conv2d(in_features, out_features, 3) + self.fc = nn.Linear(out_features, out_features) def forward(self, x): + x = self.conv(x) + x = x.flatten(1) + return self.fc(x) + + +class _DummyBEModel(nn.Module): + def __init__(self, in_features, out_features, num_estimators): + super().__init__() + self.conv = BatchConv2d(in_features, out_features, 3, num_estimators) + self.fc = BatchLinear(out_features, out_features, num_estimators=num_estimators) + + def forward(self, x): + x = self.conv(x) + x = x.flatten(1) return self.fc(x) class TestBatchEnsembleModel: - def test_forward_pass(self): - in_features = 10 - out_features = 5 + def test_convert_layers(self): + in_features = 6 + out_features = 4 num_estimators = 3 + model = _DummyModel(in_features, out_features) - wrapped_model = BatchEnsemble(model, num_estimators) + wrapped_model = BatchEnsemble(model, num_estimators, convert_layers=True) + assert wrapped_model.num_estimators == num_estimators + assert isinstance(wrapped_model.model.conv, BatchConv2d) + assert isinstance(wrapped_model.model.fc, BatchLinear) + + def test_forward_pass(self, img_input): + batch_size = img_input.size(0) + in_features = img_input.size(1) + out_features = 4 + num_estimators = 3 + model = _DummyBEModel(in_features, out_features, num_estimators) + # with repeat_training_inputs=False + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=False) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (img_input.size(0), out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # with repeat_training_inputs=True + wrapped_model = BatchEnsemble(model, num_estimators, repeat_training_inputs=True) + # test forward pass for training + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) + # test forward pass for evaluation + wrapped_model.eval() + logits = wrapped_model(img_input) + assert logits.shape == (batch_size * num_estimators, out_features) - # Test forward pass - x = torch.randn(2, in_features) # Batch size of 2 - logits = wrapped_model(x) - assert logits.shape == (2 * num_estimators, out_features) + def test_errors(self): + with pytest.raises(ValueError): + BatchEnsemble(_DummyBEModel(10, 5, 1), 0) + with pytest.raises(ValueError): + BatchEnsemble(_DummyModel(10, 5), 1) + with pytest.raises(ValueError): + BatchEnsemble(nn.Identity(), 2, convert_layers=True) diff --git a/torch_uncertainty/models/__init__.py b/torch_uncertainty/models/__init__.py index 4b8964bc..a3faf1f4 100644 --- a/torch_uncertainty/models/__init__.py +++ b/torch_uncertainty/models/__init__.py @@ -5,9 +5,11 @@ STEP_UPDATE_MODEL, SWA, SWAG, + BatchEnsemble, CheckpointEnsemble, MCDropout, StochasticModel, + batch_ensemble, deep_ensembles, mc_dropout, ) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index aca916ee..b31ba133 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -5,7 +5,6 @@ import torch.nn.functional as F from torch import nn -from torch_uncertainty.layers.batch_ensemble import BatchConv2d, BatchLinear from torch_uncertainty.layers.bayesian import BayesConv2d, BayesLinear from torch_uncertainty.layers.mc_batch_norm import MCBatchNorm2d from torch_uncertainty.layers.packed import PackedConv2d, PackedLinear @@ -129,22 +128,22 @@ def batchensemble_lenet( norm: type[nn.Module] = nn.BatchNorm2d, groups: int = 1, dropout_rate: float = 0.0, + repeat_training_inputs: bool = False, ) -> _LeNet: - model = _lenet( - stochastic=False, + model = lenet( in_channels=in_channels, num_classes=num_classes, - linear_layer=BatchLinear, - conv2d_layer=BatchConv2d, - layer_args={ - "num_estimators": num_estimators, - }, activation=activation, norm=norm, groups=groups, dropout_rate=dropout_rate, ) - return BatchEnsemble(model, num_estimators) + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=True, + ) def packed_lenet( diff --git a/torch_uncertainty/models/wrappers/__init__.py b/torch_uncertainty/models/wrappers/__init__.py index 75f37e66..fb4ff50c 100644 --- a/torch_uncertainty/models/wrappers/__init__.py +++ b/torch_uncertainty/models/wrappers/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 +from .batch_ensemble import BatchEnsemble, batch_ensemble from .checkpoint_ensemble import ( CheckpointEnsemble, ) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index e306dcb7..9d6213f1 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -1,6 +1,9 @@ import torch +from einops import repeat from torch import nn +from torch_uncertainty.layers import BatchConv2d, BatchLinear + class BatchEnsemble(nn.Module): """Wrap a BatchEnsemble model to ensure correct batch replication. @@ -12,28 +15,140 @@ class BatchEnsemble(nn.Module): This wrapper automatically **duplicates the input batch** along the first axis, ensuring that each estimator receives the correct data format. - **Usage Example:** - ```python - model = lenet(in_channels=1, num_classes=10) - wrapped_model = BatchEnsembleWrapper(model, num_estimators=5) - logits = wrapped_model(x) # `x` is automatically repeated `num_estimators` times - ``` - Args: model (nn.Module): The BatchEnsemble model. num_estimators (int): Number of ensemble members. + repeat_training_inputs (optional, bool): Whether to repeat the input batch during training. + If `True`, the input batch is repeated during both training and evaluation. If `False`, + the input batch is repeated only during evaluation. Default is `False`. + convert_layers (optional, bool): Whether to convert the model's layers to BatchEnsemble layers. + If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their + BatchEnsemble counterparts. Default is `False`. + + Raises: + ValueError: If neither `BatchLinear` nor `BatchConv2d` layers are found in the model at the + end of initialization. + ValueError: If `num_estimators` is less than or equal to 0. + ValueError: If `convert_layers=True` and neither `nn.Linear` nor `nn.Conv2d` layers are + found in the model. + + Warning: + If `convert_layers==True`, the wrapper will attempt to convert all `nn.Linear` and `nn.Conv2d` + layers in the model to their BatchEnsemble counterparts. If the model contains other types of + layers, the conversion won't happen for these layers. If don't have any `nn.Linear` or `nn.Conv2d` + layers in the model, the wrapper will raise an error during conversion. - Note: - This wrapper assumes that the model uses **BatchEnsemble layers** (see `torchensemble.layers.batch_ensemble`). + Warning: + If `repeat_training_inputs==True` and you want to use one of the `torch_uncertainty.routines` + for training, be sure to set `format_batch_fn=RepeatTarget(num_repeats=num_estimators)` when + initializing the routine. + + Example: + >>> model = nn.Sequential( + ... nn.Linear(10, 5), + ... nn.ReLU(), + ... nn.Linear(5, 2) + ... ) + >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) + >>> model + BatchEnsemble( + (model): Sequential( + (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) + (1): ReLU() + (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) + ) + ) """ - def __init__(self, model: nn.Module, num_estimators: int): + def __init__( + self, + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, + ) -> None: super().__init__() self.model = model self.num_estimators = num_estimators + self.repeat_training_inputs = repeat_training_inputs + + if convert_layers: + self._convert_layers() + + filtered_modules = [ + module + for module in self.model.modules() + if isinstance(module, BatchLinear | BatchConv2d) + ] + _batch_ensemble_checks(filtered_modules, num_estimators) def forward(self, x: torch.Tensor) -> torch.Tensor: """Repeats the input batch and passes it through the model.""" - repeat_shape = [self.num_estimators] + [1] * (x.dim() - 1) - x = x.repeat(repeat_shape) + if not self.training or self.repeat_training_inputs: + x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) return self.model(x) + + def _convert_layers(self) -> None: + """Converts the model's layers to BatchEnsemble layers.""" + no_valid_layers = True + for name, layer in self.model.named_modules(): + if isinstance(layer, nn.Linear): + setattr( + self.model, + name, + BatchLinear.from_linear(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + elif isinstance(layer, nn.Conv2d): + setattr( + self.model, + name, + BatchConv2d.from_conv2d(layer, num_estimators=self.num_estimators), + ) + no_valid_layers = False + if no_valid_layers: + raise ValueError( + "No valid layers found in the model. " + "Please use `nn.Linear` or `nn.Conv2d` layers to apply BatchEnsemble." + ) + + +def _batch_ensemble_checks(filtered_modules, num_estimators): + """Check if the model contains the required number of dropout modules.""" + if len(filtered_modules) == 0: + raise ValueError( + "No BatchEnsemble layers found in the model. " + "Please use `BatchLinear` or `BatchConv2d` layers in your model " + "or set `convert_layers=True` when initializing the wrapper." + ) + if num_estimators <= 0: + raise ValueError("`num_estimators` must be greater than 0.") + + +def batch_ensemble( + model: nn.Module, + num_estimators: int, + repeat_training_inputs: bool = False, + convert_layers: bool = False, +) -> BatchEnsemble: + """BatchEnsemble wrapper for a model. + + Args: + model (nn.Module): model to wrap + num_estimators (int): number of ensemble members + repeat_training_inputs (bool, optional): whether to repeat the input batch during training. + If `True`, the input batch is repeated during both training and evaluation. If `False`, + the input batch is repeated only during evaluation. Default is `False`. + convert_layers (bool, optional): whether to convert the model's layers to BatchEnsemble layers. + If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their + BatchEnsemble counterparts. Default is `False`. + + Returns: + BatchEnsemble: BatchEnsemble wrapper for the model + """ + return BatchEnsemble( + model=model, + num_estimators=num_estimators, + repeat_training_inputs=repeat_training_inputs, + convert_layers=convert_layers, + ) From 8b29026dd21bbfeac5efc3403ca715e72e610851 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:06:44 +0100 Subject: [PATCH 54/68] :books: Add BatchEnsemble wrapper utilities in the API Reference --- docs/source/api.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 16f1f1c2..0165ad49 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -201,6 +201,7 @@ Functions :toctree: generated/ :nosignatures: + batch_ensemble deep_ensembles mc_dropout @@ -212,6 +213,7 @@ Classes :nosignatures: :template: class.rst + BatchEnsemble CheckpointEnsemble EMA MCDropout From f2ae628ba8481910291368917f258edd29c63fdb Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:12:42 +0100 Subject: [PATCH 55/68] :book: `BatchEnsemble` docstring update --- torch_uncertainty/models/wrappers/batch_ensemble.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index 9d6213f1..d67230e5 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -83,7 +83,9 @@ def __init__( _batch_ensemble_checks(filtered_modules, num_estimators) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Repeats the input batch and passes it through the model.""" + """Repeat the input if `self.training == False` or `repeat_training_inputs==True` and pass + it through the model. + """ if not self.training or self.repeat_training_inputs: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) return self.model(x) From c737d07b24243191cfabccdba0b0ed41eee4043f Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 17:17:22 +0100 Subject: [PATCH 56/68] :book: `BatchEnsemble` docstring update --- .../models/wrappers/batch_ensemble.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index d67230e5..1e9bc46f 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -26,21 +26,21 @@ class BatchEnsemble(nn.Module): BatchEnsemble counterparts. Default is `False`. Raises: - ValueError: If neither `BatchLinear` nor `BatchConv2d` layers are found in the model at the + ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the end of initialization. - ValueError: If `num_estimators` is less than or equal to 0. - ValueError: If `convert_layers=True` and neither `nn.Linear` nor `nn.Conv2d` layers are + ValueError: If ``num_estimators`` is less than or equal to ``0``. + ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are found in the model. Warning: - If `convert_layers==True`, the wrapper will attempt to convert all `nn.Linear` and `nn.Conv2d` + If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d`` layers in the model to their BatchEnsemble counterparts. If the model contains other types of - layers, the conversion won't happen for these layers. If don't have any `nn.Linear` or `nn.Conv2d` + layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d`` layers in the model, the wrapper will raise an error during conversion. Warning: - If `repeat_training_inputs==True` and you want to use one of the `torch_uncertainty.routines` - for training, be sure to set `format_batch_fn=RepeatTarget(num_repeats=num_estimators)` when + If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines`` + for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when initializing the routine. Example: @@ -83,8 +83,8 @@ def __init__( _batch_ensemble_checks(filtered_modules, num_estimators) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Repeat the input if `self.training == False` or `repeat_training_inputs==True` and pass - it through the model. + """Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and + pass it through the model. """ if not self.training or self.repeat_training_inputs: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) From 0292306645badf64ed7816f4791ef67ead483ffd Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 10 Mar 2025 18:04:50 +0100 Subject: [PATCH 57/68] :wrench: Bump version --- docs/source/conf.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2fe7b638..2c9d46a0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.4.1" +release = "0.4.2" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 8303472e..567e152c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.4.1" +version = "0.4.2" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, From 8b81c132159e2d7dfce8c3542dcd5992d42c1e1e Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 18:02:26 +0100 Subject: [PATCH 58/68] :wrench: Add doc trigger when ready for review --- .github/workflows/build-docs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index f5c1236d..a0e5aa3e 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -4,6 +4,7 @@ on: branches: - main pull_request: + types: [opened, reopened, edited, ready_for_review] branches: - main schedule: From d3fa198f022f4a76dc255cff2963effc852bc524 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 18:05:56 +0100 Subject: [PATCH 59/68] :bug: Fix triggers --- .github/workflows/build-docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index a0e5aa3e..e1b5845e 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -4,7 +4,7 @@ on: branches: - main pull_request: - types: [opened, reopened, edited, ready_for_review] + types: [opened, reopened, ready_for_review, synchronize] branches: - main schedule: From a7127d4ea6d5742926739ca8838c8836f41ff910 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 10 Mar 2025 18:07:11 +0100 Subject: [PATCH 60/68] :bug: Fix pytest triggers --- .github/workflows/run-tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 9b23e989..8435ac79 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -7,7 +7,6 @@ on: - dev pull_request: branches: - - main - dev schedule: - cron: "42 7 * * 0" From 30c0f8420207ded64691c8b2e861137b6c3f7e82 Mon Sep 17 00:00:00 2001 From: alafage Date: Tue, 11 Mar 2025 00:42:00 +0100 Subject: [PATCH 61/68] :shirt: Update according PR comments Co-authored-by: o-laurent --- .gitignore | 2 +- .../models/wrappers/batch_ensemble.py | 121 +++++++++--------- torch_uncertainty/routines/classification.py | 3 +- torch_uncertainty/transforms/image.py | 56 ++++---- 4 files changed, 90 insertions(+), 92 deletions(-) diff --git a/.gitignore b/.gitignore index c8032c5e..65bc7436 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ # Custom .vscode/ -.itea/ +.idea/ data/ logs/ lightning_logs/ diff --git a/torch_uncertainty/models/wrappers/batch_ensemble.py b/torch_uncertainty/models/wrappers/batch_ensemble.py index 1e9bc46f..99132fec 100644 --- a/torch_uncertainty/models/wrappers/batch_ensemble.py +++ b/torch_uncertainty/models/wrappers/batch_ensemble.py @@ -6,60 +6,6 @@ class BatchEnsemble(nn.Module): - """Wrap a BatchEnsemble model to ensure correct batch replication. - - In a BatchEnsemble architecture, each estimator operates on a **sub-batch** - of the input. This means that the input batch must be **repeated** - :attr:`num_estimators` times before being processed. - - This wrapper automatically **duplicates the input batch** along the first axis, - ensuring that each estimator receives the correct data format. - - Args: - model (nn.Module): The BatchEnsemble model. - num_estimators (int): Number of ensemble members. - repeat_training_inputs (optional, bool): Whether to repeat the input batch during training. - If `True`, the input batch is repeated during both training and evaluation. If `False`, - the input batch is repeated only during evaluation. Default is `False`. - convert_layers (optional, bool): Whether to convert the model's layers to BatchEnsemble layers. - If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their - BatchEnsemble counterparts. Default is `False`. - - Raises: - ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the - end of initialization. - ValueError: If ``num_estimators`` is less than or equal to ``0``. - ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are - found in the model. - - Warning: - If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d`` - layers in the model to their BatchEnsemble counterparts. If the model contains other types of - layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d`` - layers in the model, the wrapper will raise an error during conversion. - - Warning: - If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines`` - for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when - initializing the routine. - - Example: - >>> model = nn.Sequential( - ... nn.Linear(10, 5), - ... nn.ReLU(), - ... nn.Linear(5, 2) - ... ) - >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) - >>> model - BatchEnsemble( - (model): Sequential( - (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) - (1): ReLU() - (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) - ) - ) - """ - def __init__( self, model: nn.Module, @@ -67,6 +13,59 @@ def __init__( repeat_training_inputs: bool = False, convert_layers: bool = False, ) -> None: + """Wrap a BatchEnsemble model to ensure correct batch replication. + + In a BatchEnsemble architecture, each estimator operates on a **sub-batch** + of the input. This means that the input batch must be **repeated** + :attr:`num_estimators` times before being processed. + + This wrapper automatically **duplicates the input batch** along the first axis, + ensuring that each estimator receives the correct data format. + + Args: + model (nn.Module): The BatchEnsemble model. + num_estimators (int): Number of ensemble members. + repeat_training_inputs (optional, bool): Whether to repeat the input batch during training. + If ``True``, the input batch is repeated during both training and evaluation. If ``False``, + the input batch is repeated only during evaluation. Default is ``False``. + convert_layers (optional, bool): Whether to convert the model's layers to BatchEnsemble layers. + If ``True``, the wrapper will convert all ``nn.Linear`` and ``nn.Conv2d`` layers to their + BatchEnsemble counterparts. Default is ``False``. + + Raises: + ValueError: If neither ``BatchLinear`` nor ``BatchConv2d`` layers are found in the model at the + end of initialization. + ValueError: If ``num_estimators`` is less than or equal to ``0``. + ValueError: If ``convert_layers=True`` and neither ``nn.Linear`` nor ``nn.Conv2d`` layers are + found in the model. + + Warning: + If ``convert_layers==True``, the wrapper will attempt to convert all ``nn.Linear`` and ``nn.Conv2d`` + layers in the model to their BatchEnsemble counterparts. If the model contains other types of + layers, the conversion won't happen for these layers. If don't have any ``nn.Linear`` or ``nn.Conv2d`` + layers in the model, the wrapper will raise an error during conversion. + + Warning: + If ``repeat_training_inputs==True`` and you want to use one of the ``torch_uncertainty.routines`` + for training, be sure to set ``format_batch_fn=RepeatTarget(num_repeats=num_estimators)`` when + initializing the routine. + + Example: + >>> model = nn.Sequential( + ... nn.Linear(10, 5), + ... nn.ReLU(), + ... nn.Linear(5, 2) + ... ) + >>> model = BatchEnsemble(model, num_estimators=4, convert_layers=True) + >>> model + BatchEnsemble( + (model): Sequential( + (0): BatchLinear(in_features=10, out_features=5, num_estimators=4) + (1): ReLU() + (2): BatchLinear(in_features=5, out_features=2, num_estimators=4) + ) + ) + """ super().__init__() self.model = model self.num_estimators = num_estimators @@ -83,15 +82,13 @@ def __init__( _batch_ensemble_checks(filtered_modules, num_estimators) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and - pass it through the model. - """ + """Repeat the input if ``self.training==False`` or ``repeat_training_inputs==True`` and pass it through the model.""" if not self.training or self.repeat_training_inputs: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) return self.model(x) def _convert_layers(self) -> None: - """Converts the model's layers to BatchEnsemble layers.""" + """Convert the model's layers to BatchEnsemble layers.""" no_valid_layers = True for name, layer in self.model.named_modules(): if isinstance(layer, nn.Linear): @@ -139,11 +136,11 @@ def batch_ensemble( model (nn.Module): model to wrap num_estimators (int): number of ensemble members repeat_training_inputs (bool, optional): whether to repeat the input batch during training. - If `True`, the input batch is repeated during both training and evaluation. If `False`, - the input batch is repeated only during evaluation. Default is `False`. + If ``True``, the input batch is repeated during both training and evaluation. If ``False``, + the input batch is repeated only during evaluation. Default is ``False``. convert_layers (bool, optional): whether to convert the model's layers to BatchEnsemble layers. - If `True`, the wrapper will convert all `nn.Linear` and `nn.Conv2d` layers to their - BatchEnsemble counterparts. Default is `False`. + If ``True``, the wrapper will convert all ``nn.Linear`` and ``nn.Conv2d`` layers to their + BatchEnsemble counterparts. Default is ``False``. Returns: BatchEnsemble: BatchEnsemble wrapper for the model diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 29f5888f..e5469b5c 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -127,7 +127,8 @@ def __init__( When using an ensemble model, you must: 1. Set :attr:`is_ensemble` to ``True``. 2. Set :attr:`format_batch_fn` to :class:`torch_uncertainty.transforms.RepeatTarget(num_repeats=num_estimators)`. - 3. Ensure that the model's forward pass outputs a tensor of shape :math:`(M \times B, C)`, where :math:`M` is the number of estimators, :math:`B` is the batch size, :math:`C` is the number of classes. + 3. Ensure that the model's forward pass outputs a tensor of shape :math:`(M \times B, C)`, + where :math:`M` is the number of estimators, :math:`B` is the batch size, :math:`C` is the number of classes. For automated batch handling, consider using the available model wrappers in `torch_uncertainty.models.wrappers`. diff --git a/torch_uncertainty/transforms/image.py b/torch_uncertainty/transforms/image.py index 56941b9f..cb28dd6e 100644 --- a/torch_uncertainty/transforms/image.py +++ b/torch_uncertainty/transforms/image.py @@ -256,9 +256,9 @@ def __init__( .. code-block:: python - scale = uniform_sample(min_scale, max_scale) - output_width = input_width * scale - output_height = input_height * scale + scale = uniform_sample(min_scale, max_scale) + output_width = input_width * scale + output_height = input_height * scale If the input is a :class:`torch.Tensor` or a ``TVTensor`` (e.g. :class:`~torchvision.tv_tensors.Image`, :class:`~torchvision.tv_tensors.Video`, :class:`~torchvision.tv_tensors.BoundingBoxes` etc.) @@ -266,31 +266,31 @@ def __init__( the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape. Args: - min_scale (int): Minimum scale for random sampling - max_scale (int): Maximum scale for random sampling - interpolation (InterpolationMode, optional): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, - ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. - The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. - antialias (bool, optional): Whether to apply antialiasing. - It only affects **tensors** with bilinear or bicubic modes and it is - ignored otherwise: on PIL images, antialiasing is always applied on - bilinear or bicubic modes; on other modes (for PIL images and - tensors), antialiasing makes no sense and this parameter is ignored. - Possible values are: - - - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. - Other mode aren't affected. This is probably what you want to use. - - ``False``: will not apply antialiasing for tensors on any mode. PIL - images are still antialiased on bilinear or bicubic modes, because - PIL doesn't support no antialias. - - ``None``: equivalent to ``False`` for tensors and ``True`` for - PIL images. This value exists for legacy reasons and you probably - don't want to use it unless you really know what you are doing. - - The default value changed from ``None`` to ``True`` in - v0.17, for the PIL and Tensor backends to be consistent. + min_scale (int): Minimum scale for random sampling + max_scale (int): Maximum scale for random sampling + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True`` (default): will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The default value changed from ``None`` to ``True`` in + v0.17, for the PIL and Tensor backends to be consistent. """ super().__init__() self.min_scale = min_scale From 7572ca07513cd7897b3e86f330d05d921fb642d8 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 11 Mar 2025 10:40:13 +0100 Subject: [PATCH 62/68] :ok_hand: Comply with cosmetic comments --- docker/DOCKER.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/DOCKER.md b/docker/DOCKER.md index 1b53beb4..35adbb81 100644 --- a/docker/DOCKER.md +++ b/docker/DOCKER.md @@ -1,4 +1,5 @@ # :whale: Docker image for contributors + ### Pre-built Docker image 1. To pull the pre-built image from Docker Hub, simply run: ```bash @@ -47,6 +48,7 @@ If using a cloud provider, ensure your network volume is correctly attached to avoid losing data. ### Modifying and publishing custom Docker image + If you want to make changes to the Dockerfile, follow these steps: 1. Edit the Dockerfile to fit your needs. From 618f7db4666b817acf3fe55dbb569e09c749f548 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 11 Mar 2025 11:04:52 +0100 Subject: [PATCH 63/68] :white_check_mark: Finish coverage --- tests/models/wrappers/test_batch_ensemble.py | 4 ++-- tests/post_processing/test_scalers.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/models/wrappers/test_batch_ensemble.py b/tests/models/wrappers/test_batch_ensemble.py index 8c1b675e..3ec42082 100644 --- a/tests/models/wrappers/test_batch_ensemble.py +++ b/tests/models/wrappers/test_batch_ensemble.py @@ -3,7 +3,7 @@ from torch import nn from torch_uncertainty.layers import BatchConv2d, BatchLinear -from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble +from torch_uncertainty.models.wrappers.batch_ensemble import BatchEnsemble, batch_ensemble @pytest.fixture() @@ -43,7 +43,7 @@ def test_convert_layers(self): num_estimators = 3 model = _DummyModel(in_features, out_features) - wrapped_model = BatchEnsemble(model, num_estimators, convert_layers=True) + wrapped_model = batch_ensemble(model, num_estimators, convert_layers=True) assert wrapped_model.num_estimators == num_estimators assert isinstance(wrapped_model.model.conv, BatchConv2d) assert isinstance(wrapped_model.model.fc, BatchLinear) diff --git a/tests/post_processing/test_scalers.py b/tests/post_processing/test_scalers.py index eda1356e..fabe77f3 100644 --- a/tests/post_processing/test_scalers.py +++ b/tests/post_processing/test_scalers.py @@ -57,6 +57,12 @@ def test_errors(self): with pytest.raises(ValueError): scaler.set_temperature(val=-1) + scaler = TemperatureScaler( + model=None, + ) + with pytest.raises(ValueError, match="Cannot fit a Scaler method without model."): + scaler.fit(None) + class TestVectorScaler: """Testing the VectorScaler class.""" From 1c5c7d87c0eaa24262931e6771d1b81b8b696dc9 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 11 Mar 2025 11:11:51 +0100 Subject: [PATCH 64/68] :wrench: rework optional dependency groups --- pyproject.toml | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 567e152c..90ed590d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,11 @@ dependencies = [ ] [project.optional-dependencies] +experiments = [ + "tensorboard", + "huggingface-hub", + "safetensors", +] image = [ "scikit-image", "kornia", @@ -48,16 +53,14 @@ image = [ ] tabular = ["pandas"] dev = [ - "torch_uncertainty[image]", - "huggingface-hub", - "safetensors", + "torch_uncertainty[experiments,image]", "ruff==0.9.9", "pytest-cov", "pre-commit", "pre-commit-hooks", ] docs = [ - "sphinx<6", + "sphinx<8", "tu_sphinx_theme", "sphinx-copybutton", "sphinx-gallery", @@ -65,11 +68,11 @@ docs = [ "sphinx-codeautolink", ] all = [ - "torch_uncertainty[dev,docs,image,tabular]", + "torch_uncertainty[dev,docs,image,tabular,experiments]", + "scikit-learn", "laplace-torch", - "glest==0.0.1a1", "scipy", - "tensorboard", + "glest==0.0.1a1", ] [project.urls] From f4ffbf7262379ceb9b4ab395e2e2751152d039b0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 11 Mar 2025 11:15:06 +0100 Subject: [PATCH 65/68] :fire: Remove sklearn and np in AUSE --- torch_uncertainty/metrics/sparsification.py | 49 +++++++-------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/sparsification.py index 3f169efb..aef72d6e 100644 --- a/torch_uncertainty/metrics/sparsification.py +++ b/torch_uncertainty/metrics/sparsification.py @@ -1,35 +1,25 @@ -from importlib import util - import matplotlib.pyplot as plt -import numpy as np import torch - -if util.find_spec("sklearn"): - from sklearn.metrics import auc - - sklearn_installed = True -else: # coverage: ignore - sklearn_installed = False - from torch import Tensor from torchmetrics.metric import Metric +from torchmetrics.utilities.compute import _auc_compute from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.plot import _AX_TYPE class AUSE(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False - plot_lower_bound: float = 0.0 - plot_upper_bound: float = 100.0 - plot_legend_name: str = "Sparsification Curves" + is_differentiable = False + higher_is_better = False + full_state_update = False + plot_lower_bound = 0.0 + plot_upper_bound = 100.0 + plot_legend_name = "Sparsification Curves" scores: list[Tensor] errors: list[Tensor] def __init__(self, **kwargs) -> None: - r"""The Area Under the Sparsification Error curve (AUSE) metric to estimate + r"""The Area Under the Sparsification Error curve (AUSE) metric to evaluate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors. @@ -56,13 +46,6 @@ def __init__(self, **kwargs) -> None: self.add_state("scores", default=[], dist_reduce_fx="cat") self.add_state("errors", default=[], dist_reduce_fx="cat") - if not sklearn_installed: - raise ImportError( - "Please install torch_uncertainty with the image option" - "to use the AUSE:" - """pip install -U "torch_uncertainty[image]".""" - ) - def update(self, scores: Tensor, errors: Tensor) -> None: """Store the scores and their associated errors for later computation. @@ -89,9 +72,11 @@ def compute(self) -> Tensor: """ error_rates, optimal_error_rates = self.partial_compute() num_samples = error_rates.size(0) - x = np.arange(1, num_samples + 1) / num_samples - y = (error_rates - optimal_error_rates).numpy() - return torch.tensor([auc(x, y)]) + if num_samples < 2: + return torch.tensor([float("nan")], device=self.device) + x = torch.arange(1, num_samples + 1, device=self.device) / num_samples + y = error_rates - optimal_error_rates + return torch.tensor([_auc_compute(x, y)]) def plot( self, @@ -118,12 +103,12 @@ def plot( # Computation of AUSEC error_rates, optimal_error_rates = self.partial_compute() num_samples = error_rates.size(0) - x = np.arange(num_samples) / num_samples - y = (error_rates - optimal_error_rates).numpy() + x = torch.arange(num_samples) / num_samples + y = error_rates - optimal_error_rates - ausec = auc(x, y) + ausec = _auc_compute(x, y).cpu().item() - rejection_rates = (np.arange(num_samples) / num_samples) * 100 + rejection_rates = torch.arange(num_samples) / num_samples * 100 ax.plot( rejection_rates, From 42a8c8392435f44995bd764618fb23b27b2631f3 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 11 Mar 2025 11:15:52 +0100 Subject: [PATCH 66/68] :fire: Remove overriding of cst types --- .../adaptive_calibration_error.py | 6 ++--- .../metrics/classification/brier_score.py | 6 ++--- .../metrics/classification/disagreement.py | 6 ++--- .../metrics/classification/entropy.py | 6 ++--- .../metrics/classification/fpr.py | 6 ++--- .../metrics/classification/grouping_loss.py | 6 ++--- .../metrics/classification/mean_iou.py | 6 ++--- .../classification/mutual_information.py | 6 ++--- .../metrics/classification/risk_coverage.py | 22 +++++++++---------- .../metrics/classification/variation_ratio.py | 6 ++--- 10 files changed, 38 insertions(+), 38 deletions(-) diff --git a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py index 3bcc47fe..667a0343 100644 --- a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py +++ b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py @@ -92,9 +92,9 @@ def _ace_compute( class BinaryAdaptiveCalibrationError(Metric): r"""Adaptive Top-label Calibration Error for binary tasks.""" - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False confidences: list[Tensor] accuracies: list[Tensor] diff --git a/torch_uncertainty/metrics/classification/brier_score.py b/torch_uncertainty/metrics/classification/brier_score.py index d474c6aa..24045a60 100644 --- a/torch_uncertainty/metrics/classification/brier_score.py +++ b/torch_uncertainty/metrics/classification/brier_score.py @@ -8,9 +8,9 @@ class BrierScore(Metric): - is_differentiable: bool = True - higher_is_better: bool | None = False - full_state_update: bool = False + is_differentiable = True + higher_is_better = False + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/disagreement.py b/torch_uncertainty/metrics/classification/disagreement.py index 6dd74215..843b7ff5 100644 --- a/torch_uncertainty/metrics/classification/disagreement.py +++ b/torch_uncertainty/metrics/classification/disagreement.py @@ -8,9 +8,9 @@ class Disagreement(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = None - full_state_update: bool = False + is_differentiable = False + higher_is_better = None + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/entropy.py b/torch_uncertainty/metrics/classification/entropy.py index ef5df817..400e3968 100644 --- a/torch_uncertainty/metrics/classification/entropy.py +++ b/torch_uncertainty/metrics/classification/entropy.py @@ -6,9 +6,9 @@ class Entropy(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = None - full_state_update: bool = False + is_differentiable = False + higher_is_better = None + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/fpr.py b/torch_uncertainty/metrics/classification/fpr.py index d0779b44..ceaa5f73 100644 --- a/torch_uncertainty/metrics/classification/fpr.py +++ b/torch_uncertainty/metrics/classification/fpr.py @@ -6,9 +6,9 @@ class FPRx(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False conf: list[Tensor] targets: list[Tensor] diff --git a/torch_uncertainty/metrics/classification/grouping_loss.py b/torch_uncertainty/metrics/classification/grouping_loss.py index 137ee264..3cef6b57 100644 --- a/torch_uncertainty/metrics/classification/grouping_loss.py +++ b/torch_uncertainty/metrics/classification/grouping_loss.py @@ -26,9 +26,9 @@ def fit(self, probs: Tensor, targets: Tensor, features: Tensor) -> "GLEstimator" class GroupingLoss(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/mean_iou.py b/torch_uncertainty/metrics/classification/mean_iou.py index 5d9ef56d..8bd3b0cc 100644 --- a/torch_uncertainty/metrics/classification/mean_iou.py +++ b/torch_uncertainty/metrics/classification/mean_iou.py @@ -4,9 +4,9 @@ class MeanIntersectionOverUnion(MulticlassStatScores): - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False + is_differentiable = False + higher_is_better = True + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/mutual_information.py b/torch_uncertainty/metrics/classification/mutual_information.py index c1224eec..8e1d5691 100644 --- a/torch_uncertainty/metrics/classification/mutual_information.py +++ b/torch_uncertainty/metrics/classification/mutual_information.py @@ -6,9 +6,9 @@ class MutualInformation(Metric): - is_differentiable: bool = False - higher_is_better: bool | None = None - full_state_update: bool = False + is_differentiable = False + higher_is_better = None + full_state_update = False def __init__( self, diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py index 6f76ef1f..5ca4dae0 100644 --- a/torch_uncertainty/metrics/classification/risk_coverage.py +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -11,9 +11,9 @@ class AURC(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False scores: list[Tensor] errors: list[Tensor] @@ -44,8 +44,8 @@ def __init__(self, **kwargs) -> None: kwargs: Additional keyword arguments. Reference: - Geifman & El-Yaniv. "Selective classification for deep neural - networks." In NeurIPS, 2017. + Geifman & El-Yaniv. "Selective classification for deep neural networks." In NeurIPS, + 2017. """ super().__init__(**kwargs) self.add_state("scores", default=[], dist_reduce_fx="cat") @@ -273,9 +273,9 @@ def plot( class CovAtxRisk(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False scores: list[Tensor] errors: list[Tensor] @@ -347,9 +347,9 @@ def __init__(self, **kwargs) -> None: class RiskAtxCov(Metric): - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False + is_differentiable = False + higher_is_better = False + full_state_update = False scores: list[Tensor] errors: list[Tensor] diff --git a/torch_uncertainty/metrics/classification/variation_ratio.py b/torch_uncertainty/metrics/classification/variation_ratio.py index ff1ffd73..9a244734 100644 --- a/torch_uncertainty/metrics/classification/variation_ratio.py +++ b/torch_uncertainty/metrics/classification/variation_ratio.py @@ -8,9 +8,9 @@ class VariationRatio(Metric): - full_state_update: bool = False - is_differentiable: bool = True - higher_is_better: bool = False + full_state_update = False + is_differentiable = True + higher_is_better = False def __init__( self, From 1f4db811bd0fc43dd533dfd31c6c4c391b8fde82 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 11 Mar 2025 11:37:01 +0100 Subject: [PATCH 67/68] :bug: Fix NaN handling in AUSE & add test --- pyproject.toml | 4 ++-- tests/metrics/test_sparsification.py | 6 ++++++ .../metrics/classification/risk_coverage.py | 3 ++- torch_uncertainty/metrics/sparsification.py | 9 ++++++--- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 90ed590d..e404e640 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ dev = [ "pre-commit-hooks", ] docs = [ - "sphinx<8", + "sphinx<7", "tu_sphinx_theme", "sphinx-copybutton", "sphinx-gallery", @@ -68,7 +68,7 @@ docs = [ "sphinx-codeautolink", ] all = [ - "torch_uncertainty[dev,docs,image,tabular,experiments]", + "torch_uncertainty[dev,docs,tabular]", "scikit-learn", "laplace-torch", "scipy", diff --git a/tests/metrics/test_sparsification.py b/tests/metrics/test_sparsification.py index e89df5fe..114183b0 100644 --- a/tests/metrics/test_sparsification.py +++ b/tests/metrics/test_sparsification.py @@ -33,3 +33,9 @@ def test_plot(self) -> None: assert ax.get_xlabel() == "Rejection Rate (%)" assert ax.get_ylabel() == "Error Rate (%)" plt.close(fig) + + def test_compute_nan(self) -> None: + probs = torch.Tensor([[0.1, 0.9]]) + targets = torch.Tensor([1]).long() + metric = AUSE() + assert torch.isnan(metric(probs, targets)).all() diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py index 5ca4dae0..e2682313 100644 --- a/torch_uncertainty/metrics/classification/risk_coverage.py +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -87,6 +87,7 @@ def compute(self) -> Tensor: num_samples = error_rates.size(0) if num_samples < 2: return torch.tensor([float("nan")], device=self.device) + # There is no error rate associated to 0 coverage: starting at 1 cov = torch.arange(1, num_samples + 1, device=self.device) / num_samples return _auc_compute(cov, error_rates) / (1 - 1 / num_samples) @@ -115,7 +116,7 @@ def plot( error_rates = self.partial_compute().cpu().flip(0) num_samples = error_rates.size(0) - x = torch.arange(num_samples) / num_samples + x = torch.arange(1, num_samples + 1) / num_samples aurc = _auc_compute(x, error_rates).cpu().item() # reduce plot size diff --git a/torch_uncertainty/metrics/sparsification.py b/torch_uncertainty/metrics/sparsification.py index aef72d6e..cee99884 100644 --- a/torch_uncertainty/metrics/sparsification.py +++ b/torch_uncertainty/metrics/sparsification.py @@ -59,6 +59,9 @@ def update(self, scores: Tensor, errors: Tensor) -> None: def partial_compute(self) -> tuple[Tensor, Tensor]: scores = dim_zero_cat(self.scores) errors = dim_zero_cat(self.errors) + if scores.shape[0] < 2: + nan = torch.tensor([float("nan")], device=self.device) + return nan, nan error_rates = _ause_rejection_rate_compute(scores, errors) optimal_error_rates = _ause_rejection_rate_compute(errors, errors) return error_rates.cpu(), optimal_error_rates.cpu() @@ -71,10 +74,10 @@ def compute(self) -> Tensor: Tensor: The AUSE. """ error_rates, optimal_error_rates = self.partial_compute() - num_samples = error_rates.size(0) - if num_samples < 2: + if torch.isnan(error_rates[0]).item(): return torch.tensor([float("nan")], device=self.device) - x = torch.arange(1, num_samples + 1, device=self.device) / num_samples + num_samples = error_rates.size(0) + x = torch.arange(0, num_samples, device=self.device) / num_samples y = error_rates - optimal_error_rates return torch.tensor([_auc_compute(x, y)]) From 1a07f0d58e29deb6c29bf8a8989a2c9c1a686e02 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 11 Mar 2025 11:51:49 +0100 Subject: [PATCH 68/68] :racehorse: Don't download MNISTC in tests --- tests/datamodules/classification/test_mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/datamodules/classification/test_mnist.py b/tests/datamodules/classification/test_mnist.py index 73304283..f6ab4f8e 100644 --- a/tests/datamodules/classification/test_mnist.py +++ b/tests/datamodules/classification/test_mnist.py @@ -37,6 +37,7 @@ def test_mnist_cutout(self): dm.dataset = DummyClassificationDataset dm.ood_dataset = DummyClassificationDataset + dm.shift_dataset = DummyClassificationDataset dm.setup("fit") dm.setup("test") dm.train_dataloader()