Skip to content

🐛 Support binary scaling #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions torch_uncertainty/post_processing/calibration/matrix_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
init_b: float = 0,
lr: float = 0.1,
max_iter: int = 200,
eps: float = 1e-8,
device: Literal["cpu", "cuda"] | device | None = None,
) -> None:
"""Matrix scaling post-processing for calibrated probabilities.
Expand All @@ -29,6 +30,7 @@ def __init__(
lr (float, optional): Learning rate for the optimizer. Defaults to 0.1.
max_iter (int, optional): Maximum number of iterations for the
optimizer. Defaults to 100.
eps (float): Small value for stability. Defaults to ``1e-8``.
device (Optional[Literal["cpu", "cuda"]], optional): Device to use
for optimization. Defaults to None.

Expand All @@ -37,7 +39,7 @@ def __init__(
of modern neural networks. In ICML 2017.

"""
super().__init__(model=model, lr=lr, max_iter=max_iter, device=device)
super().__init__(model=model, lr=lr, max_iter=max_iter, eps=eps, device=device)

if not isinstance(num_classes, int):
raise TypeError(f"num_classes must be an integer. Got {num_classes}.")
Expand Down Expand Up @@ -66,14 +68,6 @@ def set_temperature(self, val_w: float, val_b: float) -> None:
)

def _scale(self, logits: Tensor) -> Tensor:
"""Scale the predictions with the optimal temperature.

Args:
logits (Tensor): logits to be scaled.

Returns:
Tensor: Scaled logits.
"""
return self.temp_w @ logits + self.temp_b

@property
Expand Down
46 changes: 31 additions & 15 deletions torch_uncertainty/post_processing/calibration/scaler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from abc import abstractmethod
from typing import Literal

import torch
from torch import Tensor, nn, optim
from torch import Tensor, nn
from torch.optim import LBFGS
from torch.utils.data import DataLoader
from tqdm import tqdm

Expand All @@ -18,17 +20,19 @@ def __init__(
model: nn.Module | None = None,
lr: float = 0.1,
max_iter: int = 100,
eps: float = 1e-8,
device: Literal["cpu", "cuda"] | torch.device | None = None,
) -> None:
"""Virtual class for scaling post-processing for calibrated probabilities.

Args:
model (nn.Module): Model to calibrate.
lr (float, optional): Learning rate for the optimizer. Defaults to 0.1.
model (nn.Module | None): Model to calibrate. Defaults to ``None``.
lr (float, optional): Learning rate for the optimizer. Defaults to ``0.1``.
max_iter (int, optional): Maximum number of iterations for the
optimizer. Defaults to 100.
optimizer. Defaults to ``100``.
eps (float): Small value for stability. Defaults to ``1e-8``.
device (Optional[Literal["cpu", "cuda"]], optional): Device to use
for optimization. Defaults to None.
for optimization. Defaults to ``None``.

Reference:
Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. On calibration
Expand All @@ -38,13 +42,17 @@ def __init__(
self.device = device

if lr <= 0:
raise ValueError("Learning rate must be positive.")
raise ValueError(f"Learning rate must be strictly positive. Got {lr}.")
self.lr = lr

if max_iter <= 0:
raise ValueError("Max iterations must be positive.")
raise ValueError(f"Max iterations must be strictly positive. Got {max_iter}.")
self.max_iter = int(max_iter)

if eps <= 0:
raise ValueError(f"Eps must be strictly positive. Got {eps}.")
self.eps = eps

def fit(
self,
dataloader: DataLoader,
Expand All @@ -54,11 +62,12 @@ def fit(
"""Fit the temperature parameters to the calibration data.

Args:
dataloader (DataLoader): Dataloader with the calibration data.
dataloader (DataLoader): Dataloader with the calibration data. If there is no model,
the dataloader should include the confidence score directly and not the logits.
save_logits (bool, optional): Whether to save the logits and
labels. Defaults to False.
labels in memory. Defaults to ``False``.
progress (bool, optional): Whether to show a progress bar.
Defaults to True.
Defaults to ``True``.
"""
if self.model is None or isinstance(self.model, nn.Identity):
logging.warning(
Expand All @@ -75,9 +84,15 @@ def fit(
all_logits.append(logits)
all_labels.append(labels)
all_logits = torch.cat(all_logits).to(self.device)
all_labels = torch.cat(all_labels).to(self.device)
all_labels = torch.cat(all_labels).to(dtype=torch.long).to(self.device)

all_logits = all_logits.clamp(self.eps, 1 - self.eps)
if all_logits.dim() == 2 and all_logits.shape[1] == 1:
all_logits = all_logits.squeeze(1)
if all_logits.dim() == 1:
all_logits = torch.stack([torch.log(1 - all_logits), torch.log(all_logits)], dim=1)

optimizer = optim.LBFGS(self.temperature, lr=self.lr, max_iter=self.max_iter)
optimizer = LBFGS(self.temperature, lr=self.lr, max_iter=self.max_iter)

def calib_eval() -> float:
optimizer.zero_grad()
Expand All @@ -99,6 +114,7 @@ def forward(self, inputs: Tensor) -> Tensor:
)
return self._scale(self.model(inputs))

@abstractmethod
def _scale(self, logits: Tensor) -> Tensor:
"""Scale the logits with the optimal temperature.

Expand All @@ -108,7 +124,7 @@ def _scale(self, logits: Tensor) -> Tensor:
Returns:
Tensor: Scaled logits.
"""
raise NotImplementedError
...

def fit_predict(
self,
Expand All @@ -119,5 +135,5 @@ def fit_predict(
return self(self.logits)

@property
def temperature(self) -> list:
raise NotImplementedError
@abstractmethod
def temperature(self) -> list: ...
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
init_val: float = 1,
lr: float = 0.1,
max_iter: int = 100,
eps: float = 1e-8,
device: Literal["cpu", "cuda"] | torch.device | None = None,
) -> None:
"""Temperature scaling post-processing for calibrated probabilities.
Expand All @@ -24,14 +25,15 @@ def __init__(
lr (float, optional): Learning rate for the optimizer. Defaults to 0.1.
max_iter (int, optional): Maximum number of iterations for the
optimizer. Defaults to 100.
eps (float): Small value for stability. Defaults to ``1e-8``.
device (Optional[Literal["cpu", "cuda"]], optional): Device to use
for optimization. Defaults to None.

Reference:
Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. On calibration
of modern neural networks. In ICML 2017.
"""
super().__init__(model=model, lr=lr, max_iter=max_iter, device=device)
super().__init__(model=model, lr=lr, max_iter=max_iter, eps=eps, device=device)

if init_val <= 0:
raise ValueError(f"Initial temperature value must be positive. Got {init_val}")
Expand All @@ -50,14 +52,6 @@ def set_temperature(self, val: float) -> None:
self.temp = nn.Parameter(torch.ones(1, device=self.device) * val, requires_grad=True)

def _scale(self, logits: Tensor) -> Tensor:
"""Scale the prediction with the optimal temperature.

Args:
logits (Tensor): logits to be scaled.

Returns:
Tensor: Scaled logits.
"""
return logits / self.temperature[0]

@property
Expand Down
12 changes: 3 additions & 9 deletions torch_uncertainty/post_processing/calibration/vector_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
init_b: float = 0,
lr: float = 0.1,
max_iter: int = 200,
eps: float = 1e-8,
device: Literal["cpu", "cuda"] | torch.device | None = None,
) -> None:
"""Vector scaling post-processing for calibrated probabilities.
Expand All @@ -29,6 +30,7 @@ def __init__(
lr (float, optional): Learning rate for the optimizer. Defaults to 0.1.
max_iter (int, optional): Maximum number of iterations for the
optimizer. Defaults to 100.
eps (float): Small value for stability. Defaults to ``1e-8``.
device (Optional[Literal["cpu", "cuda"]], optional): Device to use
for optimization. Defaults to None.

Expand All @@ -37,7 +39,7 @@ def __init__(
of modern neural networks. In ICML 2017.

"""
super().__init__(model=model, lr=lr, max_iter=max_iter, device=device)
super().__init__(model=model, lr=lr, max_iter=max_iter, eps=eps, device=device)

if not isinstance(num_classes, int):
raise TypeError(f"num_classes must be an integer. Got {num_classes}.")
Expand All @@ -64,14 +66,6 @@ def set_temperature(self, val_w: float, val_b: float) -> None:
)

def _scale(self, logits: torch.Tensor) -> torch.Tensor:
"""Scale the predictions with the optimal temperature.

Args:
logits (torch.Tensor): logits to be scaled.

Returns:
torch.Tensor: Scaled logits.
"""
return self.temp_w * logits + self.temp_b

@property
Expand Down