diff --git a/.gitignore b/.gitignore index 659ceed3..9ad40cf4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,11 +6,10 @@ lightning_logs/ auto_tutorials_source/*.png docs/*/generated/ docs/*/auto_tutorials/ -*.pth -*.ckpt -*.out docs/source/sg_execution_times.rst -test +**/*.pth +**/*.ckpt +**/*.out **/*.csv pyrightconfig.json diff --git a/experiments/depth/kitti/bts.py b/experiments/depth/kitti/bts.py index 3fdfae82..d69b870e 100644 --- a/experiments/depth/kitti/bts.py +++ b/experiments/depth/kitti/bts.py @@ -1,16 +1,16 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch.optim.lr_scheduler import PolynomialLR from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.depth import BTSBaseline from torch_uncertainty.datamodules.depth import KITTIDataModule -from torch_uncertainty.utils.learning_rate import PolyLR class BTSCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_optimizer_args(torch.optim.AdamW) - parser.add_lr_scheduler_args(PolyLR) + parser.add_lr_scheduler_args(PolynomialLR) def cli_main() -> BTSCLI: diff --git a/experiments/depth/nyu/bts.py b/experiments/depth/nyu/bts.py index d8aac63b..0e419abc 100644 --- a/experiments/depth/nyu/bts.py +++ b/experiments/depth/nyu/bts.py @@ -1,16 +1,16 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch.optim.lr_scheduler import PolynomialLR from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.depth import BTSBaseline from torch_uncertainty.datamodules.depth import NYUv2DataModule -from torch_uncertainty.utils.learning_rate import PolyLR class BTSCLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_optimizer_args(torch.optim.AdamW) - parser.add_lr_scheduler_args(PolyLR) + parser.add_lr_scheduler_args(PolynomialLR) def cli_main() -> BTSCLI: diff --git a/experiments/segmentation/camvid/deeplab.py b/experiments/segmentation/camvid/deeplab.py index efebd51b..44f47e46 100644 --- a/experiments/segmentation/camvid/deeplab.py +++ b/experiments/segmentation/camvid/deeplab.py @@ -1,16 +1,16 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch.optim.lr_scheduler import PolynomialLR from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.segmentation import DeepLabBaseline from torch_uncertainty.datamodules.segmentation import CamVidDataModule -from torch_uncertainty.utils.learning_rate import PolyLR class DeepLabV3CLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_optimizer_args(torch.optim.SGD) - parser.add_lr_scheduler_args(PolyLR) + parser.add_lr_scheduler_args(PolynomialLR) def cli_main() -> DeepLabV3CLI: diff --git a/experiments/segmentation/cityscapes/deeplab.py b/experiments/segmentation/cityscapes/deeplab.py index 0b074a74..cc865b9d 100644 --- a/experiments/segmentation/cityscapes/deeplab.py +++ b/experiments/segmentation/cityscapes/deeplab.py @@ -1,16 +1,16 @@ import torch from lightning.pytorch.cli import LightningArgumentParser +from torch.optim.lr_scheduler import PolynomialLR from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.segmentation import DeepLabBaseline from torch_uncertainty.datamodules.segmentation import CityscapesDataModule -from torch_uncertainty.utils.learning_rate import PolyLR class DeepLabV3CLI(TULightningCLI): def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_optimizer_args(torch.optim.SGD) - parser.add_lr_scheduler_args(PolyLR) + parser.add_lr_scheduler_args(PolynomialLR) def cli_main() -> DeepLabV3CLI: diff --git a/tests/losses/test_bayesian.py b/tests/losses/test_bayesian.py index afca2e53..678631fe 100644 --- a/tests/losses/test_bayesian.py +++ b/tests/losses/test_bayesian.py @@ -42,7 +42,7 @@ def test_prob_regression_training_step(self): ) inputs = torch.randn(1, 10) targets = torch.randn(1, 4) - routine.training_step((inputs, targets), 0) + routine.training_step((inputs, targets)) def test_training_step(self): model = BayesLinear(10, 4) @@ -62,7 +62,7 @@ def test_training_step(self): inputs = torch.randn(1, 10) targets = torch.randn(1, 4) - routine.training_step((inputs, targets), 0) + routine.training_step((inputs, targets)) def test_failures(self): model = BayesLinear(1, 1) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 2bed36d2..8fb208ab 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -61,6 +61,7 @@ def __init__( last_layer_dropout: bool = False, width_multiplier: float = 1.0, groups: int = 1, + conv_bias: bool = False, scale: float | None = None, alpha: int | None = None, gamma: int = 1, @@ -128,6 +129,8 @@ def __init__( last_layer_dropout (bool): whether to apply dropout to the last layer only. groups (int, optional): Number of groups in convolutions. Defaults to ``1``. + conv_bias (bool, optional): Whether to include bias in the convolutional + layers. Defaults to ``False``. scale (float, optional): Expansion factor affecting the width of the estimators. Only used if :attr:`version` is ``"masked"``. Defaults to ``None``. @@ -174,7 +177,7 @@ def __init__( """ params = { "arch": arch, - "conv_bias": False, + "conv_bias": conv_bias, "dropout_rate": dropout_rate, "groups": groups, "width_multiplier": width_multiplier, diff --git a/torch_uncertainty/datasets/classification/cub.py b/torch_uncertainty/datasets/classification/cub.py index 1dfa1f0b..079d20ec 100644 --- a/torch_uncertainty/datasets/classification/cub.py +++ b/torch_uncertainty/datasets/classification/cub.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +from einops import rearrange from torch import Tensor from torchvision.datasets import ImageFolder from torchvision.datasets.utils import check_integrity, download_and_extract_archive @@ -20,6 +21,7 @@ def __init__( train: bool = True, transform: Callable | None = None, target_transform: Callable | None = None, + return_attributes: bool = False, download: bool = False, ): """The Caltech-UCSD Birds-200-2011 dataset. @@ -32,9 +34,12 @@ def __init__( returns a transformed version. E.g, transforms.RandomCrop. Defaults to None. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Defaults to None. + return_attributes (bool, optional): If True, returns the attributes instead of the images. + Defaults to False. download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. - Defaults to + Defaults to False. + Reference: Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S. Caltech-UCSD Birds 200. @@ -52,20 +57,80 @@ def __init__( super().__init__(Path(root) / "CUB_200_2011" / "images", transform, target_transform) training_idx = self._load_train_idx() + self.attributes, self.uncertainties = self._load_attributes() + self.attribute_names = self._load_attribute_names() + self.classnames = self._load_classnames() + self.samples = [sample for i, sample in enumerate(self.samples) if training_idx[i] == train] self._labels = [label for i, label in enumerate(self.targets) if training_idx[i] == train] + self.attributes = rearrange( + torch.masked_select(self.attributes, training_idx.unsqueeze(-1) == train), + "(n c) -> n c", + c=312, + ) + self.uncertainties = rearrange( + torch.masked_select(self.uncertainties, training_idx.unsqueeze(-1) == train), + "(n c) -> n c", + c=312, + ) + + if return_attributes: + self.samples = zip(self.attributes, [sam[1] for sam in self.samples], strict=False) + self.loader = torch.nn.Identity() + + def _load_classnames(self) -> list[str]: + """Load the classnames of the dataset. + + Returns: + list[str]: the list containing the names of the 200 classes. + """ with Path(self.folder_root / "CUB_200_2011" / "classes.txt").open("r") as f: - self.class_names = [ + return [ line.split(" ")[1].split(".")[1].replace("\n", "").replace("_", " ") for line in f ] def _load_train_idx(self) -> Tensor: - is_training_img = [] + """Load the index of the training data to make the split. + + Returns: + Tensor: whether the images belong to the training or test split. + """ with (self.folder_root / "CUB_200_2011" / "train_test_split.txt").open("r") as f: - is_training_img = [int(line.split(" ")[1]) for line in f] - return torch.as_tensor(is_training_img) + return torch.as_tensor([int(line.split(" ")[1]) for line in f]) + + def _load_attributes(self) -> tuple[Tensor, Tensor]: + """Load the attributes associated to each image. + + Returns: + tuple[Tensor, Tensor]: The presence of the 312 attributes along with their uncertainty. + The uncertainty is 0 for certain samples and 1 for non-visible attributes. + """ + attributes, uncertainty = [], [] + with (self.folder_root / "CUB_200_2011" / "attributes" / "image_attribute_labels.txt").open( + "r" + ) as f: + for line in f: + attributes.append(int(line.split(" ")[2])) + uncertainty.append(1 - (int(line.split(" ")[3]) - 1) / 3) + return rearrange(torch.as_tensor(attributes), "(n c) -> n c", c=312), rearrange( + torch.as_tensor(uncertainty), "(n c) -> n c", c=312 + ) + + def _load_attribute_names(self) -> list[str]: + """Load the names of the attributes. + + Returns: + list[str]: The list of the names of the 312 attributes. + """ + with (self.folder_root / "attributes.txt").open("r") as f: + return [line.split(" ")[1].replace("\n", "").replace("_", " ") for line in f] def _check_integrity(self) -> bool: + """Check the integrity of the dataset. + + Returns: + bool: True when the md5 of the archive corresponds. + """ fpath = self.folder_root / self.filename return check_integrity( fpath, @@ -73,6 +138,7 @@ def _check_integrity(self) -> bool: ) def _download(self): + """Download the dataset from caltec.edu.""" if self._check_integrity(): logging.info("Files already downloaded and verified") return diff --git a/torch_uncertainty/datasets/regression/uci_regression.py b/torch_uncertainty/datasets/regression/uci_regression.py index 560d2136..a2c50bc3 100644 --- a/torch_uncertainty/datasets/regression/uci_regression.py +++ b/torch_uncertainty/datasets/regression/uci_regression.py @@ -117,17 +117,17 @@ class UCIRegression(Dataset): "4e6727f462779e2d396e8f7d2ddb79a3", ] urls = [ - "https://archive.ics.uci.edu/ml/machine-learning-databases/housing/" "housing.data", - "https://archive.ics.uci.edu/static/public/165/concrete+compressive+" "strength.zip", + "https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data", + "https://archive.ics.uci.edu/static/public/165/concrete+compressive+strength.zip", "https://archive.ics.uci.edu/static/public/242/energy+efficiency.zip", - "https://archive.ics.uci.edu/static/public/374/appliances+energy+" "prediction.zip", - "https://www.openml.org/data/get_csv/3626/dataset_2175_kin8nm.arff", - "https://raw.githubusercontent.com/luishpinto/cm-naval-propulsion-" "plant/master/data.csv", - "https://archive.ics.uci.edu/static/public/294/combined+cycle+power+" "plant.zip", + "https://archive.ics.uci.edu/static/public/374/appliances+energy+prediction.zip", + "https://zenodo.org/records/14645866/files/kin8nm.csv", + "https://raw.githubusercontent.com/luishpinto/cm-naval-propulsion-plant/master/data.csv", + "https://archive.ics.uci.edu/static/public/294/combined+cycle+power+plant.zip", "https://archive.ics.uci.edu/static/public/265/physicochemical+" "properties+of+protein+tertiary+structure.zip", "https://archive.ics.uci.edu/static/public/186/wine+quality.zip", - "https://archive.ics.uci.edu/static/public/243/yacht+" "hydrodynamics.zip", + "https://archive.ics.uci.edu/static/public/243/yacht+hydrodynamics.zip", ] def __init__( diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 8d131fa8..bb616c58 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -178,6 +178,7 @@ def __init__( self.ood_logit_storage = None def _init_metrics(self) -> None: + """Initialize the metrics depending on the exact task.""" task = "binary" if self.binary_cls else "multiclass" metrics_dict = { @@ -264,6 +265,12 @@ def _init_metrics(self) -> None: self.test_grouping_loss = grouping_loss.clone(prefix="test/") def _init_mixup(self, mixup_params: dict | None) -> Callable: + """Setup the optional mixup augmentation based on the :attr:`mixup_params` dict. + + Args: + mixup_params (dict | None): the detailed parameters of the mixup augmentation. None if + unused. + """ if mixup_params is None: mixup_params = {} mixup_params = MIXUP_PARAMS | mixup_params @@ -312,6 +319,14 @@ def _init_mixup(self, mixup_params: dict | None) -> Callable: return Identity() def _apply_mixup(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: + """Apply the mixup augmentation on a :attr:`batch` of images. + + Args: + batch (tuple[Tensor, Tensor]): the images and the corresponding targets. + + Returns: + tuple[Tensor, Tensor]: the images and the corresponding targets transformed with mixup. + """ if not self.is_ensemble: if self.mixup_params["mixtype"] == "kernel_warping": if self.mixup_params["dist_sim"] == "emb": @@ -328,18 +343,28 @@ def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: + """Put the hyperparameters in tensorboard.""" if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( self.hparams, ) def on_validation_start(self) -> None: + """Prepare the validation step. + + Update the model's wrapper and the batchnorms if needed. + """ if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: + """Prepare the test step. + + Setup the post-processing dataset and fit the post-processing method if needed, prepares + the storage lists for logit plotting and update the batchnorms if needed. + """ if self.post_processing is not None: calibration_dataset = ( self.trainer.datamodule.val_dataloader().dataset @@ -357,11 +382,11 @@ def on_test_start(self) -> None: self.model.bn_update(self.trainer.train_dataloader, device=self.device) def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: - """Forward pass of the model. + """Forward pass of the inner model. Args: - inputs (Tensor): Input tensor. - save_feats (bool, optional): Whether to store the features or + inputs (Tensor): input tensor. + save_feats (bool, optional): whether to store the features or not. Defaults to ``False``. Note: @@ -378,7 +403,15 @@ def forward(self, inputs: Tensor, save_feats: bool = False) -> Tensor: logits = self.model(inputs) return logits - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: + def training_step(self, batch: tuple[Tensor, Tensor]) -> STEP_OUTPUT: + """Perform a single training step based on the input tensors. + + Args: + batch (tuple[Tensor, Tensor]): the training data and their corresponding targets + + Returns: + Tensor: the loss corresponding to this training step. + """ batch = self._apply_mixup(batch) inputs, target = self.format_batch_fn(batch) @@ -400,7 +433,14 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OU self.log("train_loss", loss, prog_bar=True, logger=True) return loss - def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor]) -> None: + """Perform a single validation step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the validation batch. + + Args: + batch (tuple[Tensor, Tensor]): the validation data and their corresponding targets + """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) logits = rearrange(logits, "(m b) c -> b m c", b=targets.size(0)) @@ -422,6 +462,17 @@ def test_step( batch_idx: int, dataloader_idx: int = 0, ) -> None: + """Perform a single test step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the test batch. Also + handle OOD and distribution-shifted images. + + Args: + batch (tuple[Tensor, Tensor]): the test data and their corresponding targets. + batch_idx (int): the number of the current batch (unused). + dataloader_idx (int): 0 if in-distribution, 1 if out-of-distribution and 2 if + distribution-shifted. + """ inputs, targets = batch logits = self.forward(inputs, save_feats=self.eval_grouping_loss) logits = rearrange(logits, "(n b) c -> b n c", b=targets.size(0)) @@ -500,6 +551,7 @@ def test_step( self.test_shift_ens_metrics.update(probs_per_est) def on_validation_epoch_end(self) -> None: + """Compute and log the values of the collected metrics in `validation_step`.""" res_dict = self.val_cls_metrics.compute() self.log_dict(res_dict, logger=True, sync_dist=True) self.log( @@ -516,6 +568,7 @@ def on_validation_epoch_end(self) -> None: self.val_grouping_loss.reset() def on_test_epoch_end(self) -> None: + """Compute, log, and plot the values of the collected metrics in `test_step`.""" # already logged result_dict = self.test_cls_metrics.compute() @@ -615,6 +668,11 @@ def on_test_epoch_end(self) -> None: self.save_results_to_csv(result_dict) def save_results_to_csv(self, results: dict[str, float]) -> None: + """Save the metric results in a csv. + + Args: + results (dict[str, float]): the dictionary containing all the values of the metrics. + """ if self.logger is not None: csv_writer( Path(self.logger.log_dir) / "results.csv", @@ -633,6 +691,19 @@ def _classification_routine_checks( post_processing: PostProcessing | None, format_batch_fn: nn.Module | None, ) -> None: + """Check the domains of the routine's parameters. + + Args: + model (nn.Module): the model used to make classification predictions. + num_classes (int): the number of classes in the dataset. + is_ensemble (bool): whether the model is an ensemble or a single model. + ood_criterion (str): the criterion for the binary OOD detection task. + eval_grouping_loss (bool): whether to evaluate the grouping loss. + num_calibration_bins (int): the number of bins for the evaluation of the calibration. + mixup_params (dict | None): the dictionary to setup the mixup augmentation. + post_processing (PostProcessing | None): the post-processing module. + format_batch_fn (nn.Module | None): the function for formatting the batch for ensembles. + """ if ood_criterion not in [ "msp", "logit", diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py index dc6f8a85..3fa450f2 100644 --- a/torch_uncertainty/routines/pixel_regression.py +++ b/torch_uncertainty/routines/pixel_regression.py @@ -108,7 +108,10 @@ def __init__( self.optim_recipe = optim_recipe self.format_batch_fn = format_batch_fn + self._init_metrics() + def _init_metrics(self) -> None: + """Initialize the metrics depending on the exact task.""" depth_metrics = MetricCollection( { "reg/SILog": SILog(), @@ -138,18 +141,27 @@ def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: + """Put the hyperparameters in tensorboard.""" if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( self.hparams, ) def on_validation_start(self) -> None: + """Prepare the validation step. + + Update the model's wrapper and the batchnorms if needed. + """ if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: + """Prepare the test step. + + Update the batchnorms if needed. + """ if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) @@ -174,7 +186,15 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = pred.squeeze(-1) return pred - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: + def training_step(self, batch: tuple[Tensor, Tensor]) -> STEP_OUTPUT: + """Perform a single training step based on the input tensors. + + Args: + batch (tuple[Tensor, Tensor]): the training data and their corresponding targets + + Returns: + Tensor: the loss corresponding to this training step. + """ inputs, target = self.format_batch_fn(batch) if self.one_dim_depth: target = target.unsqueeze(1) @@ -201,6 +221,14 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OU return loss def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | None]: + """Get the prediction and handle predicted eventual distribution parameters. + + Args: + inputs (Tensor): the input data. + + Returns: + tuple[Tensor, Distribution | None]: the prediction as a Tensor and a distribution. + """ batch_size = inputs.size(0) preds = self.model(inputs) @@ -220,6 +248,15 @@ def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | Non return preds.mean(dim=1), None def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + """Perform a single validation step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the validation batch. + + Args: + batch (tuple[Tensor, Tensor]): the validation images and their corresponding targets. + batch_idx (int): the id of the batch. Optionally plot images and the predictions with + the first batch. + """ inputs, targets = batch if self.one_dim_depth: targets = targets.unsqueeze(1) @@ -245,6 +282,16 @@ def test_step( batch_idx: int, dataloader_idx: int = 0, ) -> None: + """Perform a single test step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the test batch. Also + handle OOD and distribution-shifted images. + + Args: + batch (tuple[Tensor, Tensor]): the test data and their corresponding targets. + batch_idx (int): the number of the current batch (unused). + dataloader_idx (int): 0 if in-distribution, 1 if out-of-distribution. + """ if dataloader_idx != 0: raise NotImplementedError( "Depth OOD detection not implemented yet. Raise an issue " "if needed." @@ -272,6 +319,7 @@ def test_step( self.test_prob_metrics.update(dist, targets, padding_mask) def on_validation_epoch_end(self) -> None: + """Compute and log the values of the collected metrics in `validation_step`.""" res_dict = self.val_metrics.compute() self.log_dict(res_dict, logger=True, sync_dist=True) self.log( @@ -290,6 +338,7 @@ def on_validation_epoch_end(self) -> None: self.val_prob_metrics.reset() def on_test_epoch_end(self) -> None: + """Compute and log the values of the collected metrics in `test_step`.""" self.log_dict( self.test_metrics.compute(), sync_dist=True, @@ -354,6 +403,13 @@ def colorize( def _depth_routine_checks(output_dim: int, num_image_plot: int, log_plots: bool) -> None: + """Check the domains of the routine's parameters. + + Args: + output_dim (int): the dimension of the output of the regression task. + num_image_plot (int): the number of images to plot at evaluation time. + log_plots (bool): whether to plot images and predictions during evaluation. + """ if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") if num_image_plot < 1 and log_plots: diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 29bdae7c..07a43bb8 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -94,7 +94,11 @@ def __init__( self.optim_recipe = optim_recipe self.format_batch_fn = format_batch_fn + self.one_dim_regression = output_dim == 1 + self._init_metrics() + def _init_metrics(self) -> None: + """Initialize the metrics depending on the exact task.""" reg_metrics = MetricCollection( { "reg/MAE": MeanAbsoluteError(), @@ -112,28 +116,35 @@ def __init__( self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/") self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/") - self.one_dim_regression = output_dim == 1 - def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: + """Put the hyperparameters in tensorboard.""" if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( self.hparams, ) def on_validation_start(self) -> None: + """Prepare the validation step. + + Update the model's wrapper and the batchnorms if needed. + """ if self.needs_epoch_update and not self.trainer.sanity_checking: self.model.update_wrapper(self.current_epoch) if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) def on_test_start(self) -> None: + """Prepare the test step. + + Update the batchnorms if needed. + """ if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) - def forward(self, inputs: Tensor) -> Tensor | Distribution: + def forward(self, inputs: Tensor) -> Tensor | dict[str, Tensor]: """Forward pass of the routine. The forward pass automatically squeezes the output if the regression @@ -143,7 +154,8 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: inputs (Tensor): The input tensor. Returns: - Tensor: The output tensor. + Tensor | dict[str, Tensor]: The output tensor or the parameters of the output + distribution. """ pred = self.model(inputs) if self.probabilistic: @@ -164,7 +176,15 @@ def forward(self, inputs: Tensor) -> Tensor | Distribution: pred = pred.squeeze(-1) return pred - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: + def training_step(self, batch: tuple[Tensor, Tensor]) -> STEP_OUTPUT: + """Perform a single training step based on the input tensors. + + Args: + batch (tuple[Tensor, Tensor]): the training data and their corresponding targets + + Returns: + Tensor: the loss corresponding to this training step. + """ inputs, targets = self.format_batch_fn(batch) if self.one_dim_regression: @@ -190,6 +210,14 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OU return loss def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | None]: + """Get the prediction and handle predicted eventual distribution parameters. + + Args: + inputs (Tensor): the input data. + + Returns: + tuple[Tensor, Distribution | None]: the prediction as a Tensor and a distribution. + """ batch_size = inputs.size(0) preds = self.model(inputs) @@ -208,7 +236,14 @@ def evaluation_forward(self, inputs: Tensor) -> tuple[Tensor, Distribution | Non preds = rearrange(preds, "(m b) c -> b m c", b=batch_size) return preds.mean(dim=1), None - def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor]) -> None: + """Perform a single validation step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the validation batch. + + Args: + batch (tuple[Tensor, Tensor]): the validation data and their corresponding targets. + """ inputs, targets = batch if self.one_dim_regression: targets = targets.unsqueeze(-1) @@ -224,6 +259,16 @@ def test_step( batch_idx: int, dataloader_idx: int = 0, ) -> None: + """Perform a single test step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the test batch. Also + handle OOD and distribution-shifted images. + + Args: + batch (tuple[Tensor, Tensor]): the test data and their corresponding targets. + batch_idx (int): the number of the current batch (unused). + dataloader_idx (int): 0 if in-distribution, 1 if out-of-distribution. + """ if dataloader_idx != 0: raise NotImplementedError( "Regression OOD detection not implemented yet. Raise an issue " "if needed." @@ -239,6 +284,7 @@ def test_step( self.test_prob_metrics.update(dist, targets) def on_validation_epoch_end(self) -> None: + """Compute and log the values of the collected metrics in `validation_step`.""" res_dict = self.val_metrics.compute() self.log_dict(res_dict, logger=True, sync_dist=True) self.log( @@ -254,6 +300,7 @@ def on_validation_epoch_end(self) -> None: self.val_prob_metrics.reset() def on_test_epoch_end(self) -> None: + """Compute and log the values of the collected metrics in `test_step`.""" self.log_dict( self.test_metrics.compute(), ) @@ -267,5 +314,10 @@ def on_test_epoch_end(self) -> None: def _regression_routine_checks(output_dim: int) -> None: + """Check the domains of the routine's parameters. + + Args: + output_dim (int): the dimension of the output of the regression task. + """ if output_dim < 1: raise ValueError(f"output_dim must be positive, got {output_dim}.") diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index 2ef10d73..ba08e367 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -70,9 +70,9 @@ def __init__( """ super().__init__() _segmentation_routine_checks( - num_classes, - metric_subsampling_rate, - num_calibration_bins, + num_classes=num_classes, + metric_subsampling_rate=metric_subsampling_rate, + num_calibration_bins=num_calibration_bins, ) if eval_shift: raise NotImplementedError( @@ -81,6 +81,7 @@ def __init__( self.model = model self.num_classes = num_classes + self.num_calibration_bins = num_calibration_bins self.loss = loss self.needs_epoch_update = isinstance(model, EPOCH_UPDATE_MODEL) self.needs_step_update = isinstance(model, STEP_UPDATE_MODEL) @@ -92,30 +93,38 @@ def __init__( self.format_batch_fn = format_batch_fn self.metric_subsampling_rate = metric_subsampling_rate self.log_plots = log_plots + self._init_metrics() - # metrics + if log_plots: + self.num_samples_to_plot = num_samples_to_plot + self.sample_buffer = [] + + def _init_metrics(self) -> None: + """Initialize the metrics depending on the exact task.""" seg_metrics = MetricCollection( { - "seg/mIoU": MeanIntersectionOverUnion(num_classes=num_classes), + "seg/mIoU": MeanIntersectionOverUnion(num_classes=self.num_classes), }, compute_groups=False, ) sbsmpl_seg_metrics = MetricCollection( { - "seg/mAcc": Accuracy(task="multiclass", average="macro", num_classes=num_classes), - "seg/Brier": BrierScore(num_classes=num_classes), + "seg/mAcc": Accuracy( + task="multiclass", average="macro", num_classes=self.num_classes + ), + "seg/Brier": BrierScore(num_classes=self.num_classes), "seg/NLL": CategoricalNLL(), - "seg/pixAcc": Accuracy(task="multiclass", num_classes=num_classes), + "seg/pixAcc": Accuracy(task="multiclass", num_classes=self.num_classes), "cal/ECE": CalibrationError( task="multiclass", - num_classes=num_classes, - num_bins=num_calibration_bins, + num_classes=self.num_classes, + num_bins=self.num_calibration_bins, ), "cal/aECE": CalibrationError( task="multiclass", adaptive=True, - num_bins=num_calibration_bins, - num_classes=num_classes, + num_classes=self.num_classes, + num_bins=self.num_calibration_bins, ), "sc/AURC": AURC(), "sc/AUGRC": AUGRC(), @@ -135,10 +144,6 @@ def __init__( self.test_seg_metrics = seg_metrics.clone(prefix="test/") self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test/") - if log_plots: - self.num_samples_to_plot = num_samples_to_plot - self.sample_buffer = [] - def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe @@ -146,7 +151,10 @@ def forward(self, inputs: Tensor) -> Tensor: """Forward pass of the model. Args: - inputs (torch.Tensor): Input tensor. + inputs (Tensor): input tensor. + + Returns: + Tensor: the prediction of the model. """ return self.model(inputs) @@ -164,9 +172,16 @@ def on_test_start(self) -> None: if hasattr(self.model, "need_bn_update"): self.model.bn_update(self.trainer.train_dataloader, device=self.device) - def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OUTPUT: - img, target = batch - img, target = self.format_batch_fn((img, target)) + def training_step(self, batch: tuple[Tensor, Tensor]) -> STEP_OUTPUT: + """Perform a single training step based on the input tensors. + + Args: + batch (tuple[Tensor, Tensor]): the training images and their corresponding targets + + Returns: + Tensor: the loss corresponding to this training step. + """ + img, target = self.format_batch_fn(batch) logits = self.forward(img) target = F.resize(target, logits.shape[-2:], interpolation=F.InterpolationMode.NEAREST) logits = rearrange(logits, "b c h w -> (b h w) c") @@ -178,7 +193,14 @@ def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> STEP_OU self.log("train_loss", loss, prog_bar=True, logger=True) return loss - def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + def validation_step(self, batch: tuple[Tensor, Tensor]) -> None: + """Perform a single validation step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the validation batch. + + Args: + batch (tuple[Tensor, Tensor]): the validation images and their corresponding targets + """ img, targets = batch logits = self.forward(img) targets = F.resize( @@ -195,7 +217,14 @@ def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: self.val_seg_metrics.update(probs, targets) self.val_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) - def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: + def test_step(self, batch: tuple[Tensor, Tensor]) -> None: + """Perform a single test step based on the input tensors. + + Compute the prediction of the model and the value of the metrics on the test batch. + + Args: + batch (tuple[Tensor, Tensor]): the test images and their corresponding targets + """ img, targets = batch logits = self.forward(img) targets = F.resize( @@ -224,6 +253,7 @@ def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, targets)) def on_validation_epoch_end(self) -> None: + """Compute and log the values of the collected metrics in `validation_step`.""" res_dict = self.val_seg_metrics.compute() self.log_dict(res_dict, logger=True, sync_dist=True) self.log( @@ -237,6 +267,7 @@ def on_validation_epoch_end(self) -> None: self.val_sbsmpl_seg_metrics.reset() def on_test_epoch_end(self) -> None: + """Compute, log, and plot the values of the collected metrics in `test_step`.""" self.log_dict(self.test_seg_metrics.compute(), sync_dist=True) self.log_dict(self.test_sbsmpl_seg_metrics.compute(), sync_dist=True) if isinstance(self.logger, Logger) and self.log_plots: @@ -255,7 +286,7 @@ def on_test_epoch_end(self) -> None: self.log_segmentation_plots() def log_segmentation_plots(self) -> None: - """Builds and logs examples of segmentation plots from the test set.""" + """Build and log examples of segmentation plots from the test set.""" for i, (img, pred, tgt) in enumerate(self.sample_buffer): pred = pred == torch.arange(self.num_classes, device=pred.device)[:, None, None] tgt = tgt == torch.arange(self.num_classes, device=tgt.device)[:, None, None] @@ -278,6 +309,15 @@ def log_segmentation_plots(self) -> None: ) def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + """Select a random sample of the data to compute the loss onto. + + Args: + pred (Tensor): the prediction tensor. + target (Tensor): the target tensor. + + Returns: + Tuple[Tensor, Tensor]: the subsampled prediction and target tensors. + """ total_size = target.size(0) num_samples = max(1, int(total_size * self.metric_subsampling_rate)) indices = torch.randperm(total_size, device=pred.device)[:num_samples] @@ -289,6 +329,13 @@ def _segmentation_routine_checks( metric_subsampling_rate: float, num_calibration_bins: int, ) -> None: + """Check the domains of the routine's parameters. + + Args: + num_classes (int): the number of classes in the dataset. + metric_subsampling_rate (float): the rate of subsampling to compute the metrics. + num_calibration_bins (int): the number of bins for the evaluation of the calibration. + """ if num_classes < 2: raise ValueError(f"num_classes must be at least 2, got {num_classes}.") diff --git a/torch_uncertainty/utils/evaluation_loop.py b/torch_uncertainty/utils/evaluation_loop.py index d079cae4..f3088103 100644 --- a/torch_uncertainty/utils/evaluation_loop.py +++ b/torch_uncertainty/utils/evaluation_loop.py @@ -27,7 +27,7 @@ def _add_row(table: Table, metric_name: str, value: Tensor) -> None: if metric_name in PERCENTAGE_METRICS: value = value * 100 - table.add_row(metric_name, f"{value.item():.2f}%") + table.add_row(metric_name, f"{value.item():.3f}%") else: table.add_row(metric_name, f"{value.item():.5f}") diff --git a/torch_uncertainty/utils/learning_rate.py b/torch_uncertainty/utils/learning_rate.py deleted file mode 100644 index b30602ce..00000000 --- a/torch_uncertainty/utils/learning_rate.py +++ /dev/null @@ -1,29 +0,0 @@ -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler - - -class PolyLR(LRScheduler): - def __init__( - self, - optimizer: Optimizer, - total_iters: int, - power: float = 0.9, - last_epoch: int = -1, - min_lr: float = 1e-6, - ) -> None: - self.power = power - self.total_iters = total_iters - self.min_lr = min_lr - super().__init__(optimizer, last_epoch) - - def get_lr(self) -> list[float]: - return self._get_closed_form_lr() - - def _get_closed_form_lr(self) -> list[float]: - return [ - max( - base_lr * (1 - self.last_epoch / self.total_iters) ** self.power, - self.min_lr, - ) - for base_lr in self.base_lrs - ] diff --git a/torch_uncertainty/utils/to_hub_format.py b/torch_uncertainty/utils/to_hub_format.py index 1e1136d3..8db9bcea 100644 --- a/torch_uncertainty/utils/to_hub_format.py +++ b/torch_uncertainty/utils/to_hub_format.py @@ -24,7 +24,7 @@ raise ValueError("File does not exist") dtype = torch.float16 if args.fp16 else torch.float32 -model = torch.load(args.path)["state_dict"] +model = torch.load(args.path, weights_only=True)["state_dict"] model = {key.replace("model.", ""): val.to(device="cpu", dtype=dtype) for key, val in model.items()} output_name = args.name