diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c5fcfa3..7aa85a6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,40 +1,47 @@ repos: -- repo: local - hooks: - - id: end-of-file-fixer - name: end-of-file-fixer - entry: end-of-file-fixer - language: system - types: [text] - - id: check-toml - name: check-toml - entry: check-toml - language: system - types: [toml] - - id: check-yaml - name: check-yaml - language: system - entry: check-yaml - types: [yaml] - - id: check-merge-conflict - name: check-merge-conflict - entry: check-merge-conflict - language: system - - id: check-added-large-files - name: check-added-large-files - entry: check-added-large-files - language: system - - id: ruff-check - name: ruff-check - entry: ruff check - language: python - types_or: [python, pyi] - require_serial: true - args: [--force-exclude, --fix, --exit-non-zero-on-fix, --output-format, concise] - exclude: ^auto_tutorial_source/ - - id: ruff-format - name: ruff-format - entry: ruff format - language: python - types_or: [python, pyi] - exclude: ^auto_tutorials_source/ + - repo: local + hooks: + - id: end-of-file-fixer + name: end-of-file-fixer + entry: end-of-file-fixer + language: system + types: [text] + - id: check-toml + name: check-toml + entry: check-toml + language: system + types: [toml] + - id: check-yaml + name: check-yaml + language: system + entry: check-yaml + types: [yaml] + - id: check-merge-conflict + name: check-merge-conflict + entry: check-merge-conflict + language: system + - id: check-added-large-files + name: check-added-large-files + entry: check-added-large-files + language: system + - id: ruff-check + name: ruff-check + entry: ruff check + language: python + types_or: [python, pyi] + require_serial: true + args: + [ + --force-exclude, + --fix, + --exit-non-zero-on-fix, + --output-format, + concise, + ] + exclude: ^auto_tutorial_source/ + - id: ruff-format + name: ruff-format + entry: ruff format + language: python + types_or: [python, pyi] + exclude: ^auto_tutorials_source/ diff --git a/docker/DOCKER.md b/docker/DOCKER.md index 4e2de285..2eb96cc1 100644 --- a/docker/DOCKER.md +++ b/docker/DOCKER.md @@ -3,29 +3,37 @@ This Docker image is designed for users and contributors who want to run experiments with `torch-uncertainty` on remote virtual machines with GPU support. It is particularly useful for those who do not have access to a local GPU and need a pre-configured environment for development and experimentation. --- + ## How to Use The Docker Image + ### Step 1: Fork the Repository Before proceeding, ensure you have forked the `torch-uncertainty` repository to your own GitHub account. You can do this by visiting the [torch-uncertainty GitHub repository](https://github.com/ENSTA-U2IS-AI/torch-uncertainty) and clicking the **Fork** button in the top-right corner. Once forked, clone your forked repository to your local machine: + ```bash git clone git@github.com:/torch-uncertainty.git cd torch-uncertainty ``` > ### ⚠️ IMPORTANT NOTE: Keep Your Fork Synced -> +> > **To ensure that you are working with the latest stable version and bug fixes, you must manually sync your fork with the upstream repository before building the Docker image. Failure to sync your fork may result in outdated dependencies or missing bug fixes in the Docker image.** ### Step 2: Build the Docker image locally + Build the modified image locally and push it to a Docker registry: -``` + +```bash docker build -t my-torch-uncertainty-docker:version . docker push my-dockerhub-user/my-torch-uncertainty-image:version ``` + ### Step 3: Set environment variables on your VM + Connect to you VM and set the following environment variables: + ```bash export VM_SSH_PUBLIC_KEY="$(cat ~/.ssh/id_rsa.pub)" export GITHUB_SSH_PRIVATE_KEY="$(cat ~/.ssh/id_rsa)" @@ -36,6 +44,7 @@ export USE_COMPACT_SHELL_PROMPT=true ``` Here is a brief explanation of the environment variables used in the Docker setup: + - **`VM_SSH_PUBLIC_KEY`**: The public SSH key used to authenticate with the container via SSH. - **`GITHUB_SSH_PRIVATE_KEY`**: The private SSH key used to authenticate with GitHub for cloning and pushing repositories. - **`GITHUB_USER`**: The GitHub username used to clone the repository during the first-time setup. @@ -44,8 +53,10 @@ Here is a brief explanation of the environment variables used in the Docker setu - **`USE_COMPACT_SHELL_PROMPT`** (optional): Enables a compact and colorized shell prompt inside the container if set to `"true"`. ### Step 4: Run the Docker container + First, authenticate with your Docker registry if you use a private registry. Then run the following command to run the Docker image from your docker registriy + ```bash docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \ -e VM_SSH_PUBLIC_KEY \ @@ -58,10 +69,13 @@ docker run --rm -it --gpus all -p 8888:8888 -p 22:22 \ ``` ### Step 5: Connect to your container + Once the container is up and running, you can connect to it via SSH: + ```bash ssh -i /path/to/private_key root@ -p ``` + Replace `` and `` with the host and port of your VM, and `/path/to/private_key` with the private key that corresponds to `VM_SSH_PUBLIC_KEY`. @@ -74,7 +88,7 @@ If using a cloud provider, ensure your network volume is correctly attached to a ## Remote Development This Docker setup also allows for remote development on the VM, since GitHub SSH access is set up and the whole repo is cloned to the VM from your GitHub fork. -For example, you can seamlessly connect your VS Code editor to your remote VM and run experiments, as if on your local machine but with the GPU acceleration of your VM. +For example, you can seamlessly connect your VS Code editor to your remote VM and run experiments, as if on your local machine but with the GPU acceleration of your VM. See [VS Code Remote Development](https://code.visualstudio.com/docs/remote/remote-overview) for further details. ## Streamline setup with your Cloud provider of choice diff --git a/experiments/depth/kitti/configs/bts.yaml b/experiments/depth/kitti/configs/bts.yaml index ee27a0b8..55bdc780 100644 --- a/experiments/depth/kitti/configs/bts.yaml +++ b/experiments/depth/kitti/configs/bts.yaml @@ -13,16 +13,16 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val/reg/SILog - mode: min - save_last: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/SILog + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: - loss: + loss: class_path: torch_uncertainty.metrics.SILog init_args: sqrt: true @@ -34,7 +34,7 @@ model: data: root: ./data batch_size: 4 - crop_size: + crop_size: - 352 - 704 eval_size: diff --git a/experiments/depth/nyu/configs/bts.yaml b/experiments/depth/nyu/configs/bts.yaml index 023869de..db399bfe 100644 --- a/experiments/depth/nyu/configs/bts.yaml +++ b/experiments/depth/nyu/configs/bts.yaml @@ -13,16 +13,16 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val/reg/SILog - mode: min - save_last: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/reg/SILog + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: - loss: + loss: class_path: torch_uncertainty.metrics.SILog init_args: sqrt: true @@ -34,7 +34,7 @@ model: data: root: ./data batch_size: 8 - crop_size: + crop_size: - 416 - 544 eval_size: diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml index 658623eb..32d25086 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 13 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml index a8d3fed9..a7563f94 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 13 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml index 291ae2af..ff8111fd 100644 --- a/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/boston/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 13 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml index 4d441e2e..9bdfa9f2 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml index 6bdabac0..bcc61343 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml index e078d3ec..05cfd417 100644 --- a/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/concrete/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml index 5af41cae..86cc7e29 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml index abc6fd38..41bc571d 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml index e078d3ec..05cfd417 100644 --- a/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/energy-efficiency/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml index c274098b..12b7b5e6 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std dist_family: laplace diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml index 2ecdee6d..2b2eb698 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std dist_family: normal diff --git a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml index 94a25fe0..3cd74e4a 100644 --- a/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/kin8nm/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 8 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml index ac8b58b3..76d5394e 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 16 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml index 1286d72a..43cb1c2b 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 16 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml index efcdb7a9..f68c63a8 100644 --- a/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/naval-propulsion-plant/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 16 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml index 95b87412..80ffa4ae 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 4 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml index ab0fd5d1..5adc86d9 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 4 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml index 11f1242e..e83ea686 100644 --- a/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/power-plant/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 4 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml index 26c98034..809c6d88 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 9 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml index 1c03c5a9..b6aa744f 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 9 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml index 2fc24032..3fe2bf3b 100644 --- a/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/protein/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 9 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml index c40efc24..f5e422e3 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 11 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml index a246d27f..e1c552b8 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 11 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml index 8dbc05d8..ee8e813d 100644 --- a/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/wine-quality-red/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 11 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml index ff3b66cf..15658cda 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/laplace.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 6 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: laplace diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml index 4d5e154b..9f836719 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/normal.yaml @@ -13,22 +13,22 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - init_args: - probabilistic: true - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/NLL - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + init_args: + probabilistic: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/NLL + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 6 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: torch_uncertainty.losses.DistributionNLLLoss version: std distribution: normal diff --git a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml index 15fca9b0..51640c4b 100644 --- a/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml +++ b/experiments/regression/uci_datasets/configs/yacht/mlp/point_wise.yaml @@ -13,20 +13,20 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TURegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val/reg/MSE - patience: 1000 - check_finite: true + - class_path: torch_uncertainty.callbacks.TURegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/reg/MSE + patience: 1000 + check_finite: true model: output_dim: 1 in_features: 6 - hidden_dims: - - 50 + hidden_dims: + - 50 loss: MSELoss version: std data: diff --git a/experiments/segmentation/camvid/configs/deeplab.yaml b/experiments/segmentation/camvid/configs/deeplab.yaml index aad8633c..613648a3 100644 --- a/experiments/segmentation/camvid/configs/deeplab.yaml +++ b/experiments/segmentation/camvid/configs/deeplab.yaml @@ -13,10 +13,10 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: num_classes: 11 loss: CrossEntropyLoss diff --git a/experiments/segmentation/camvid/configs/segformer.yaml b/experiments/segmentation/camvid/configs/segformer.yaml index b3b2bcbf..bc46c03e 100644 --- a/experiments/segmentation/camvid/configs/segformer.yaml +++ b/experiments/segmentation/camvid/configs/segformer.yaml @@ -11,10 +11,10 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: num_classes: 11 loss: CrossEntropyLoss @@ -28,5 +28,5 @@ optimizer: lr: 0.01 lr_scheduler: milestones: - - 30 + - 30 gamma: 0.1 diff --git a/experiments/segmentation/cityscapes/configs/deeplab.yaml b/experiments/segmentation/cityscapes/configs/deeplab.yaml index 0a3c99df..7d0c2415 100644 --- a/experiments/segmentation/cityscapes/configs/deeplab.yaml +++ b/experiments/segmentation/cityscapes/configs/deeplab.yaml @@ -13,10 +13,10 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: num_classes: 19 loss: CrossEntropyLoss @@ -30,8 +30,8 @@ data: batch_size: 8 crop_size: 768 eval_size: - - 1024 - - 2048 + - 1024 + - 2048 num_workers: 8 optimizer: lr: 1e-2 diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index 6169c9e1..2ccb3745 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -12,10 +12,10 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: num_classes: 19 loss: CrossEntropyLoss @@ -26,8 +26,8 @@ data: batch_size: 8 crop_size: 1024 eval_size: - - 1024 - - 2048 + - 1024 + - 2048 num_workers: 8 optimizer: lr: 6e-5 diff --git a/experiments/segmentation/muad/configs/muad/deeplab/deeplabv3+.yaml b/experiments/segmentation/muad/configs/muad/deeplab/deeplabv3+.yaml index bb60b037..bb763b84 100644 --- a/experiments/segmentation/muad/configs/muad/deeplab/deeplabv3+.yaml +++ b/experiments/segmentation/muad/configs/muad/deeplab/deeplabv3+.yaml @@ -12,10 +12,10 @@ trainer: name: deeplabv3+ default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: model: class_path: torch_uncertainty.models.segmentation.deep_lab_v3_resnet50 @@ -30,28 +30,28 @@ model: class_path: torch.Tensor dict_kwargs: data: - - 4.1712 - - 19.4603 - - 3.2345 - - 49.2588 - - 36.2490 - - 34.0272 - - 47.0651 - - 49.7145 - - 12.4178 - - 48.3962 - - 14.3876 - - 32.8862 - - 5.2729 - - 17.8703 - - 50.4984 + - 4.1712 + - 19.4603 + - 3.2345 + - 49.2588 + - 36.2490 + - 34.0272 + - 47.0651 + - 49.7145 + - 12.4178 + - 48.3962 + - 14.3876 + - 32.8862 + - 5.2729 + - 17.8703 + - 50.4984 data: root: ./data batch_size: 12 crop_size: 768 eval_size: - - 512 - - 1024 + - 512 + - 1024 optimizer: class_path: torch.optim.SGD init_args: diff --git a/experiments/segmentation/muad/configs/muad/unet/batch_ensemble.yaml b/experiments/segmentation/muad/configs/muad/unet/batch_ensemble.yaml index 66969dc8..af5d80e6 100644 --- a/experiments/segmentation/muad/configs/muad/unet/batch_ensemble.yaml +++ b/experiments/segmentation/muad/configs/muad/unet/batch_ensemble.yaml @@ -56,6 +56,7 @@ data: root: ./data batch_size: 16 crop_size: 256 + eval_ood: true eval_size: - 512 - 1024 diff --git a/experiments/segmentation/muad/configs/muad/unet/deep_ensembles.yaml b/experiments/segmentation/muad/configs/muad/unet/deep_ensembles.yaml index f5027850..c56a26cc 100644 --- a/experiments/segmentation/muad/configs/muad/unet/deep_ensembles.yaml +++ b/experiments/segmentation/muad/configs/muad/unet/deep_ensembles.yaml @@ -12,10 +12,10 @@ trainer: name: deep_ensembles default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: model: class_path: torch_uncertainty.models.deep_ensembles @@ -28,7 +28,6 @@ model: bilinear: true num_estimators: 4 task: segmentation - ckpt_paths: ./logs/muad/unet/deep_ensembles/version_0/checkpoints num_classes: 15 loss: class_path: torch.nn.CrossEntropyLoss @@ -37,21 +36,21 @@ model: class_path: torch.Tensor dict_kwargs: data: - - 4.1712 - - 19.4603 - - 3.2345 - - 49.2588 - - 36.2490 - - 34.0272 - - 47.0651 - - 49.7145 - - 12.4178 - - 48.3962 - - 14.3876 - - 32.8862 - - 5.2729 - - 17.8703 - - 50.4984 + - 4.1712 + - 19.4603 + - 3.2345 + - 49.2588 + - 36.2490 + - 34.0272 + - 47.0651 + - 49.7145 + - 12.4178 + - 48.3962 + - 14.3876 + - 32.8862 + - 5.2729 + - 17.8703 + - 50.4984 format_batch_fn: class_path: torch_uncertainty.transforms.RepeatTarget init_args: @@ -60,9 +59,10 @@ data: root: ./data batch_size: 32 crop_size: 256 + eval_ood: true eval_size: - - 512 - - 1024 + - 512 + - 1024 optimizer: class_path: torch.optim.Adam init_args: @@ -72,8 +72,8 @@ lr_scheduler: class_path: torch.optim.lr_scheduler.MultiStepLR init_args: milestones: - - 20 - - 40 - - 60 - - 80 + - 20 + - 40 + - 60 + - 80 gamma: 0.5 diff --git a/experiments/segmentation/muad/configs/muad/unet/masksemble.yaml b/experiments/segmentation/muad/configs/muad/unet/masksemble.yaml index 5260887c..bb7d41ee 100644 --- a/experiments/segmentation/muad/configs/muad/unet/masksemble.yaml +++ b/experiments/segmentation/muad/configs/muad/unet/masksemble.yaml @@ -4,6 +4,7 @@ trainer: accelerator: gpu precision: bf16-mixed max_epochs: 100 + accumulate_grad_batches: 4 logger: class_path: lightning.pytorch.loggers.TensorBoardLogger init_args: @@ -11,10 +12,10 @@ trainer: name: masksemble default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: model: class_path: torch_uncertainty.models.segmentation.masked_unet @@ -32,32 +33,33 @@ model: class_path: torch.Tensor dict_kwargs: data: - - 4.1712 - - 19.4603 - - 3.2345 - - 49.2588 - - 36.2490 - - 34.0272 - - 47.0651 - - 49.7145 - - 12.4178 - - 48.3962 - - 14.3876 - - 32.8862 - - 5.2729 - - 17.8703 - - 50.4984 + - 4.1712 + - 19.4603 + - 3.2345 + - 49.2588 + - 36.2490 + - 34.0272 + - 47.0651 + - 49.7145 + - 12.4178 + - 48.3962 + - 14.3876 + - 32.8862 + - 5.2729 + - 17.8703 + - 50.4984 format_batch_fn: class_path: torch_uncertainty.transforms.RepeatTarget init_args: num_repeats: 4 data: root: ./data - batch_size: 32 + batch_size: 8 crop_size: 256 + eval_ood: true eval_size: - - 512 - - 1024 + - 512 + - 1024 optimizer: class_path: torch.optim.Adam init_args: @@ -67,8 +69,8 @@ lr_scheduler: class_path: torch.optim.lr_scheduler.MultiStepLR init_args: milestones: - - 20 - - 40 - - 60 - - 80 + - 20 + - 40 + - 60 + - 80 gamma: 0.5 diff --git a/experiments/segmentation/muad/configs/muad/unet/mc_dropout.yaml b/experiments/segmentation/muad/configs/muad/unet/mc_dropout.yaml new file mode 100644 index 00000000..5d0d41df --- /dev/null +++ b/experiments/segmentation/muad/configs/muad/unet/mc_dropout.yaml @@ -0,0 +1,76 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + precision: bf16-mixed + max_epochs: 100 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/muad/unet + name: mc_dropout + default_hp_metric: false + callbacks: + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +model: + model: + class_path: torch_uncertainty.models.mc_dropout + init_args: + model: + class_path: torch_uncertainty.models.segmentation.unet + init_args: + in_channels: 3 + num_classes: 15 + bilinear: true + dropout_rate: 0.1 + num_estimators: 10 + on_batch: false + num_classes: 15 + loss: + class_path: torch.nn.CrossEntropyLoss + init_args: + weight: + class_path: torch.Tensor + dict_kwargs: + data: + - 4.1712 + - 19.4603 + - 3.2345 + - 49.2588 + - 36.2490 + - 34.0272 + - 47.0651 + - 49.7145 + - 12.4178 + - 48.3962 + - 14.3876 + - 32.8862 + - 5.2729 + - 17.8703 + - 50.4984 +data: + root: ./data + batch_size: 32 + crop_size: 256 + eval_ood: true + eval_size: + - 512 + - 1024 +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.001 + weight_decay: 1e-4 +lr_scheduler: + class_path: torch.optim.lr_scheduler.MultiStepLR + init_args: + milestones: + - 20 + - 40 + - 60 + - 80 + gamma: 0.5 diff --git a/experiments/segmentation/muad/configs/muad/unet/mimo.yaml b/experiments/segmentation/muad/configs/muad/unet/mimo.yaml index 8dd4c4d4..2be464fa 100644 --- a/experiments/segmentation/muad/configs/muad/unet/mimo.yaml +++ b/experiments/segmentation/muad/configs/muad/unet/mimo.yaml @@ -12,10 +12,10 @@ trainer: name: mimo default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: model: class_path: torch_uncertainty.models.segmentation.mimo_unet @@ -32,21 +32,21 @@ model: class_path: torch.Tensor dict_kwargs: data: - - 4.1712 - - 19.4603 - - 3.2345 - - 49.2588 - - 36.2490 - - 34.0272 - - 47.0651 - - 49.7145 - - 12.4178 - - 48.3962 - - 14.3876 - - 32.8862 - - 5.2729 - - 17.8703 - - 50.4984 + - 4.1712 + - 19.4603 + - 3.2345 + - 49.2588 + - 36.2490 + - 34.0272 + - 47.0651 + - 49.7145 + - 12.4178 + - 48.3962 + - 14.3876 + - 32.8862 + - 5.2729 + - 17.8703 + - 50.4984 format_batch_fn: class_path: torch_uncertainty.transforms.MIMOBatchFormat init_args: @@ -57,9 +57,10 @@ data: root: ./data batch_size: 32 crop_size: 256 + eval_ood: true eval_size: - - 512 - - 1024 + - 512 + - 1024 optimizer: class_path: torch.optim.Adam init_args: @@ -69,8 +70,8 @@ lr_scheduler: class_path: torch.optim.lr_scheduler.MultiStepLR init_args: milestones: - - 20 - - 40 - - 60 - - 80 + - 20 + - 40 + - 60 + - 80 gamma: 0.5 diff --git a/experiments/segmentation/muad/configs/muad/unet/packed.yaml b/experiments/segmentation/muad/configs/muad/unet/packed.yaml index 7870e9f0..80507c49 100644 --- a/experiments/segmentation/muad/configs/muad/unet/packed.yaml +++ b/experiments/segmentation/muad/configs/muad/unet/packed.yaml @@ -12,10 +12,10 @@ trainer: name: packed default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: model: class_path: torch_uncertainty.models.segmentation.packed_unet @@ -34,21 +34,21 @@ model: class_path: torch.Tensor dict_kwargs: data: - - 4.1712 - - 19.4603 - - 3.2345 - - 49.2588 - - 36.2490 - - 34.0272 - - 47.0651 - - 49.7145 - - 12.4178 - - 48.3962 - - 14.3876 - - 32.8862 - - 5.2729 - - 17.8703 - - 50.4984 + - 4.1712 + - 19.4603 + - 3.2345 + - 49.2588 + - 36.2490 + - 34.0272 + - 47.0651 + - 49.7145 + - 12.4178 + - 48.3962 + - 14.3876 + - 32.8862 + - 5.2729 + - 17.8703 + - 50.4984 format_batch_fn: class_path: torch_uncertainty.transforms.RepeatTarget init_args: @@ -57,9 +57,10 @@ data: root: ./data batch_size: 32 crop_size: 256 + eval_ood: true eval_size: - - 512 - - 1024 + - 512 + - 1024 optimizer: class_path: torch.optim.Adam init_args: @@ -69,8 +70,8 @@ lr_scheduler: class_path: torch.optim.lr_scheduler.MultiStepLR init_args: milestones: - - 20 - - 40 - - 60 - - 80 + - 20 + - 40 + - 60 + - 80 gamma: 0.5 diff --git a/experiments/segmentation/muad/configs/muad/unet/standard.yaml b/experiments/segmentation/muad/configs/muad/unet/standard.yaml index 83a1f8d5..c5eddaa7 100644 --- a/experiments/segmentation/muad/configs/muad/unet/standard.yaml +++ b/experiments/segmentation/muad/configs/muad/unet/standard.yaml @@ -12,10 +12,10 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: model: class_path: torch_uncertainty.models.segmentation.unet @@ -23,6 +23,7 @@ model: in_channels: 3 num_classes: 15 bilinear: true + dropout_rate: 0.5 num_classes: 15 loss: class_path: torch.nn.CrossEntropyLoss @@ -31,28 +32,29 @@ model: class_path: torch.Tensor dict_kwargs: data: - - 4.1712 - - 19.4603 - - 3.2345 - - 49.2588 - - 36.2490 - - 34.0272 - - 47.0651 - - 49.7145 - - 12.4178 - - 48.3962 - - 14.3876 - - 32.8862 - - 5.2729 - - 17.8703 - - 50.4984 + - 4.1712 + - 19.4603 + - 3.2345 + - 49.2588 + - 36.2490 + - 34.0272 + - 47.0651 + - 49.7145 + - 12.4178 + - 48.3962 + - 14.3876 + - 32.8862 + - 5.2729 + - 17.8703 + - 50.4984 data: root: ./data batch_size: 32 crop_size: 256 + eval_ood: true eval_size: - - 512 - - 1024 + - 512 + - 1024 optimizer: class_path: torch.optim.Adam init_args: @@ -62,8 +64,8 @@ lr_scheduler: class_path: torch.optim.lr_scheduler.MultiStepLR init_args: milestones: - - 20 - - 40 - - 60 - - 80 + - 20 + - 40 + - 60 + - 80 gamma: 0.5 diff --git a/experiments/segmentation/muad/configs/muad_small/unet/mc_dropout.yaml b/experiments/segmentation/muad/configs/muad_small/unet/mc_dropout.yaml new file mode 100644 index 00000000..d5cea608 --- /dev/null +++ b/experiments/segmentation/muad/configs/muad_small/unet/mc_dropout.yaml @@ -0,0 +1,72 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + max_epochs: 50 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/muad-small/unet + name: standard + default_hp_metric: false + callbacks: + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +model: + model: + class_path: torch_uncertainty.models.mc_dropout + init_args: + model: + class_path: torch_uncertainty.models.segmentation.small_unet + init_args: + in_channels: 3 + num_classes: 15 + bilinear: true + dropout_rate: 0.5 + num_estimators: 4 + num_classes: 15 + loss: + class_path: torch.nn.CrossEntropyLoss + init_args: + weight: + class_path: torch.Tensor + dict_kwargs: + data: + - 4.3817 + - 19.7927 + - 3.3011 + - 48.8031 + - 36.2141 + - 33.0049 + - 47.5130 + - 48.8560 + - 12.4401 + - 48.0600 + - 14.4807 + - 30.8762 + - 4.7467 + - 19.3913 + - 50.4984 +data: + root: ./data + batch_size: 10 + version: small + eval_ood: true + eval_size: + - 256 + - 512 + num_workers: 10 +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1e-3 + weight_decay: 2e-4 +lr_scheduler: + class_path: torch.optim.lr_scheduler.StepLR + init_args: + step_size: 20 + gamma: 0.1 diff --git a/experiments/segmentation/muad/configs/muad_small/unet/standard.yaml b/experiments/segmentation/muad/configs/muad_small/unet/standard.yaml index e614f875..e0bbc0bb 100644 --- a/experiments/segmentation/muad/configs/muad_small/unet/standard.yaml +++ b/experiments/segmentation/muad/configs/muad_small/unet/standard.yaml @@ -12,10 +12,10 @@ trainer: name: standard default_hp_metric: false callbacks: - - class_path: torch_uncertainty.callbacks.TUSegCheckpoint - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: step + - class_path: torch_uncertainty.callbacks.TUSegCheckpoint + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: model: class_path: torch_uncertainty.models.segmentation.small_unet @@ -31,28 +31,29 @@ model: class_path: torch.Tensor dict_kwargs: data: - - 4.3817 - - 19.7927 - - 3.3011 - - 48.8031 - - 36.2141 - - 33.0049 - - 47.5130 - - 48.8560 - - 12.4401 - - 48.0600 - - 14.4807 - - 30.8762 - - 4.7467 - - 19.3913 - - 50.4984 + - 4.3817 + - 19.7927 + - 3.3011 + - 48.8031 + - 36.2141 + - 33.0049 + - 47.5130 + - 48.8560 + - 12.4401 + - 48.0600 + - 14.4807 + - 30.8762 + - 4.7467 + - 19.3913 + - 50.4984 data: root: ./data batch_size: 10 version: small + eval_ood: true eval_size: - - 256 - - 512 + - 256 + - 512 num_workers: 10 optimizer: class_path: torch.optim.Adam diff --git a/experiments/segmentation/muad/configs/segformer.yaml b/experiments/segmentation/muad/configs/segformer.yaml index a0c110e0..6c29986a 100644 --- a/experiments/segmentation/muad/configs/segformer.yaml +++ b/experiments/segmentation/muad/configs/segformer.yaml @@ -15,8 +15,8 @@ data: batch_size: 8 crop_size: 1024 eval_size: - - 1024 - - 2048 + - 1024 + - 2048 num_workers: 30 optimizer: lr: 6e-5 diff --git a/torch_uncertainty/datamodules/segmentation/muad.py b/torch_uncertainty/datamodules/segmentation/muad.py index e64ab24f..30c51e0a 100644 --- a/torch_uncertainty/datamodules/segmentation/muad.py +++ b/torch_uncertainty/datamodules/segmentation/muad.py @@ -27,6 +27,7 @@ def __init__( batch_size: int, version: Literal["full", "small"] = "full", eval_batch_size: int | None = None, + eval_ood: bool = False, crop_size: _size_2_t = 1024, eval_size: _size_2_t = (1024, 2048), train_transform: nn.Module | None = None, @@ -45,6 +46,9 @@ def __init__( ``full`` or ``small``. Defaults to ``full``. eval_batch_size (int | None) : Number of samples per batch during evaluation (val and test). Set to batch_size if None. Defaults to None. + eval_ood (bool): Whether to evaluate on the OOD dataset. Defaults to + ``False``. If set to ``True``, the OOD dataset will be used for + evaluation in addition of the test dataset. crop_size (sequence or int, optional): Desired input image and segmentation mask sizes during training. If :attr:`crop_size` is an int instead of sequence like :math:`(H, W)`, a square crop @@ -137,9 +141,14 @@ def __init__( self.dataset = MUAD self.version = version + self.eval_ood = eval_ood self.crop_size = _pair(crop_size) self.eval_size = _pair(eval_size) + # FIXME: should be the same split names (update huggingface dataset) + self.test_split = "test" if version == "small" else "test_id" + self.ood_split = "ood" if version == "small" else "test_ood" + if train_transform is not None: self.train_transform = train_transform else: @@ -212,6 +221,22 @@ def prepare_data(self) -> None: # coverage: ignore self.dataset( root=self.root, split="val", version=self.version, target_type="semantic", download=True ) + self.dataset( + root=self.root, + split=self.test_split, + version=self.version, + target_type="semantic", + download=True, + ) + + if self.eval_ood: + self.dataset( + root=self.root, + split=self.ood_split, + version=self.version, + target_type="semantic", + download=True, + ) def setup(self, stage: str | None = None) -> None: if stage == "fit" or stage is None: @@ -242,11 +267,26 @@ def setup(self, stage: str | None = None) -> None: if stage == "test" or stage is None: self.test = self.dataset( root=self.root, - split="val", + split=self.test_split, version=self.version, target_type="semantic", transforms=self.test_transform, ) + if self.eval_ood: + self.ood = self.dataset( + root=self.root, + split=self.ood_split, + version=self.version, + target_type="semantic", + transforms=self.test_transform, + ) if stage not in ["fit", "test", None]: raise ValueError(f"Stage {stage} is not supported.") + + def test_dataloader(self) -> torch.utils.data.DataLoader: + """Returns the test dataloader.""" + dataloader = [self._data_loader(self.get_test_set(), training=False, shuffle=False)] + if self.eval_ood: + dataloader.append(self._data_loader(self.get_ood_set(), training=False, shuffle=False)) + return dataloader diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index 735c1983..35942fb2 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -61,7 +61,6 @@ class MUAD(VisionDataset): "val": 492, "test_id": 551, "test_ood": 1668, - "test_id_no_shadow": 102, "test_id_low_adv": 605, "test_id_high_adv": 602, "test_ood_low_adv": 1552, @@ -168,7 +167,7 @@ def __init__( if split not in self.huggingface_splits[version]: raise ValueError( - f"split must be one of {self.huggingface_splits[version].keys()}. Got {split}." + f"split must be one of {self.huggingface_splits[version]}. Got {split}." ) self.split = split self.version = version diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 840b32c6..d201fa68 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -29,4 +29,9 @@ SILog, ThresholdAccuracy, ) +from .segmentation import ( + SegmentationBinaryAUROC, + SegmentationBinaryAveragePrecision, + SegmentationFPR95, +) from .sparsification import AUSE diff --git a/torch_uncertainty/metrics/segmentation/__init__.py b/torch_uncertainty/metrics/segmentation/__init__.py new file mode 100644 index 00000000..793c86a0 --- /dev/null +++ b/torch_uncertainty/metrics/segmentation/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 +from .seg_binary_auroc import SegmentationBinaryAUROC +from .seg_binary_average_precision import SegmentationBinaryAveragePrecision +from .seg_fpr95 import SegmentationFPR95 diff --git a/torch_uncertainty/metrics/segmentation/seg_binary_auroc.py b/torch_uncertainty/metrics/segmentation/seg_binary_auroc.py new file mode 100644 index 00000000..a2a2a89b --- /dev/null +++ b/torch_uncertainty/metrics/segmentation/seg_binary_auroc.py @@ -0,0 +1,46 @@ +from typing import Any + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.classification import BinaryAUROC + + +class SegmentationBinaryAUROC(Metric): + is_differentiable = False + higher_is_better = True + full_state_update = False + + def __init__( + self, + max_fpr: float | None = None, + thresholds: int | list[float] | Tensor | None = None, + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ): + """SegmentationBinaryAUROC computes the Area Under the Receiver Operating Characteristic Curve (AUROC) + for binary segmentation tasks. It aggregates the AUROC across batches and computes the average AUROC + over all batches processed. + """ + super().__init__(**kwargs) + self.auroc_metric = BinaryAUROC( + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + **kwargs, + ) + self.add_state("binary_auroc", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + batch_size = preds.size(0) + auroc = self.auroc_metric(preds, target) + self.binary_auroc += auroc * batch_size + self.total += batch_size + + def compute(self) -> Tensor: + if self.total == 0: + return torch.tensor(0.0, device=self.binary_auroc.device) + return self.binary_auroc / self.total diff --git a/torch_uncertainty/metrics/segmentation/seg_binary_average_precision.py b/torch_uncertainty/metrics/segmentation/seg_binary_average_precision.py new file mode 100644 index 00000000..c68dd325 --- /dev/null +++ b/torch_uncertainty/metrics/segmentation/seg_binary_average_precision.py @@ -0,0 +1,40 @@ +from typing import Any + +import torch +from torch import Tensor +from torchmetrics import Metric +from torchmetrics.classification import BinaryAveragePrecision + + +class SegmentationBinaryAveragePrecision(Metric): + is_differentiable = False + higher_is_better = True + full_state_update = False + + def __init__( + self, + thresholds: int | list[float] | Tensor | None = None, + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ): + """SegmentationBinaryAveragePrecision computes the Average Precision (AP) for binary segmentation tasks. + It aggregates the mean AP across batches and computes the average AP over all batches processed. + """ + super().__init__(**kwargs) + self.aupr_metric = BinaryAveragePrecision( + thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args, **kwargs + ) + self.add_state("binary_aupr", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + batch_size = preds.size(0) + aupr = self.aupr_metric(preds, target) + self.binary_aupr += aupr * batch_size + self.total += batch_size + + def compute(self) -> Tensor: + if self.total == 0: + return torch.tensor(0.0, device=self.binary_aupr.device) + return self.binary_aupr / self.total diff --git a/torch_uncertainty/metrics/segmentation/seg_fpr95.py b/torch_uncertainty/metrics/segmentation/seg_fpr95.py new file mode 100644 index 00000000..56ec58fb --- /dev/null +++ b/torch_uncertainty/metrics/segmentation/seg_fpr95.py @@ -0,0 +1,35 @@ +import torch +from torch import Tensor +from torchmetrics import Metric + +from torch_uncertainty.metrics import FPR95 + + +class SegmentationFPR95(Metric): + is_differentiable = False + higher_is_better = False + full_state_update = False + + def __init__(self, pos_label: int, **kwargs): + """FPR95 metric for segmentation tasks. + Compute the mean FPR95 per batch across all batches. + + Args: + pos_label (int): The positive label in the segmentation OOD detection task. + **kwargs: Additional keyword arguments for the FPR95 metric. + """ + super().__init__(**kwargs) + self.fpr95_metric = FPR95(pos_label, **kwargs) + self.add_state("fpr95", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + batch_size = preds.size(0) + fpr95 = self.fpr95_metric(preds, target) + self.fpr95 += fpr95 * batch_size + self.total += batch_size + + def compute(self) -> Tensor: + if self.total == 0: + return torch.tensor(0.0, device=self.fpr95.device) + return self.fpr95 / self.total diff --git a/torch_uncertainty/models/segmentation/unet/batched.py b/torch_uncertainty/models/segmentation/unet/batched.py index 9267b7a6..54947927 100644 --- a/torch_uncertainty/models/segmentation/unet/batched.py +++ b/torch_uncertainty/models/segmentation/unet/batched.py @@ -170,17 +170,23 @@ def __init__( def forward(self, x: Tensor) -> Tensor: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) + # Downsampling x1 = self.inc(x) x2 = self.down1(x1) + x2 = self.dropout(x2) x3 = self.down2(x2) + x3 = self.dropout(x3) x4 = self.down3(x3) + x4 = self.dropout(x4) x5 = self.down4(x4) + # Upsampling + x5 = self.dropout(x5) x = self.up1(x5, x4) + x = self.dropout(x) x = self.up2(x, x3) x = self.up3(x, x2) - x = self.dropout(x) x = self.up4(x, x1) - x = self.dropout(x) + # Final output return self.outc(x) diff --git a/torch_uncertainty/models/segmentation/unet/masked.py b/torch_uncertainty/models/segmentation/unet/masked.py index c206a6e4..649055c7 100644 --- a/torch_uncertainty/models/segmentation/unet/masked.py +++ b/torch_uncertainty/models/segmentation/unet/masked.py @@ -213,17 +213,23 @@ def __init__( def forward(self, x: Tensor) -> Tensor: x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) + # Downsampling x1 = self.inc(x) x2 = self.down1(x1) + x2 = self.dropout(x2) x3 = self.down2(x2) + x3 = self.dropout(x3) x4 = self.down3(x3) + x4 = self.dropout(x4) x5 = self.down4(x4) + # Upsampling + x5 = self.dropout(x5) x = self.up1(x5, x4) + x = self.dropout(x) x = self.up2(x, x3) x = self.up3(x, x2) - x = self.dropout(x) x = self.up4(x, x1) - x = self.dropout(x) + # Final output return self.outc(x) diff --git a/torch_uncertainty/models/segmentation/unet/packed.py b/torch_uncertainty/models/segmentation/unet/packed.py index f4c113bc..4875c612 100644 --- a/torch_uncertainty/models/segmentation/unet/packed.py +++ b/torch_uncertainty/models/segmentation/unet/packed.py @@ -344,19 +344,19 @@ def forward(self, x): # Downsampling x1 = self.inc(x) x2 = self.down1(x1) + x2 = self.dropout(x2) x3 = self.down2(x2) + x3 = self.dropout(x3) x4 = self.down3(x3) + x4 = self.dropout(x4) x5 = self.down4(x4) # Upsampling + x5 = self.dropout(x5) x = self.up1(x5, x4) x = self.dropout(x) x = self.up2(x, x3) - x = self.dropout(x) x = self.up3(x, x2) - x = self.dropout(x) x = self.up4(x, x1) - x = self.dropout(x) - # Final output return self.outc(x) diff --git a/torch_uncertainty/models/segmentation/unet/standard.py b/torch_uncertainty/models/segmentation/unet/standard.py index 0bbd68a9..877aecb6 100644 --- a/torch_uncertainty/models/segmentation/unet/standard.py +++ b/torch_uncertainty/models/segmentation/unet/standard.py @@ -157,17 +157,23 @@ def __init__( self.dropout = nn.Dropout2d(dropout_rate) def forward(self, x: Tensor) -> Tensor: + # Downsampling x1 = self.inc(x) x2 = self.down1(x1) + x2 = self.dropout(x2) x3 = self.down2(x2) + x3 = self.dropout(x3) x4 = self.down3(x3) + x4 = self.dropout(x4) x5 = self.down4(x4) + # Upsampling + x5 = self.dropout(x5) x = self.up1(x5, x4) + x = self.dropout(x) x = self.up2(x, x3) x = self.up3(x, x2) - x = self.dropout(x) x = self.up4(x, x1) - x = self.dropout(x) + # Final output return self.outc(x) diff --git a/torch_uncertainty/models/wrappers/mc_dropout.py b/torch_uncertainty/models/wrappers/mc_dropout.py index 7e040813..ff430504 100644 --- a/torch_uncertainty/models/wrappers/mc_dropout.py +++ b/torch_uncertainty/models/wrappers/mc_dropout.py @@ -1,4 +1,5 @@ import torch +from einops import repeat from torch import Tensor, nn from torch.nn.modules.dropout import _DropoutNd @@ -83,7 +84,7 @@ def forward( if self.training: return self.core_model(x) if self.on_batch: - x = x.repeat(self.num_estimators, 1, 1, 1) + x = repeat(x, "b ... -> (m b) ...", m=self.num_estimators) return self.core_model(x) # Else, for loop return torch.cat([self.core_model(x) for _ in range(self.num_estimators)], dim=0) @@ -100,8 +101,11 @@ def mc_dropout( Args: model (nn.Module): model to wrap num_estimators (int): number of estimators to use last_layer (bool, optional): whether to apply dropout to the last layer only. Defaults to ``False``. - on_batch (bool): Increase the batch_size to perform MC-Dropout. Otherwise in a for loop to reduce memory footprint. Defaults to ``true``. + on_batch (bool): Increase the batch_size to perform MC-Dropout. Otherwise in a for loop to reduce memory footprint. Defaults to ``True``. last_layer (bool, optional): whether to apply dropout to the last layer only. Defaults to ``False``. + + Warning: + Beware that :attr:`on_batch==True` can raise weird errors if the not enough memory is available. """ return MCDropout( model=model, diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index c0b167b8..f47aec5a 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -19,12 +19,23 @@ BrierScore, CalibrationError, CategoricalNLL, + CovAt5Risk, MeanIntersectionOverUnion, + RiskAt80Cov, + SegmentationBinaryAUROC, + SegmentationBinaryAveragePrecision, + SegmentationFPR95, ) from torch_uncertainty.models import ( EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, ) +from torch_uncertainty.ood_criteria import ( + OODCriterionInputType, + TUOODCriterion, + get_ood_criterion, +) +from torch_uncertainty.post_processing import PostProcessing from torch_uncertainty.utils import csv_writer from torch_uncertainty.utils.plotting import show @@ -39,6 +50,9 @@ def __init__( eval_shift: bool = False, format_batch_fn: nn.Module | None = None, metric_subsampling_rate: float = 1e-2, + eval_ood: bool = False, + ood_criterion: type[TUOODCriterion] | str = "msp", + post_processing: PostProcessing | None = None, log_plots: bool = False, num_samples_to_plot: int = 3, num_bins_cal_err: int = 15, @@ -58,6 +72,13 @@ def __init__( batch. Defaults to ``None``. metric_subsampling_rate (float, optional): The rate of subsampling for the memory consuming metrics. Defaults to ``1e-2``. + eval_ood (bool, optional): Indicates whether to evaluate the OOD + performance. Defaults to ``False``. + ood_criterion (TUOODCriterion, optional): Criterion for the binary OOD detection task. + Defaults to ``"msp"`` which amounts to the maximum softmax probability score (MSP). + post_processing (PostProcessing, optional): The post-processing + technique to use. Defaults to ``None``. Warning: There is no + post-processing technique implemented yet for segmentation tasks. log_plots (bool, optional): Indicates whether to log figures in the logger. Defaults to ``False``. num_samples_to_plot (int, optional): Number of segmentation prediction and @@ -102,6 +123,13 @@ def __init__( self.metric_subsampling_rate = metric_subsampling_rate self.log_plots = log_plots self.save_in_csv = save_in_csv + self.ood_criterion = get_ood_criterion(ood_criterion) + self.eval_ood = eval_ood + + self.post_processing = post_processing + if self.post_processing is not None: + self.post_processing.set_model(self.model) + self._init_metrics() if log_plots: @@ -113,17 +141,17 @@ def _init_metrics(self) -> None: seg_metrics = MetricCollection( { "seg/mIoU": MeanIntersectionOverUnion(num_classes=self.num_classes), + "seg/mAcc": Accuracy( + task="multiclass", average="macro", num_classes=self.num_classes + ), + "seg/pixAcc": Accuracy(task="multiclass", num_classes=self.num_classes), }, - compute_groups=False, + compute_groups=[["seg/mIoU", "seg/mAcc", "seg/pixAcc"]], ) sbsmpl_seg_metrics = MetricCollection( { - "seg/mAcc": Accuracy( - task="multiclass", average="macro", num_classes=self.num_classes - ), "seg/Brier": BrierScore(num_classes=self.num_classes), "seg/NLL": CategoricalNLL(), - "seg/pixAcc": Accuracy(task="multiclass", num_classes=self.num_classes), "cal/ECE": CalibrationError( task="multiclass", num_classes=self.num_classes, @@ -137,14 +165,14 @@ def _init_metrics(self) -> None: ), "sc/AURC": AURC(), "sc/AUGRC": AUGRC(), + "sc/Cov@5Risk": CovAt5Risk(), + "sc/Risk@80Cov": RiskAt80Cov(), }, compute_groups=[ - ["seg/mAcc"], ["seg/Brier"], ["seg/NLL"], - ["seg/pixAcc"], ["cal/ECE", "cal/aECE"], - ["sc/AURC", "sc/AUGRC"], + ["sc/AURC", "sc/AUGRC", "sc/Cov@5Risk", "sc/Risk@80Cov"], ], ) @@ -153,6 +181,16 @@ def _init_metrics(self) -> None: self.test_seg_metrics = seg_metrics.clone(prefix="test/") self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test/") + if self.eval_ood: + ood_metrics = MetricCollection( + { + "AUROC": SegmentationBinaryAUROC(), + "AUPR": SegmentationBinaryAveragePrecision(), + "FPR95": SegmentationFPR95(pos_label=1), + } + ) + self.test_ood_metrics = ood_metrics.clone(prefix="ood/") + def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -178,6 +216,14 @@ def on_validation_start(self) -> None: self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: + if self.post_processing is not None: + with torch.inference_mode(False): + self.post_processing.fit(self.trainer.datamodule.postprocess_dataloader()) + + if self.eval_ood and self.log_plots and isinstance(self.logger, Logger): + self.id_logit_storage = [] + self.ood_logit_storage = [] + if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) @@ -226,13 +272,20 @@ def validation_step(self, batch: tuple[Tensor, Tensor]) -> None: self.val_seg_metrics.update(probs, targets) self.val_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) - def test_step(self, batch: tuple[Tensor, Tensor]) -> None: + def test_step( + self, + batch: tuple[Tensor, Tensor], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: """Perform a single test step based on the input tensors. Compute the prediction of the model and the value of the metrics on the test batch. Args: batch (tuple[Tensor, Tensor]): the test images and their corresponding targets + batch_idx (int): the index of the batch in the test dataloader. + dataloader_idx (int, optional): the index of the dataloader. Defaults to ``0``. """ img, targets = batch logits = self.forward(img) @@ -254,12 +307,38 @@ def test_step(self, batch: tuple[Tensor, Tensor]) -> None: _pred = _prb.argmax(dim=0, keepdim=True) self.sample_buffer.append((_img, _pred, _tgt)) + probs_per_est = rearrange(probs_per_est, "b m c h w -> (b h w) m c") probs = rearrange(probs, "b c h w -> (b h w) c") targets = targets.flatten() - valid_mask = (targets != 255) * (targets < self.num_classes) - probs, targets = probs[valid_mask], targets[valid_mask] - self.test_seg_metrics.update(probs, targets) - self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) + valid_mask = targets != 255 + probs, probs_per_est, targets = ( + probs[valid_mask], + probs_per_est[valid_mask], + targets[valid_mask], + ) + id_mask = targets < self.num_classes + ood_mask = targets >= self.num_classes + + if dataloader_idx == 0: + id_probs, _, id_targets = probs[id_mask], probs_per_est[id_mask], targets[id_mask] + self.test_seg_metrics.update(id_probs, id_targets) + self.test_sbsmpl_seg_metrics.update(*self.subsample(id_probs, id_targets)) + + if self.eval_ood and dataloader_idx == 1: + if self.ood_criterion.input_type == OODCriterionInputType.PROB: + ood_scores = self.ood_criterion(probs) + elif self.ood_criterion.input_type == OODCriterionInputType.ESTIMATOR_PROB: + ood_scores = self.ood_criterion(probs_per_est) + else: + raise ValueError( + f"Unsupported input type for OOD criterion: {self.ood_criterion.input_type}" + ) + + labels = torch.zeros_like(targets) + labels[id_mask] = 0 # ID examples + labels[ood_mask] = 1 # OOD examples + + self.test_ood_metrics.update(ood_scores, labels) def on_validation_epoch_end(self) -> None: """Compute and log the values of the collected metrics in `validation_step`.""" @@ -279,6 +358,9 @@ def on_test_epoch_end(self) -> None: """Compute, log, and plot the values of the collected metrics in `test_step`.""" result_dict = self.test_seg_metrics.compute() result_dict |= self.test_sbsmpl_seg_metrics.compute() + if self.eval_ood: + result_dict |= self.test_ood_metrics.compute() + self.log_dict(result_dict, logger=True, sync_dist=True) if isinstance(self.logger, Logger) and self.log_plots: @@ -304,6 +386,8 @@ def on_test_epoch_end(self) -> None: self.test_seg_metrics.reset() self.test_sbsmpl_seg_metrics.reset() + if self.eval_ood: + self.test_ood_metrics.reset() def log_segmentation_plots(self) -> None: """Build and log examples of segmentation plots from the test set."""