From 8b4fca4da8a1b1bff2fbc2fe9131d019cdaf588e Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Wed, 2 Jul 2025 17:00:31 +0100 Subject: [PATCH 01/16] inital commit: Don't re-init buffer, inplace fill Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/components/base/memory_bank_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/models/components/base/memory_bank_module.py b/src/anomalib/models/components/base/memory_bank_module.py index 07eae880bc..bf3a0d6d82 100644 --- a/src/anomalib/models/components/base/memory_bank_module.py +++ b/src/anomalib/models/components/base/memory_bank_module.py @@ -73,7 +73,7 @@ def on_validation_start(self) -> None: """ if not self._is_fitted: self.fit() - self._is_fitted = torch.tensor([True], device=self.device) + self._is_fitted.data.fill_(value=True) def on_train_epoch_end(self) -> None: """Ensure memory bank is fitted after training. @@ -82,4 +82,4 @@ def on_train_epoch_end(self) -> None: """ if not self._is_fitted: self.fit() - self._is_fitted = torch.tensor([True], device=self.device) + self._is_fitted.data.fill_(value=True) From 0c603c375fd33b3c2f1ba1408210fcc9a07866ef Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Wed, 2 Jul 2025 20:22:26 +0100 Subject: [PATCH 02/16] implement patchcore memory savings Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- .../models/image/patchcore/lightning_model.py | 13 ++------ .../models/image/patchcore/torch_model.py | 32 +++++++++++++------ 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/anomalib/models/image/patchcore/lightning_model.py b/src/anomalib/models/image/patchcore/lightning_model.py index ac673b85b2..8180d8497f 100644 --- a/src/anomalib/models/image/patchcore/lightning_model.py +++ b/src/anomalib/models/image/patchcore/lightning_model.py @@ -159,7 +159,6 @@ def __init__( num_neighbors=num_neighbors, ) self.coreset_sampling_ratio = coreset_sampling_ratio - self.embeddings: list[torch.Tensor] = [] @classmethod def configure_pre_processor( @@ -235,9 +234,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: ``fit()``. """ del args, kwargs # These variables are not used. - - embedding = self.model(batch.image) - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) @@ -245,14 +242,10 @@ def fit(self) -> None: """Apply subsampling to the embedding collected from the training set. This method: - 1. Aggregates embeddings from all training batches - 2. Applies coreset subsampling to reduce memory requirements + 1. Applies coreset subsampling to reduce memory requirements """ - logger.info("Aggregating the embedding extracted from the training set.") - embeddings = torch.vstack(self.embeddings) - logger.info("Applying core-set subsampling to get the embedding.") - self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio) + self.model.subsample_embedding(self.coreset_sampling_ratio) def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Generate predictions for a batch of images. diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 2a777a81ab..2ee7c86779 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -123,9 +123,7 @@ def __init__( ).eval() self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator() - - self.register_buffer("memory_bank", torch.Tensor()) - self.memory_bank: torch.Tensor + self.register_buffer("memory_bank", torch.empty(0, self.feature_extractor.out_dims[1])) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Process input tensor through the model. @@ -169,7 +167,19 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embedding = self.reshape_embedding(embedding) if self.training: + # Grow memory bank while preserving buffer registration + if self.memory_bank.numel() > 0: + new_bank = torch.cat((self.memory_bank, embedding), dim=0) + self.memory_bank.data = new_bank + else: + self.memory_bank.data = embedding return embedding + + # Ensure memory bank is not empty + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot perform coreset selection." + raise ValueError(msg) + # apply nearest neighbor search patch_scores, locations = self.nearest_neighbors(embedding=embedding, n_neighbors=1) # reshape to batch dimension @@ -239,26 +249,28 @@ def reshape_embedding(embedding: torch.Tensor) -> torch.Tensor: embedding_size = embedding.size(1) return embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) - def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: - """Subsample embeddings using coreset selection. + def subsample_embedding(self, sampling_ratio: float) -> None: + """Subsample the memory_banks embeddings using coreset selection. Uses k-center-greedy coreset subsampling to select a representative subset of patch embeddings to store in the memory bank. Args: - embedding (torch.Tensor): Embedding tensor to subsample from. sampling_ratio (float): Fraction of embeddings to keep, in range (0,1]. Example: - >>> embedding = torch.randn(1000, 512) - >>> model.subsample_embedding(embedding, sampling_ratio=0.1) + >>> model.memory_bank = torch.randn(1000, 512) + >>> model.subsample_embedding(sampling_ratio=0.1) >>> model.memory_bank.shape torch.Size([100, 512]) """ + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot perform coreset selection." + raise ValueError(msg) # Coreset Subsampling - sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) + sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) coreset = sampler.sample_coreset() - self.memory_bank = coreset + self.memory_bank.data = coreset.clone() @staticmethod def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From 73fd2eebcc3d2a1a19b751a894186c346837b5bf Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Thu, 3 Jul 2025 13:26:19 +0100 Subject: [PATCH 03/16] deprecation warning for internal subsample_embedding of patchcore Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/image/patchcore/torch_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 2ee7c86779..12ff9d9a3d 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -33,6 +33,7 @@ # Copyright (C) 2022-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import warnings from collections.abc import Sequence from typing import TYPE_CHECKING @@ -249,7 +250,7 @@ def reshape_embedding(embedding: torch.Tensor) -> torch.Tensor: embedding_size = embedding.size(1) return embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) - def subsample_embedding(self, sampling_ratio: float) -> None: + def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = None) -> None: """Subsample the memory_banks embeddings using coreset selection. Uses k-center-greedy coreset subsampling to select a representative @@ -257,6 +258,9 @@ def subsample_embedding(self, sampling_ratio: float) -> None: Args: sampling_ratio (float): Fraction of embeddings to keep, in range (0,1]. + embeddings (torch.Tensor): **[DEPRECATED]** + This argument is deprecated and will be removed in a future release. + Use the default behavior (i.e., rely on `self.memory_bank`) instead. Example: >>> model.memory_bank = torch.randn(1000, 512) @@ -264,6 +268,13 @@ def subsample_embedding(self, sampling_ratio: float) -> None: >>> model.memory_bank.shape torch.Size([100, 512]) """ + if embeddings is not None: + warnings.warn( + "The 'embeddings' argument is deprecated and will be removed in a future release." + "Please use the default memory bank instead.", + DeprecationWarning, + stacklevel=2, + ) if self.memory_bank.size(0) == 0: msg = "Memory bank is empty. Cannot perform coreset selection." raise ValueError(msg) From f2dca3141846359810f0c05359e89786d72a8885 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Thu, 3 Jul 2025 14:33:55 +0100 Subject: [PATCH 04/16] PaDiM gpu memory usage reduction Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- .../models/image/padim/lightning_model.py | 11 ++----- .../models/image/padim/torch_model.py | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index c8f24b0b26..c890992f2e 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -131,9 +131,6 @@ def __init__( n_features=n_features, ) - self.stats: list[torch.Tensor] = [] - self.embeddings: list[torch.Tensor] = [] - @staticmethod def configure_optimizers() -> None: """PADIM doesn't require optimization, therefore returns no optimizers.""" @@ -154,19 +151,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - embedding = self.model(batch.image) - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: """Fit a Gaussian to the embedding collected from the training set.""" - logger.info("Aggregating the embedding extracted from the training set.") - embeddings = torch.vstack(self.embeddings) - logger.info("Fitting a Gaussian to the embedding collected from the training set.") - self.stats = self.model.gaussian.fit(embeddings) + self.model.fit_gaussian() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a validation step of PADIM. diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index 317299bdc6..3cc07bfefd 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -147,6 +147,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.gaussian = MultiVariateGaussian() + self.memory_bank = torch.empty(0, self.feature_extractor.out_dims[1]) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. @@ -182,6 +183,12 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embeddings = self.tiler.untile(embeddings) if self.training: + # Grow memory bank + if self.memory_bank.numel() > 0: + new_bank = torch.cat((self.memory_bank, embeddings), dim=0) + self.memory_bank.data = new_bank + else: + self.memory_bank.data = embeddings return embeddings anomaly_map = self.anomaly_map_generator( @@ -217,3 +224,25 @@ def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: # subsample embeddings idx = self.idx.to(embeddings.device) return torch.index_select(embeddings, 1, idx) + + def fit_gaussian(self) -> None: + """Fits a Gaussian model to the current contents of the memory bank. + + This method is typically called after the memory bank has been filled during training. + It uses the stored features to fit a Gaussian distribution, which may be used for + downstream tasks such as anomaly detection, coreset selection, or probabilistic modeling. + + After fitting, the memory bank is cleared to free GPU memory before validation or testing. + + Raises: + ValueError: If the memory bank is empty. + """ + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot perform coreset selection." + raise ValueError(msg) + + # fit gaussian + self.gaussian.fit(self.memory_bank) + + # clear memory bank, redcues gpu size + self.memory_bank = torch.empty(0, self.feature_extractor.out_dims[1]) From 0275ce9616c3196186cc1b70e62f6c6bdf8299f1 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Thu, 3 Jul 2025 15:22:29 +0100 Subject: [PATCH 05/16] refactor memory bank tensor assignment for readability. dfkdde refactor Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- .../models/image/dfkde/lightning_model.py | 9 ++---- .../models/image/dfkde/torch_model.py | 28 +++++++++++++++++++ .../models/image/padim/torch_model.py | 17 ++++++----- .../models/image/patchcore/torch_model.py | 18 ++++++------ 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index dc8591eb54..2f4bd1aeea 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -113,8 +113,6 @@ def __init__( max_training_points=max_training_points, ) - self.embeddings: list[torch.Tensor] = [] - @staticmethod def configure_optimizers() -> None: # pylint: disable=arguments-differ """DFKDE doesn't require optimization, therefore returns no optimizers.""" @@ -133,18 +131,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - embedding = self.model(batch.image) - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: """Fit KDE model to collected embeddings from the training set.""" - embeddings = torch.vstack(self.embeddings) - logger.info("Fitting a KDE model to the embedding collected from the training set.") - self.model.classifier.fit(embeddings) + self.model.fit_classifier() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform validation by computing anomaly scores. diff --git a/src/anomalib/models/image/dfkde/torch_model.py b/src/anomalib/models/image/dfkde/torch_model.py index bb7f9c28fe..b637a26f0c 100644 --- a/src/anomalib/models/image/dfkde/torch_model.py +++ b/src/anomalib/models/image/dfkde/torch_model.py @@ -89,6 +89,8 @@ def __init__( feature_scaling_method=feature_scaling_method, max_training_points=max_training_points, ) + self.embedding_dim = sum(self.feature_extractor.out_dims) + self.memory_bank = torch.empty(0, self.embedding_dim) def get_features(self, batch: torch.Tensor) -> torch.Tensor: """Extract features from the pre-trained backbone network. @@ -141,8 +143,34 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: # 1. apply feature extraction features = self.get_features(batch) if self.training: + with torch.no_grad(): + if self.memory_bank.numel() > 0: + self.memory_bank = torch.cat((self.memory_bank, features), dim=0) + else: + self.memory_bank = features return features # 2. apply density estimation scores = self.classifier(features) return InferenceBatch(pred_score=scores) + + def fit_classifier(self) -> None: + """Fits the classifier using the current contents of the memory bank. + + This method is typically called after the memory bank has been populated + during training. + + After fitting, the memory bank is cleared to reduce GPU memory usage. + + Raises: + ValueError: If the memory bank is empty. + """ + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot perform coreset selection." + raise ValueError(msg) + + # fit gaussian + self.classifier.fit(self.memory_bank) + + # clear memory bank, redcues gpu size + self.memory_bank = torch.empty(0, self.embedding_dim) diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index 3cc07bfefd..5bc443dea8 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -147,7 +147,8 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.gaussian = MultiVariateGaussian() - self.memory_bank = torch.empty(0, self.feature_extractor.out_dims[1]) + self.embedding_dim = sum(self.feature_extractor.out_dims) + self.memory_bank = torch.empty(0, self.embedding_dim) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. @@ -184,11 +185,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: if self.training: # Grow memory bank - if self.memory_bank.numel() > 0: - new_bank = torch.cat((self.memory_bank, embeddings), dim=0) - self.memory_bank.data = new_bank - else: - self.memory_bank.data = embeddings + with torch.no_grad(): + if self.memory_bank.numel() > 0: + self.memory_bank = torch.cat((self.memory_bank, embeddings), dim=0) + else: + self.memory_bank = embeddings return embeddings anomaly_map = self.anomaly_map_generator( @@ -229,8 +230,6 @@ def fit_gaussian(self) -> None: """Fits a Gaussian model to the current contents of the memory bank. This method is typically called after the memory bank has been filled during training. - It uses the stored features to fit a Gaussian distribution, which may be used for - downstream tasks such as anomaly detection, coreset selection, or probabilistic modeling. After fitting, the memory bank is cleared to free GPU memory before validation or testing. @@ -245,4 +244,4 @@ def fit_gaussian(self) -> None: self.gaussian.fit(self.memory_bank) # clear memory bank, redcues gpu size - self.memory_bank = torch.empty(0, self.feature_extractor.out_dims[1]) + self.memory_bank = torch.empty(0, self.embedding_dim) diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 12ff9d9a3d..5a52086c14 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -124,7 +124,9 @@ def __init__( ).eval() self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator() - self.register_buffer("memory_bank", torch.empty(0, self.feature_extractor.out_dims[1])) + embedding_dim = sum(self.feature_extractor.out_dims) + self.memory_bank: torch.Tensor + self.register_buffer("memory_bank", torch.empty(0, embedding_dim)) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Process input tensor through the model. @@ -168,12 +170,12 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embedding = self.reshape_embedding(embedding) if self.training: - # Grow memory bank while preserving buffer registration - if self.memory_bank.numel() > 0: - new_bank = torch.cat((self.memory_bank, embedding), dim=0) - self.memory_bank.data = new_bank - else: - self.memory_bank.data = embedding + with torch.no_grad(): + if self.memory_bank.numel() > 0: + new_bank = torch.cat((self.memory_bank, embedding), dim=0) + self.memory_bank.resize_(new_bank.size()).copy_(new_bank) + else: + self.memory_bank.resize_(embedding.size()).copy_(embedding) return embedding # Ensure memory bank is not empty @@ -281,7 +283,7 @@ def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = # Coreset Subsampling sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) coreset = sampler.sample_coreset() - self.memory_bank.data = coreset.clone() + self.memory_bank = coreset @staticmethod def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From a0bdbf2ffe96a7604f9a520b59879c04a2bd5281 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Thu, 3 Jul 2025 15:49:40 +0100 Subject: [PATCH 06/16] unify memory bank models trainer arguments. refactor dfm to fit new memory bank framework Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- .../models/image/dfkde/lightning_model.py | 8 ++- .../models/image/dfm/lightning_model.py | 20 +++---- src/anomalib/models/image/dfm/torch_model.py | 59 ++++++++++++++----- .../models/image/padim/lightning_model.py | 14 ++--- .../models/image/patchcore/lightning_model.py | 3 +- 5 files changed, 67 insertions(+), 37 deletions(-) diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index 2f4bd1aeea..0c03415c49 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -162,9 +162,13 @@ def trainer_arguments(self) -> dict[str, Any]: """Get DFKDE-specific trainer arguments. Returns: - dict[str, Any]: Dictionary of trainer arguments. + dict[str, Any]: Trainer arguments + - ``gradient_clip_val``: ``0`` (no gradient clipping needed) + - ``max_epochs``: ``1`` (single pass through training data) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/dfm/lightning_model.py b/src/anomalib/models/image/dfm/lightning_model.py index bd0038dd77..9e128e00db 100644 --- a/src/anomalib/models/image/dfm/lightning_model.py +++ b/src/anomalib/models/image/dfm/lightning_model.py @@ -112,7 +112,6 @@ def __init__( n_comps=pca_level, score_type=score_type, ) - self.embeddings: list[torch.Tensor] = [] self.score_type = score_type @staticmethod @@ -137,8 +136,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - embedding = self.model.get_features(batch.image).squeeze() - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) @@ -149,11 +147,8 @@ def fit(self) -> None: The method aggregates embeddings collected during training and fits both the PCA transformation and Gaussian model used for scoring. """ - logger.info("Aggregating the embedding extracted from the training set.") - embeddings = torch.vstack(self.embeddings) - logger.info("Fitting a PCA and a Gaussian model to dataset.") - self.model.fit(embeddings) + self.model.fit() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Compute predictions for the input batch during validation. @@ -176,12 +171,13 @@ def trainer_arguments(self) -> dict[str, Any]: """Get DFM-specific trainer arguments. Returns: - dict[str, Any]: Dictionary of trainer arguments: - - ``gradient_clip_val`` (int): Disable gradient clipping - - ``max_epochs`` (int): Train for one epoch only - - ``num_sanity_val_steps`` (int): Skip validation sanity checks + dict[str, Any]: Trainer arguments + - ``gradient_clip_val``: ``0`` (no gradient clipping needed) + - ``max_epochs``: ``1`` (single pass through training data) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index 3e9aa8228a..664a28bae3 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -26,6 +26,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import warnings import torch from torch import nn @@ -153,18 +154,37 @@ def __init__( layers=[layer], ).eval() - def fit(self, dataset: torch.Tensor) -> None: + self.embedding_dim = sum(self.feature_extractor.out_dims) + self.memory_bank = torch.empty(0, self.embedding_dim) + + def fit(self, dataset: torch.Tensor = None) -> None: """Fit PCA and Gaussian model to dataset. Args: - dataset (torch.Tensor): Input dataset with shape - ``(n_samples, n_features)``. + dataset (torch.Tensor, optional): **[DEPRECATED]** + Input dataset with shape ``(n_samples, n_features)``. + This argument is deprecated and will be ignored. Instead, + the method now uses `self.memory_bank` as the input dataset. """ - self.pca_model.fit(dataset) + if dataset is not None: + warnings.warn( + "The 'dataset' argument is deprecated and will be removed in a future release. " + "The method now uses 'self.memory_bank' as the input dataset.", + DeprecationWarning, + stacklevel=2, + ) + + # Use self.memory_bank as the dataset + dataset_to_use = self.memory_bank + + self.pca_model.fit(dataset_to_use) if self.score_type == "nll": - features_reduced = self.pca_model.transform(dataset) + features_reduced = self.pca_model.transform(dataset_to_use) self.gaussian_model.fit(features_reduced.T) + # clear memory bank, reduces GPU size + self.memory_bank = torch.empty(0, self.embedding_dim) + def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: """Compute anomaly scores. @@ -202,17 +222,16 @@ def get_features(self, batch: torch.Tensor) -> torch.Tensor: ``(batch_size, channels, height, width)``. Returns: - Union[torch.Tensor, Tuple[torch.Tensor, torch.Size]]: Features during - training, or tuple of (features, feature_shapes) during inference. + tuple of (features, feature_shapes). """ - self.feature_extractor.eval() - features = self.feature_extractor(batch)[self.layer] - batch_size = len(features) - if self.pooling_kernel_size > 1: - features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size) - feature_shapes = features.shape - features = features.view(batch_size, -1).detach() - return features if self.training else (features, feature_shapes) + with torch.no_grad(): + features = self.feature_extractor(batch)[self.layer] + batch_size = len(features) + if self.pooling_kernel_size > 1: + features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size) + feature_shapes = features.shape + features = features.view(batch_size, -1) + return features, feature_shapes def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: """Compute anomaly predictions from input images. @@ -227,6 +246,16 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: ``InferenceBatch`` with prediction scores and anomaly maps. """ feature_vector, feature_shapes = self.get_features(batch) + + if self.training: + # Grow memory bank + with torch.no_grad(): + if self.memory_bank.numel() > 0: + self.memory_bank = torch.cat((self.memory_bank, feature_vector), dim=0) + else: + self.memory_bank = feature_vector + return feature_vector + pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes) if anomaly_map is not None: anomaly_map = F.interpolate(anomaly_map, size=batch.shape[-2:], mode="bilinear", align_corners=False) diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index c890992f2e..57291dc8be 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -183,16 +183,16 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, int | float]: - """Return PADIM trainer arguments. - - Since the model does not require training, we limit the max_epochs to 1. - Since we need to run training epoch before validation, we also set the - sanity steps to 0. + """Get default trainer arguments for Padim. Returns: - dict[str, int | float]: Dictionary of trainer arguments + dict[str, Any]: Trainer arguments + - ``max_epochs``: ``1`` (single pass through training data) + - ``val_check_interval``: ``1.0`` (check validation every 1 step) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0} + return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0, "devices": 0} @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/patchcore/lightning_model.py b/src/anomalib/models/image/patchcore/lightning_model.py index 8180d8497f..eec92feea5 100644 --- a/src/anomalib/models/image/patchcore/lightning_model.py +++ b/src/anomalib/models/image/patchcore/lightning_model.py @@ -279,8 +279,9 @@ def trainer_arguments(self) -> dict[str, Any]: - ``gradient_clip_val``: ``0`` (no gradient clipping needed) - ``max_epochs``: ``1`` (single pass through training data) - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: From a3867328469936f8c7d9fbc8555651bb81a90422 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Thu, 3 Jul 2025 16:00:34 +0100 Subject: [PATCH 07/16] bugfix: padim device count set to zero Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/image/padim/lightning_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index 57291dc8be..0ad220b2c1 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -192,7 +192,7 @@ def trainer_arguments(self) -> dict[str, int | float]: - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) - ``devices``: ``1`` (only single gpu supported) """ - return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0, "devices": 0} + return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: From 572dc5e19dc8de541b3820c66ce7b28976482d9c Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Fri, 4 Jul 2025 15:34:15 +0100 Subject: [PATCH 08/16] ensure buffer is not replaced but instead resized and filled Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/image/patchcore/torch_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 5a52086c14..3024a09e9a 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -283,7 +283,7 @@ def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = # Coreset Subsampling sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) coreset = sampler.sample_coreset() - self.memory_bank = coreset + self.memory_bank.resize_(coreset.size()).copy_(coreset) @staticmethod def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From c714b85642cbdabb7b7c1b958134c2e60cc68379 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Mon, 7 Jul 2025 15:09:22 +0100 Subject: [PATCH 09/16] give memory bank type Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- .../models/image/dfkde/torch_model.py | 15 ++++++------ src/anomalib/models/image/dfm/torch_model.py | 23 ++++++++----------- .../models/image/padim/torch_model.py | 18 +++++++-------- .../models/image/patchcore/torch_model.py | 20 ++++++++-------- 4 files changed, 33 insertions(+), 43 deletions(-) diff --git a/src/anomalib/models/image/dfkde/torch_model.py b/src/anomalib/models/image/dfkde/torch_model.py index b637a26f0c..8ff4aa1ec0 100644 --- a/src/anomalib/models/image/dfkde/torch_model.py +++ b/src/anomalib/models/image/dfkde/torch_model.py @@ -89,8 +89,7 @@ def __init__( feature_scaling_method=feature_scaling_method, max_training_points=max_training_points, ) - self.embedding_dim = sum(self.feature_extractor.out_dims) - self.memory_bank = torch.empty(0, self.embedding_dim) + self.memory_bank = torch.Tensor() def get_features(self, batch: torch.Tensor) -> torch.Tensor: """Extract features from the pre-trained backbone network. @@ -143,11 +142,11 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: # 1. apply feature extraction features = self.get_features(batch) if self.training: - with torch.no_grad(): - if self.memory_bank.numel() > 0: - self.memory_bank = torch.cat((self.memory_bank, features), dim=0) - else: - self.memory_bank = features + if self.memory_bank.size(0) == 0: + self.memory_bank = features + else: + new_bank = torch.cat((self.memory_bank, features), dim=0).to(self.memory_bank) + self.memory_bank = new_bank return features # 2. apply density estimation @@ -173,4 +172,4 @@ def fit_classifier(self) -> None: self.classifier.fit(self.memory_bank) # clear memory bank, redcues gpu size - self.memory_bank = torch.empty(0, self.embedding_dim) + self.memory_bank = torch.Tensor().to(self.memory_bank) diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index 664a28bae3..379256c5b7 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -154,8 +154,7 @@ def __init__( layers=[layer], ).eval() - self.embedding_dim = sum(self.feature_extractor.out_dims) - self.memory_bank = torch.empty(0, self.embedding_dim) + self.memory_bank = torch.Tensor() def fit(self, dataset: torch.Tensor = None) -> None: """Fit PCA and Gaussian model to dataset. @@ -174,16 +173,13 @@ def fit(self, dataset: torch.Tensor = None) -> None: stacklevel=2, ) - # Use self.memory_bank as the dataset - dataset_to_use = self.memory_bank - - self.pca_model.fit(dataset_to_use) + self.pca_model.fit(self.memory_bank) if self.score_type == "nll": - features_reduced = self.pca_model.transform(dataset_to_use) + features_reduced = self.pca_model.transform(self.memory_bank) self.gaussian_model.fit(features_reduced.T) # clear memory bank, reduces GPU size - self.memory_bank = torch.empty(0, self.embedding_dim) + self.memory_bank = torch.Tensor().to(self.memory_bank) def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: """Compute anomaly scores. @@ -248,12 +244,11 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: feature_vector, feature_shapes = self.get_features(batch) if self.training: - # Grow memory bank - with torch.no_grad(): - if self.memory_bank.numel() > 0: - self.memory_bank = torch.cat((self.memory_bank, feature_vector), dim=0) - else: - self.memory_bank = feature_vector + if self.memory_bank.size(0) == 0: + self.memory_bank = feature_vector + else: + new_bank = torch.cat((self.memory_bank, feature_vector), dim=0).to(self.memory_bank) + self.memory_bank = new_bank return feature_vector pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes) diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index 5bc443dea8..6f88a57a09 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -147,8 +147,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.gaussian = MultiVariateGaussian() - self.embedding_dim = sum(self.feature_extractor.out_dims) - self.memory_bank = torch.empty(0, self.embedding_dim) + self.memory_bank = torch.Tensor() def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. @@ -184,12 +183,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embeddings = self.tiler.untile(embeddings) if self.training: - # Grow memory bank - with torch.no_grad(): - if self.memory_bank.numel() > 0: - self.memory_bank = torch.cat((self.memory_bank, embeddings), dim=0) - else: - self.memory_bank = embeddings + if self.memory_bank.size(0) == 0: + self.memory_bank = embeddings + else: + new_bank = torch.cat((self.memory_bank, embeddings), dim=0).to(self.memory_bank) + self.memory_bank = new_bank return embeddings anomaly_map = self.anomaly_map_generator( @@ -243,5 +241,5 @@ def fit_gaussian(self) -> None: # fit gaussian self.gaussian.fit(self.memory_bank) - # clear memory bank, redcues gpu size - self.memory_bank = torch.empty(0, self.embedding_dim) + # clear memory bank, redcues gpu usage + self.memory_bank = torch.Tensor().to(self.memory_bank) diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 3024a09e9a..8a65f01abc 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -124,9 +124,8 @@ def __init__( ).eval() self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator() - embedding_dim = sum(self.feature_extractor.out_dims) self.memory_bank: torch.Tensor - self.register_buffer("memory_bank", torch.empty(0, embedding_dim)) + self.register_buffer("memory_bank", torch.Tensor()) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Process input tensor through the model. @@ -170,17 +169,16 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embedding = self.reshape_embedding(embedding) if self.training: - with torch.no_grad(): - if self.memory_bank.numel() > 0: - new_bank = torch.cat((self.memory_bank, embedding), dim=0) - self.memory_bank.resize_(new_bank.size()).copy_(new_bank) - else: - self.memory_bank.resize_(embedding.size()).copy_(embedding) + if self.memory_bank.size(0) == 0: + self.memory_bank = embedding + else: + new_bank = torch.cat((self.memory_bank, embedding), dim=0).to(self.memory_bank) + self.memory_bank = new_bank return embedding # Ensure memory bank is not empty if self.memory_bank.size(0) == 0: - msg = "Memory bank is empty. Cannot perform coreset selection." + msg = "Memory bank is empty. Cannot provide anomaly scores" raise ValueError(msg) # apply nearest neighbor search @@ -282,8 +280,8 @@ def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = raise ValueError(msg) # Coreset Subsampling sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) - coreset = sampler.sample_coreset() - self.memory_bank.resize_(coreset.size()).copy_(coreset) + coreset = sampler.sample_coreset().type_as(self.memory_bank) + self.memory_bank = coreset @staticmethod def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From 16fabb31e8cde7e69be153dc8e718d058eaa86ef Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Mon, 7 Jul 2025 15:13:09 +0100 Subject: [PATCH 10/16] revert memory bank mixin back Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/components/base/memory_bank_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/models/components/base/memory_bank_module.py b/src/anomalib/models/components/base/memory_bank_module.py index bf3a0d6d82..07eae880bc 100644 --- a/src/anomalib/models/components/base/memory_bank_module.py +++ b/src/anomalib/models/components/base/memory_bank_module.py @@ -73,7 +73,7 @@ def on_validation_start(self) -> None: """ if not self._is_fitted: self.fit() - self._is_fitted.data.fill_(value=True) + self._is_fitted = torch.tensor([True], device=self.device) def on_train_epoch_end(self) -> None: """Ensure memory bank is fitted after training. @@ -82,4 +82,4 @@ def on_train_epoch_end(self) -> None: """ if not self._is_fitted: self.fit() - self._is_fitted.data.fill_(value=True) + self._is_fitted = torch.tensor([True], device=self.device) From b292713160a286c8ed0414b35a294c88f192e4ce Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Mon, 7 Jul 2025 15:54:47 +0100 Subject: [PATCH 11/16] to memory bank type and device Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/image/patchcore/torch_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 8a65f01abc..0f5a1c742e 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -280,7 +280,7 @@ def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = raise ValueError(msg) # Coreset Subsampling sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) - coreset = sampler.sample_coreset().type_as(self.memory_bank) + coreset = sampler.sample_coreset().to(self.memory_bank) self.memory_bank = coreset @staticmethod From 585c81a6558773f0ee0bc08b509afc00eca62c75 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Fri, 11 Jul 2025 10:41:43 +0100 Subject: [PATCH 12/16] update return type, add comment about deprecation, duplicated copyright Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/image/dfm/torch_model.py | 2 +- src/anomalib/models/image/patchcore/torch_model.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index de0930eaa5..a9e4986379 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -210,7 +210,7 @@ def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: return (score, None) if self.score_type == "nll" else (score, score_map) - def get_features(self, batch: torch.Tensor) -> torch.Tensor: + def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Size]: """Extract features from the pretrained network. Args: diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 11d18aebf4..175e06baf0 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -33,9 +33,6 @@ Coreset subsampling using k-center-greedy approach """ -# Copyright (C) 2022-2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - import warnings from collections.abc import Sequence from typing import TYPE_CHECKING From 9267dd19883192c07055043c4009402a3bc72a07 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:21:18 +0100 Subject: [PATCH 13/16] new deprecation wrapper, now handles arg deprecation Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/utils/deprecation.py | 222 +++++++++++++++++++++++------- 1 file changed, 175 insertions(+), 47 deletions(-) diff --git a/src/anomalib/utils/deprecation.py b/src/anomalib/utils/deprecation.py index 25717b8a50..e9986da6d4 100644 --- a/src/anomalib/utils/deprecation.py +++ b/src/anomalib/utils/deprecation.py @@ -3,7 +3,7 @@ """Deprecation utilities for anomalib. -This module provides utilities for marking functions and classes as deprecated. +This module provides utilities for marking functions, classes and args as deprecated. The utilities help maintain backward compatibility while encouraging users to migrate to newer APIs. @@ -34,6 +34,25 @@ >>> @deprecate(since="2.1.0") ... def yet_another_function(): ... pass + + >>> # Deprecation of specific argument(s) + >>> @deprecate(args={"old_param": "new_param"}, since="2.1.0", remove="3.0.0") + ... def my_function(new_param=None, old_param=None): + ... # Handle the mapping logic yourself + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... # Rest of function logic + ... pass + + >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") + >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): + ... # Handle each mapping individually + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... if two_param is not None and two_new_param is None: + ... two_new_param = two_param + ... # Rest of function logic + ... pass """ import inspect @@ -56,6 +75,7 @@ def deprecate( remove: str | None = None, use: str | None = None, reason: str | None = None, + args: dict[str, str] | None = None, warning_category: type[Warning] = FutureWarning, ) -> Callable[[_T], _T]: ... @@ -67,25 +87,30 @@ def deprecate( remove: str | None = None, use: str | None = None, reason: str | None = None, - warning_category: type[Warning] = FutureWarning, + args: dict[str, str] | None = None, + warning_category: type[Warning] = DeprecationWarning, ) -> Any: - """Mark a function or class as deprecated. + """Mark a function, class, or keyword argument as deprecated. - This decorator will cause a warning to be emitted when the function is called or class is instantiated. + This decorator will cause a warning to be emitted when the function is called, + class is instantiated, or deprecated keyword arguments are used. + If args are passed, no function deprecation will be shown. Args: __obj: The function or class to be deprecated. If provided, the decorator is used without arguments. If None, the decorator is used with arguments. - since: Version when the function/class was deprecated (e.g. "2.1.0"). + since: Version when the function/class/arg was deprecated (e.g. "2.1.0"). If not provided, no deprecation version will be shown in the warning message. - remove: Version when the function/class will be removed. + remove: Version when the function/class/arg will be removed. If not provided, no removal version will be shown in the warning message. - use: Name of the replacement function/class. + use: Name of the replacement function/class/arg. If not provided, no replacement suggestion will be shown in the warning message. reason: Additional reason for the deprecation. If not provided, no reason will be shown in the warning message. - warning_category: Type of warning to emit (default: FutureWarning). + args: Mapping of deprecated argument names to their replacements. + If not provided, no argument deprecation warning will be shown. + warning_category: Type of warning to emit (default: DeprecationWarning). Can be DeprecationWarning, FutureWarning, PendingDeprecationWarning, etc. Returns: @@ -116,49 +141,152 @@ def deprecate( >>> @deprecate(since="2.1.0", reason="Performance improvements available", warning_category=FutureWarning) ... def old_slow_function(): ... pass + + >>> # Deprecation of specific argument(s) + >>> @deprecate(args={"old_param": "new_param"}, since="2.1.0", remove="3.0.0") + ... def my_function(new_param=None, old_param=None): + ... # Handle the mapping logic yourself + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... # Rest of function logic + ... pass + + >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") + >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): + ... # Handle each mapping individually + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... if two_param is not None and two_new_param is None: + ... two_new_param = two_param + ... # Rest of function logic + ... pass """ def _deprecate_impl(obj: Any) -> Any: # noqa: ANN401 if isinstance(obj, type): - # Handle class deprecation - original_init = inspect.getattr_static(obj, "__init__") - - @wraps(original_init) - def new_init(self: Any, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 - msg = f"{obj.__name__} is deprecated" - if since: - msg += f" since v{since}" - if remove: - msg += f" and will be removed in v{remove}" - if use: - msg += f". Use {use} instead" - if reason: - msg += f". {reason}" - msg += "." - warnings.warn(msg, warning_category, stacklevel=2) - original_init(self, *args, **kwargs) - - setattr(obj, "__init__", new_init) # noqa: B010 - return obj - - # Handle function deprecation - @wraps(obj) - def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 - msg = f"{obj.__name__} is deprecated" - if since: - msg += f" since v{since}" - if remove: - msg += f" and will be removed in v{remove}" - if use: - msg += f". Use {use} instead" - if reason: - msg += f". {reason}" - msg += "." + return _wrap_deprecated_init(obj, since, remove, use, reason, warning_category) + return _wrap_deprecated_function(obj, since, remove, use, reason, args, warning_category) + + return _deprecate_impl if __obj is None else _deprecate_impl(__obj) + + +def _build_deprecation_msg( + name: str, + since: str | None, + remove: str | None, + use: str | None, + reason: str | None, + prefix: str = "", +) -> str: + """Construct a standardized deprecation warning message. + + Args: + name: Name of the deprecated item (function, class, or argument). + since: Version when the item was deprecated. + remove: Version when the item will be removed. + use: Suggested replacement name. + reason: Additional reason for deprecation. + prefix: Optional prefix (e.g., 'Argument ') for clarity. + + Returns: + A complete deprecation message string. + """ + msg = f"{prefix}{name} is deprecated" + if since: + msg += f" since v{since}" + if remove: + msg += f" and will be removed in v{remove}" + if use: + msg += f". Use {use} instead" + if reason: + msg += f". {reason}" + return msg + "." + + +def _wrap_deprecated_init( + cls: type, + since: str | None, + remove: str | None, + use: str | None, + reason: str | None, + warning_category: type[Warning], +) -> type: + """Wrap a class constructor to emit a deprecation warning when instantiated. + + Args: + cls: The class being deprecated. + since: Version when deprecation began. + remove: Version when the class will be removed. + use: Suggested replacement class. + reason: Additional reason for deprecation. + warning_category: The type of warning to emit. + + Returns: + The decorated class with a warning-emitting __init__. + """ + original_init = inspect.getattr_static(cls, "__init__") + + @wraps(original_init) + def new_init(self: Any, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 + msg = _build_deprecation_msg(cls.__name__, since, remove, use, reason) + warnings.warn(msg, warning_category, stacklevel=2) + original_init(self, *args, **kwargs) + + setattr(cls, "__init__", new_init) # noqa: B010 + return cls + + +def _wrap_deprecated_function( + func: Callable, + since: str | None, + remove: str | None, + use: str | None, + reason: str | None, + args_map: dict[str, str] | None, + warning_category: type[Warning], +) -> Callable: + """Wrap a function to emit a deprecation warning when called. + + Also handles deprecated keyword arguments, warning when old arguments are used. + + Args: + func: The function to wrap. + since: Version when the function was deprecated. + remove: Version when the function will be removed. + use: Suggested replacement function. + reason: Additional reason for deprecation. + args_map: Mapping of deprecated argument names to new names. + warning_category: The type of warning to emit. + + Returns: + The wrapped function that emits deprecation warnings when called. + """ + sig = inspect.signature(func) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + + # Check for deprecated arguments and warn (but don't replace) + if args_map: + for old_arg, new_arg in args_map.items(): + if old_arg in bound_args.arguments: + msg = _build_deprecation_msg( + old_arg, + since, + remove, + use=new_arg, + reason=None, + prefix="Argument ", + ) + warnings.warn(msg, warning_category, stacklevel=2) + + # If args_map exists, this is argument deprecation, not function deprecation + if not args_map: + msg = _build_deprecation_msg(func.__name__, since, remove, use, reason) warnings.warn(msg, warning_category, stacklevel=2) - return obj(*args, **kwargs) - return wrapper + return func(*bound_args.args, **bound_args.kwargs) - if __obj is None: - return _deprecate_impl - return _deprecate_impl(__obj) + return wrapper From d56a7bc44521b40cc8557b5f48e4c8e3f16d7c87 Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:54:02 +0100 Subject: [PATCH 14/16] add flexible tests for new deprecation warning Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/utils/deprecation.py | 304 +++++++++++++++++------------ tests/unit/utils/test_deprecate.py | 70 +++++++ 2 files changed, 245 insertions(+), 129 deletions(-) create mode 100644 tests/unit/utils/test_deprecate.py diff --git a/src/anomalib/utils/deprecation.py b/src/anomalib/utils/deprecation.py index e9986da6d4..baa91030ce 100644 --- a/src/anomalib/utils/deprecation.py +++ b/src/anomalib/utils/deprecation.py @@ -8,18 +8,11 @@ migrate to newer APIs. Example: - >>> from anomalib.utils.deprecation import deprecated - - >>> # Deprecation without any information + >>> # Basic deprecation without any information >>> @deprecate ... def old_function(): ... pass - >>> # Deprecation with replacement function - >>> @deprecate(use="new_function") - ... def old_function(): - ... pass - >>> # Deprecation with version info and replacement class >>> @deprecate(since="2.1.0", remove="2.5.0", use="NewClass") ... class OldClass: @@ -35,6 +28,17 @@ ... def yet_another_function(): ... pass + >>> # Deprecation with reason and custom warning category + >>> @deprecate(since="2.1.0", reason="Performance improvements available", warning_category=FutureWarning) + ... def old_slow_function(): + ... pass + + >>> # Deprecation of classes + >>> @deprecate(since="2.1.0", remove="3.0.0") + ... class MyClass: + ... def __init__(self): + ... pass + >>> # Deprecation of specific argument(s) >>> @deprecate(args={"old_param": "new_param"}, since="2.1.0", remove="3.0.0") ... def my_function(new_param=None, old_param=None): @@ -44,8 +48,8 @@ ... # Rest of function logic ... pass - >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") - >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): + >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") + >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): ... # Handle each mapping individually ... if old_param is not None and new_param is None: ... new_param = old_param @@ -53,6 +57,13 @@ ... two_new_param = two_param ... # Rest of function logic ... pass + + >>> @deprecate(args={"old_arg": "new_arg"}, since="2.1.0", remove="3.0.0") + >>> class MyClass: + ... def __init__(self, new_arg=None, old_arg=None): + ... if new_arg is None and old_arg is not None: + ... new_arg = old_arg + ... pass """ import inspect @@ -64,112 +75,6 @@ _T = TypeVar("_T") -@overload -def deprecate(__obj: _T) -> _T: ... - - -@overload -def deprecate( - *, - since: str | None = None, - remove: str | None = None, - use: str | None = None, - reason: str | None = None, - args: dict[str, str] | None = None, - warning_category: type[Warning] = FutureWarning, -) -> Callable[[_T], _T]: ... - - -def deprecate( - __obj: Any = None, - *, - since: str | None = None, - remove: str | None = None, - use: str | None = None, - reason: str | None = None, - args: dict[str, str] | None = None, - warning_category: type[Warning] = DeprecationWarning, -) -> Any: - """Mark a function, class, or keyword argument as deprecated. - - This decorator will cause a warning to be emitted when the function is called, - class is instantiated, or deprecated keyword arguments are used. - If args are passed, no function deprecation will be shown. - - Args: - __obj: The function or class to be deprecated. - If provided, the decorator is used without arguments. - If None, the decorator is used with arguments. - since: Version when the function/class/arg was deprecated (e.g. "2.1.0"). - If not provided, no deprecation version will be shown in the warning message. - remove: Version when the function/class/arg will be removed. - If not provided, no removal version will be shown in the warning message. - use: Name of the replacement function/class/arg. - If not provided, no replacement suggestion will be shown in the warning message. - reason: Additional reason for the deprecation. - If not provided, no reason will be shown in the warning message. - args: Mapping of deprecated argument names to their replacements. - If not provided, no argument deprecation warning will be shown. - warning_category: Type of warning to emit (default: DeprecationWarning). - Can be DeprecationWarning, FutureWarning, PendingDeprecationWarning, etc. - - Returns: - Decorated function/class that emits a deprecation warning when used. - - Example: - >>> # Basic deprecation without any information - >>> @deprecate - ... def old_function(): - ... pass - - >>> # Deprecation with version info and replacement class - >>> @deprecate(since="2.1.0", remove="2.5.0", use="NewClass") - ... class OldClass: - ... pass - - >>> # Deprecation with only removal version - >>> @deprecate(remove="2.5.0") - ... def another_function(): - ... pass - - >>> # Deprecation with only deprecation version - >>> @deprecate(since="2.1.0") - ... def yet_another_function(): - ... pass - - >>> # Deprecation with reason and custom warning category - >>> @deprecate(since="2.1.0", reason="Performance improvements available", warning_category=FutureWarning) - ... def old_slow_function(): - ... pass - - >>> # Deprecation of specific argument(s) - >>> @deprecate(args={"old_param": "new_param"}, since="2.1.0", remove="3.0.0") - ... def my_function(new_param=None, old_param=None): - ... # Handle the mapping logic yourself - ... if old_param is not None and new_param is None: - ... new_param = old_param - ... # Rest of function logic - ... pass - - >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") - >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): - ... # Handle each mapping individually - ... if old_param is not None and new_param is None: - ... new_param = old_param - ... if two_param is not None and two_new_param is None: - ... two_new_param = two_param - ... # Rest of function logic - ... pass - """ - - def _deprecate_impl(obj: Any) -> Any: # noqa: ANN401 - if isinstance(obj, type): - return _wrap_deprecated_init(obj, since, remove, use, reason, warning_category) - return _wrap_deprecated_function(obj, since, remove, use, reason, args, warning_category) - - return _deprecate_impl if __obj is None else _deprecate_impl(__obj) - - def _build_deprecation_msg( name: str, since: str | None, @@ -203,12 +108,42 @@ def _build_deprecation_msg( return msg + "." +def _warn_deprecated_arguments( + args_map: dict[str, str], + bound_args: inspect.BoundArguments, + since: str | None, + remove: str | None, + warning_category: type[Warning], +) -> None: + """Warn about deprecated keyword arguments. + + Args: + args_map: Mapping of deprecated argument names to new names. + bound_args: Bound arguments of the function or method. + since: Version when the argument was deprecated. + remove: Version when the argument will be removed. + warning_category: Type of warning to emit. + """ + for old_arg, new_arg in args_map.items(): + if old_arg in bound_args.arguments: + msg = _build_deprecation_msg( + old_arg, + since, + remove, + use=new_arg, + reason=None, + prefix="Argument ", + ) + warnings.warn(msg, warning_category, stacklevel=3) + + def _wrap_deprecated_init( cls: type, since: str | None, remove: str | None, use: str | None, reason: str | None, + args_map: dict[str, str] | None, warning_category: type[Warning], ) -> type: """Wrap a class constructor to emit a deprecation warning when instantiated. @@ -219,17 +154,26 @@ def _wrap_deprecated_init( remove: Version when the class will be removed. use: Suggested replacement class. reason: Additional reason for deprecation. + args_map: Mapping of deprecated argument names to new names. warning_category: The type of warning to emit. Returns: The decorated class with a warning-emitting __init__. """ original_init = inspect.getattr_static(cls, "__init__") + sig = inspect.signature(original_init) @wraps(original_init) def new_init(self: Any, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 - msg = _build_deprecation_msg(cls.__name__, since, remove, use, reason) - warnings.warn(msg, warning_category, stacklevel=2) + bound_args = sig.bind_partial(self, *args, **kwargs) + bound_args.apply_defaults() + + if args_map: + _warn_deprecated_arguments(args_map, bound_args, since, remove, warning_category) + else: + msg = _build_deprecation_msg(cls.__name__, since, remove, use, reason) + warnings.warn(msg, warning_category, stacklevel=2) + original_init(self, *args, **kwargs) setattr(cls, "__init__", new_init) # noqa: B010 @@ -270,17 +214,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 # Check for deprecated arguments and warn (but don't replace) if args_map: - for old_arg, new_arg in args_map.items(): - if old_arg in bound_args.arguments: - msg = _build_deprecation_msg( - old_arg, - since, - remove, - use=new_arg, - reason=None, - prefix="Argument ", - ) - warnings.warn(msg, warning_category, stacklevel=2) + _warn_deprecated_arguments(args_map, bound_args, since, remove, warning_category) # If args_map exists, this is argument deprecation, not function deprecation if not args_map: @@ -290,3 +224,115 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 return func(*bound_args.args, **bound_args.kwargs) return wrapper + + +@overload +def deprecate(__obj: _T) -> _T: ... + + +@overload +def deprecate( + *, + since: str | None = None, + remove: str | None = None, + use: str | None = None, + reason: str | None = None, + args: dict[str, str] | None = None, + warning_category: type[Warning] = FutureWarning, +) -> Callable[[_T], _T]: ... + + +def deprecate( + __obj: Any = None, + *, + since: str | None = None, + remove: str | None = None, + use: str | None = None, + reason: str | None = None, + args: dict[str, str] | None = None, + warning_category: type[Warning] = DeprecationWarning, +) -> Any: + """Mark a function, class, or keyword argument as deprecated. + + This decorator will cause a warning to be emitted when the function is called, + class is instantiated, or deprecated keyword arguments are used. + If args are passed, no function deprecation will be shown. + + Args: + __obj: The function or class to be deprecated. + If provided, the decorator is used without arguments. + If None, the decorator is used with arguments. + since: Version when the function/class/arg was deprecated (e.g. "2.1.0"). + If not provided, no deprecation version will be shown in the warning message. + remove: Version when the function/class/arg will be removed. + If not provided, no removal version will be shown in the warning message. + use: Name of the replacement function/class/arg. + If not provided, no replacement suggestion will be shown in the warning message. + reason: Additional reason for the deprecation. + If not provided, no reason will be shown in the warning message. + args: Mapping of deprecated argument names to their replacements. + If not provided, no argument deprecation warning will be shown. + If provided, function deprecation will not be show. + warning_category: Type of warning to emit (default: DeprecationWarning). + Can be DeprecationWarning, FutureWarning, PendingDeprecationWarning, etc. + + Returns: + Decorated function/class that emits a deprecation warning when used. + + Example: + >>> # Basic deprecation without any information + >>> @deprecate + ... def old_function(): + ... pass + + >>> # Deprecation with version info and replacement class + >>> @deprecate(since="2.1.0", remove="2.5.0", use="NewClass") + ... class OldClass: + ... pass + + >>> # Deprecation with only removal version + >>> @deprecate(remove="2.5.0") + ... def another_function(): + ... pass + + >>> # Deprecation with only deprecation version + >>> @deprecate(since="2.1.0") + ... def yet_another_function(): + ... pass + + >>> # Deprecation with reason and custom warning category + >>> @deprecate(since="2.1.0", reason="Performance improvements available", warning_category=FutureWarning) + ... def old_slow_function(): + ... pass + + >>> # Deprecation of specific argument(s) + >>> @deprecate(args={"old_param": "new_param"}, since="2.1.0", remove="3.0.0") + ... def my_function(new_param=None, old_param=None): + ... # Handle the mapping logic yourself + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... # Rest of function logic + ... pass + + >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") + >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): + ... # Handle each mapping individually + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... if two_param is not None and two_new_param is None: + ... two_new_param = two_param + ... # Rest of function logic + ... pass + >>> @deprecate(args={"old_arg": "new_arg"}, since="2.1.0", remove="3.0.0") + >>> class MyClass: + >>> def __init__(self, new_arg=None, old_arg=None): + >>> if new_arg is None and old_arg is not None: + >>> new_arg = old_arg + """ + + def _deprecate_impl(obj: Any) -> Any: # noqa: ANN401 + if isinstance(obj, type): + return _wrap_deprecated_init(obj, since, remove, use, reason, args, warning_category) + return _wrap_deprecated_function(obj, since, remove, use, reason, args, warning_category) + + return _deprecate_impl if __obj is None else _deprecate_impl(__obj) diff --git a/tests/unit/utils/test_deprecate.py b/tests/unit/utils/test_deprecate.py new file mode 100644 index 0000000000..5b4e98589f --- /dev/null +++ b/tests/unit/utils/test_deprecate.py @@ -0,0 +1,70 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the deprecation utility function.""" + +import pytest + +from anomalib.utils import deprecate + + +def test_deprecated_function_warns() -> None: + """Test simple function deprecation warning.""" + + @deprecate(since="1.0", remove="2.0", use="new_func") + def old_func() -> str: + return "hello" + + with pytest.warns(DeprecationWarning, match=r"(?=.*1\.0)(?=.*2\.0)(?=.*new_func)"): + assert old_func() == "hello" + + +def test_deprecated_class_warns_on_init() -> None: + """Test deprecated class emits warning on instantiation.""" + + @deprecate(since="1.0", remove="2.0", use="NewClass") + class OldClass: + def __init__(self) -> None: + self.val = 42 + + with pytest.warns(DeprecationWarning, match=r"(?=.*1\.0)(?=.*2\.0)(?=.*NewClass)"): + instance = OldClass() + assert instance.val == 42 + + +def test_deprecated_function_arg_warns() -> None: + """Test deprecated argument triggers warning with correct match.""" + + @deprecate(args={"old_param": "new_param"}, since="1.0", remove="2.0") + def func(new_param: int | None = None, old_param: int | None = None) -> int: + return new_param if new_param is not None else old_param + + with pytest.warns(DeprecationWarning, match=r"(?=.*old_param)(?=.*new_param)(?=.*1\.0)(?=.*2\.0)"): + assert func(old_param=10) == 10 + + +def test_deprecated_multiple_args_warns() -> None: + """Test multiple deprecated arguments trigger appropriate warnings.""" + + @deprecate(args={"old1": "new1", "old2": "new2"}, since="1.0") + def func(new1=None, new2=None, old1=None, old2=None): # noqa: ANN001, ANN202 + return new1 or old1, new2 or old2 + + with pytest.warns(DeprecationWarning, match=r"(?=.*old1)(?=.*new1)(?=.*1\.0)"): + func(old1="a") + + with pytest.warns(DeprecationWarning, match=r"(?=.*old2)(?=.*new2)(?=.*1\.0)"): + func(old2="b") + + +def test_deprecated_class_arg_warns() -> None: + """Test deprecated constructor argument triggers warning.""" + + @deprecate(args={"legacy_arg": "modern_arg"}, since="1.0", remove="2.0") + class MyClass: + def __init__(self, modern_arg=None, legacy_arg=None) -> None: # noqa: ANN001 + self.value = modern_arg if modern_arg is not None else legacy_arg + + with pytest.warns(DeprecationWarning, match=r"(?=.*legacy_arg)(?=.*modern_arg)(?=.*1\.0)(?=.*2\.0)"): + obj = MyClass(legacy_arg="deprecated") + assert obj.value == "deprecated" From 9072bbdee4bf3514989b2c0849f357215c5f596b Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Tue, 15 Jul 2025 12:12:59 +0100 Subject: [PATCH 15/16] update warnings to deprecations. Resolve comments re-fit vs fit_guassian. Test for None replacements for args Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/image/dfkde/lightning_model.py | 2 +- src/anomalib/models/image/dfkde/torch_model.py | 6 +++--- src/anomalib/models/image/dfm/torch_model.py | 14 +++++--------- src/anomalib/models/image/padim/lightning_model.py | 2 +- src/anomalib/models/image/padim/torch_model.py | 6 +++--- src/anomalib/models/image/patchcore/torch_model.py | 13 +++++-------- src/anomalib/utils/deprecation.py | 10 +++++----- tests/unit/utils/test_deprecate.py | 11 +++++++++++ 8 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index 0b57e3c9ab..99ac990e1d 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -139,7 +139,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: def fit(self) -> None: """Fit KDE model to collected embeddings from the training set.""" logger.info("Fitting a KDE model to the embedding collected from the training set.") - self.model.fit_classifier() + self.model.fit() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform validation by computing anomaly scores. diff --git a/src/anomalib/models/image/dfkde/torch_model.py b/src/anomalib/models/image/dfkde/torch_model.py index c797889b4a..9af1ef4b65 100644 --- a/src/anomalib/models/image/dfkde/torch_model.py +++ b/src/anomalib/models/image/dfkde/torch_model.py @@ -89,7 +89,7 @@ def __init__( feature_scaling_method=feature_scaling_method, max_training_points=max_training_points, ) - self.memory_bank = torch.Tensor() + self.memory_bank = torch.empty(0) def get_features(self, batch: torch.Tensor) -> torch.Tensor: """Extract features from the pre-trained backbone network. @@ -153,7 +153,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: scores = self.classifier(features) return InferenceBatch(pred_score=scores) - def fit_classifier(self) -> None: + def fit(self) -> None: """Fits the classifier using the current contents of the memory bank. This method is typically called after the memory bank has been populated @@ -172,4 +172,4 @@ def fit_classifier(self) -> None: self.classifier.fit(self.memory_bank) # clear memory bank, redcues gpu size - self.memory_bank = torch.Tensor().to(self.memory_bank) + self.memory_bank = torch.empty(0).to(self.memory_bank) diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index a9e4986379..25cbb2c36a 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -26,7 +26,6 @@ """ import math -import warnings import torch from torch import nn @@ -34,6 +33,7 @@ from anomalib.data import InferenceBatch from anomalib.models.components import PCA, DynamicBufferMixin, TimmFeatureExtractor +from anomalib.utils.deprecation import deprecate class SingleClassGaussian(DynamicBufferMixin): @@ -154,8 +154,9 @@ def __init__( layers=[layer], ).eval() - self.memory_bank = torch.Tensor() + self.memory_bank = torch.empty(0) + @deprecate(args={"dataset": None}, since="2.1.0") def fit(self, dataset: torch.Tensor = None) -> None: """Fit PCA and Gaussian model to dataset. @@ -166,12 +167,7 @@ def fit(self, dataset: torch.Tensor = None) -> None: the method now uses `self.memory_bank` as the input dataset. """ if dataset is not None: - warnings.warn( - "The 'dataset' argument is deprecated and will be removed in a future release. " - "The method now uses 'self.memory_bank' as the input dataset.", - DeprecationWarning, - stacklevel=2, - ) + del dataset self.pca_model.fit(self.memory_bank) if self.score_type == "nll": @@ -179,7 +175,7 @@ def fit(self, dataset: torch.Tensor = None) -> None: self.gaussian_model.fit(features_reduced.T) # clear memory bank, reduces GPU size - self.memory_bank = torch.Tensor().to(self.memory_bank) + self.memory_bank = torch.empty(0).to(self.memory_bank) def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: """Compute anomaly scores. diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index 7e76873467..8660b88ee4 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -159,7 +159,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: def fit(self) -> None: """Fit a Gaussian to the embedding collected from the training set.""" logger.info("Fitting a Gaussian to the embedding collected from the training set.") - self.model.fit_gaussian() + self.model.fit() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a validation step of PADIM. diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index 0d08a3990a..d7367149a0 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -147,7 +147,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.gaussian = MultiVariateGaussian() - self.memory_bank = torch.Tensor() + self.memory_bank = torch.empty(0) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. @@ -224,7 +224,7 @@ def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: idx = self.idx.to(embeddings.device) return torch.index_select(embeddings, 1, idx) - def fit_gaussian(self) -> None: + def fit(self) -> None: """Fits a Gaussian model to the current contents of the memory bank. This method is typically called after the memory bank has been filled during training. @@ -242,4 +242,4 @@ def fit_gaussian(self) -> None: self.gaussian.fit(self.memory_bank) # clear memory bank, redcues gpu usage - self.memory_bank = torch.Tensor().to(self.memory_bank) + self.memory_bank = torch.empty(0).to(self.memory_bank) diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 175e06baf0..1970e6dcc2 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -33,7 +33,6 @@ Coreset subsampling using k-center-greedy approach """ -import warnings from collections.abc import Sequence from typing import TYPE_CHECKING @@ -43,6 +42,7 @@ from anomalib.data import InferenceBatch from anomalib.models.components import DynamicBufferMixin, KCenterGreedy, TimmFeatureExtractor +from anomalib.utils import deprecate from .anomaly_map import AnomalyMapGenerator @@ -125,7 +125,7 @@ def __init__( self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator() self.memory_bank: torch.Tensor - self.register_buffer("memory_bank", torch.Tensor()) + self.register_buffer("memory_bank", torch.empty(0)) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Process input tensor through the model. @@ -250,6 +250,7 @@ def reshape_embedding(embedding: torch.Tensor) -> torch.Tensor: embedding_size = embedding.size(1) return embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) + @deprecate(args={"embeddings": None}, since="2.1.0", reason="Use the default memory bank instead.") def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = None) -> None: """Subsample the memory_banks embeddings using coreset selection. @@ -269,12 +270,8 @@ def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = torch.Size([100, 512]) """ if embeddings is not None: - warnings.warn( - "The 'embeddings' argument is deprecated and will be removed in a future release." - "Please use the default memory bank instead.", - DeprecationWarning, - stacklevel=2, - ) + del embeddings + if self.memory_bank.size(0) == 0: msg = "Memory bank is empty. Cannot perform coreset selection." raise ValueError(msg) diff --git a/src/anomalib/utils/deprecation.py b/src/anomalib/utils/deprecation.py index baa91030ce..0f195bea8d 100644 --- a/src/anomalib/utils/deprecation.py +++ b/src/anomalib/utils/deprecation.py @@ -109,7 +109,7 @@ def _build_deprecation_msg( def _warn_deprecated_arguments( - args_map: dict[str, str], + args_map: dict[str, str | None], bound_args: inspect.BoundArguments, since: str | None, remove: str | None, @@ -143,7 +143,7 @@ def _wrap_deprecated_init( remove: str | None, use: str | None, reason: str | None, - args_map: dict[str, str] | None, + args_map: dict[str, str | None] | None, warning_category: type[Warning], ) -> type: """Wrap a class constructor to emit a deprecation warning when instantiated. @@ -186,7 +186,7 @@ def _wrap_deprecated_function( remove: str | None, use: str | None, reason: str | None, - args_map: dict[str, str] | None, + args_map: dict[str, str | None] | None, warning_category: type[Warning], ) -> Callable: """Wrap a function to emit a deprecation warning when called. @@ -237,7 +237,7 @@ def deprecate( remove: str | None = None, use: str | None = None, reason: str | None = None, - args: dict[str, str] | None = None, + args: dict[str, str | None] | None = None, warning_category: type[Warning] = FutureWarning, ) -> Callable[[_T], _T]: ... @@ -249,7 +249,7 @@ def deprecate( remove: str | None = None, use: str | None = None, reason: str | None = None, - args: dict[str, str] | None = None, + args: dict[str, str | None] | None = None, warning_category: type[Warning] = DeprecationWarning, ) -> Any: """Mark a function, class, or keyword argument as deprecated. diff --git a/tests/unit/utils/test_deprecate.py b/tests/unit/utils/test_deprecate.py index 5b4e98589f..936e8d03c3 100644 --- a/tests/unit/utils/test_deprecate.py +++ b/tests/unit/utils/test_deprecate.py @@ -32,6 +32,17 @@ def __init__(self) -> None: assert instance.val == 42 +def test_deprecated_function_arg_warns_with_none_replacement() -> None: + """Test deprecated argument triggers warning when replacement is None.""" + + @deprecate(args={"old_param": None}, since="1.0", remove="2.0") + def func(new_param: int | None = None, old_param: int | None = None) -> int: + return new_param if new_param is not None else old_param + + with pytest.warns(DeprecationWarning, match=r"(?=.*old_param)(?=.*1\.0)(?=.*2\.0)"): + assert func(old_param=10) == 10 + + def test_deprecated_function_arg_warns() -> None: """Test deprecated argument triggers warning with correct match.""" From f123996141ad7256f0d870ebfddb45a57de9ea4b Mon Sep 17 00:00:00 2001 From: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Date: Tue, 15 Jul 2025 12:29:41 +0100 Subject: [PATCH 16/16] add type hint for dfkde Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --- src/anomalib/models/image/dfkde/lightning_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index 99ac990e1d..4222d626b3 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -104,7 +104,7 @@ def __init__( visualizer=visualizer, ) - self.model = DfkdeModel( + self.model: DfkdeModel = DfkdeModel( layers=layers, backbone=backbone, pre_trained=pre_trained,