diff --git a/README.md b/README.md index 00d0bc2a..ebfe0957 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ To date, the following post-processing methods have been implemented: - Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) - Monte Carlo Batch Normalization - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_batch_norm.html) +- A wrapper for Laplace appoximation using the [Laplace library](https://github.com/aleximmer/Laplace) ## Tutorials diff --git a/docs/source/api.rst b/docs/source/api.rst index d4f99acf..a3762b92 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -242,6 +242,16 @@ Post-Processing Methods .. currentmodule:: torch_uncertainty.post_processing +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class_inherited.rst + MCBatchNorm + Laplace + +Scaling Methods +^^^^^^^^^^^^^^^ + .. autosummary:: :toctree: generated/ :nosignatures: @@ -250,7 +260,6 @@ Post-Processing Methods TemperatureScaler VectorScaler MatrixScaler - MCBatchNorm Datamodules ----------- diff --git a/docs/source/references.rst b/docs/source/references.rst index bd4467c9..b219eb3f 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -193,6 +193,16 @@ For Monte-Carlo Batch Normalization, consider citing: * Authors: *Mathias Teye, Hossein Azizpour, and Kevin Smith* * Paper: `ICML 2018 `__. +Laplace Approximation +^^^^^^^^^^^^^^^^^^^^^ + +For Laplace Approximation, consider citing: + +**Laplace Redux - Effortless Bayesian Deep Learning** + +* Authors: *Erik Daxberger, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, Philipp Hennig* +* Paper: `NeurIPS 2021 `__. + Metrics ------- diff --git a/pyproject.toml b/pyproject.toml index 0b11a230..822924d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ ] [project.optional-dependencies] -image = ["scikit-image", "h5py",] +image = ["scikit-image", "h5py", "webdataset"] tabular = ["pandas"] dev = [ "torch_uncertainty[image]", @@ -63,7 +63,10 @@ docs = [ "sphinx-design", "sphinx-codeautolink", ] -all = ["torch_uncertainty[dev,docs,image,tabular]"] +all = [ + "torch_uncertainty[dev,docs,image,tabular]", + "laplace-torch" + ] [project.urls] homepage = "https://torch-uncertainty.github.io/" diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py new file mode 100644 index 00000000..e3f19d13 --- /dev/null +++ b/torch_uncertainty/post_processing/laplace.py @@ -0,0 +1,68 @@ +from importlib import util +from typing import Literal + +from torch import Tensor, nn +from torch.utils.data import Dataset + +if util.find_spec("laplace"): + from laplace import Laplace + + laplace_installed = True + + +class Laplace(nn.Module): + def __init__( + self, + model: nn.Module, + task: Literal["classification", "regression"], + subset_of_weights="last_layer", + hessian_structure="kron", + pred_type: Literal["glm", "nn"] = "glm", + link_approx: Literal[ + "mc", "probit", "bridge", "bridge_norm" + ] = "probit", + ) -> None: + """Laplace approximation for uncertainty estimation. + + This class is a wrapper of Laplace classes from the laplace-torch library. + + Args: + model (nn.Module): model to be converted. + task (Literal["classification", "regression"]): task type. + subset_of_weights (str): subset of weights to be considered. Defaults to + "last_layer". + hessian_structure (str): structure of the Hessian matrix. Defaults to + "kron". + pred_type (Literal["glm", "nn"], optional): type of posterior predictive, + See the Laplace library for more details. Defaults to "glm". + link_approx (Literal["mc", "probit", "bridge", "bridge_norm"], optional): + how to approximate the classification link function for the `'glm'`. + See the Laplace library for more details. Defaults to "probit". + + Reference: + Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. + """ + super().__init__() + if not laplace_installed: + raise ImportError( + "The laplace-torch library is not installed. Please install it via `pip install laplace-torch`." + ) + self.la = Laplace( + model=model, + task=task, + subset_of_weights=subset_of_weights, + hessian_structure=hessian_structure, + ) + self.pred_type = pred_type + self.link_approx = link_approx + + def fit(self, dataset: Dataset) -> None: + self.la.fit(dataset=dataset) + + def forward( + self, + x: Tensor, + ) -> Tensor: + return self.la( + x, pred_type=self.pred_type, link_approx=self.link_approx + ) diff --git a/torch_uncertainty/post_processing/mc_batch_norm.py b/torch_uncertainty/post_processing/mc_batch_norm.py index d99fdd7c..66d21889 100644 --- a/torch_uncertainty/post_processing/mc_batch_norm.py +++ b/torch_uncertainty/post_processing/mc_batch_norm.py @@ -99,7 +99,7 @@ def _est_forward(self, x: Tensor) -> Tensor: def forward( self, x: Tensor, - ) -> tuple[Tensor, Tensor]: + ) -> Tensor: if self.training: return self.model(x) if not self.trained: