diff --git a/src/anomalib/models/image/dfkde/lightning_model.py b/src/anomalib/models/image/dfkde/lightning_model.py index 7ebcaf22f7..4222d626b3 100644 --- a/src/anomalib/models/image/dfkde/lightning_model.py +++ b/src/anomalib/models/image/dfkde/lightning_model.py @@ -104,7 +104,7 @@ def __init__( visualizer=visualizer, ) - self.model = DfkdeModel( + self.model: DfkdeModel = DfkdeModel( layers=layers, backbone=backbone, pre_trained=pre_trained, @@ -113,8 +113,6 @@ def __init__( max_training_points=max_training_points, ) - self.embeddings: list[torch.Tensor] = [] - @staticmethod def configure_optimizers() -> None: # pylint: disable=arguments-differ """DFKDE doesn't require optimization, therefore returns no optimizers.""" @@ -133,18 +131,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - embedding = self.model(batch.image) - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: """Fit KDE model to collected embeddings from the training set.""" - embeddings = torch.vstack(self.embeddings) - logger.info("Fitting a KDE model to the embedding collected from the training set.") - self.model.classifier.fit(embeddings) + self.model.fit() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform validation by computing anomaly scores. @@ -167,9 +162,13 @@ def trainer_arguments(self) -> dict[str, Any]: """Get DFKDE-specific trainer arguments. Returns: - dict[str, Any]: Dictionary of trainer arguments. + dict[str, Any]: Trainer arguments + - ``gradient_clip_val``: ``0`` (no gradient clipping needed) + - ``max_epochs``: ``1`` (single pass through training data) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/dfkde/torch_model.py b/src/anomalib/models/image/dfkde/torch_model.py index f677b668c4..9af1ef4b65 100644 --- a/src/anomalib/models/image/dfkde/torch_model.py +++ b/src/anomalib/models/image/dfkde/torch_model.py @@ -89,6 +89,7 @@ def __init__( feature_scaling_method=feature_scaling_method, max_training_points=max_training_points, ) + self.memory_bank = torch.empty(0) def get_features(self, batch: torch.Tensor) -> torch.Tensor: """Extract features from the pre-trained backbone network. @@ -141,8 +142,34 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: # 1. apply feature extraction features = self.get_features(batch) if self.training: + if self.memory_bank.size(0) == 0: + self.memory_bank = features + else: + new_bank = torch.cat((self.memory_bank, features), dim=0).to(self.memory_bank) + self.memory_bank = new_bank return features # 2. apply density estimation scores = self.classifier(features) return InferenceBatch(pred_score=scores) + + def fit(self) -> None: + """Fits the classifier using the current contents of the memory bank. + + This method is typically called after the memory bank has been populated + during training. + + After fitting, the memory bank is cleared to reduce GPU memory usage. + + Raises: + ValueError: If the memory bank is empty. + """ + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot perform coreset selection." + raise ValueError(msg) + + # fit gaussian + self.classifier.fit(self.memory_bank) + + # clear memory bank, redcues gpu size + self.memory_bank = torch.empty(0).to(self.memory_bank) diff --git a/src/anomalib/models/image/dfm/lightning_model.py b/src/anomalib/models/image/dfm/lightning_model.py index 7933f8b3b5..d424c8be19 100644 --- a/src/anomalib/models/image/dfm/lightning_model.py +++ b/src/anomalib/models/image/dfm/lightning_model.py @@ -112,7 +112,6 @@ def __init__( n_comps=pca_level, score_type=score_type, ) - self.embeddings: list[torch.Tensor] = [] self.score_type = score_type @staticmethod @@ -137,8 +136,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - embedding = self.model.get_features(batch.image).squeeze() - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) @@ -149,11 +147,8 @@ def fit(self) -> None: The method aggregates embeddings collected during training and fits both the PCA transformation and Gaussian model used for scoring. """ - logger.info("Aggregating the embedding extracted from the training set.") - embeddings = torch.vstack(self.embeddings) - logger.info("Fitting a PCA and a Gaussian model to dataset.") - self.model.fit(embeddings) + self.model.fit() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Compute predictions for the input batch during validation. @@ -176,12 +171,13 @@ def trainer_arguments(self) -> dict[str, Any]: """Get DFM-specific trainer arguments. Returns: - dict[str, Any]: Dictionary of trainer arguments: - - ``gradient_clip_val`` (int): Disable gradient clipping - - ``max_epochs`` (int): Train for one epoch only - - ``num_sanity_val_steps`` (int): Skip validation sanity checks + dict[str, Any]: Trainer arguments + - ``gradient_clip_val``: ``0`` (no gradient clipping needed) + - ``max_epochs``: ``1`` (single pass through training data) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index 40fab96efa..25cbb2c36a 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -33,6 +33,7 @@ from anomalib.data import InferenceBatch from anomalib.models.components import PCA, DynamicBufferMixin, TimmFeatureExtractor +from anomalib.utils.deprecation import deprecate class SingleClassGaussian(DynamicBufferMixin): @@ -153,18 +154,29 @@ def __init__( layers=[layer], ).eval() - def fit(self, dataset: torch.Tensor) -> None: + self.memory_bank = torch.empty(0) + + @deprecate(args={"dataset": None}, since="2.1.0") + def fit(self, dataset: torch.Tensor = None) -> None: """Fit PCA and Gaussian model to dataset. Args: - dataset (torch.Tensor): Input dataset with shape - ``(n_samples, n_features)``. + dataset (torch.Tensor, optional): **[DEPRECATED]** + Input dataset with shape ``(n_samples, n_features)``. + This argument is deprecated and will be ignored. Instead, + the method now uses `self.memory_bank` as the input dataset. """ - self.pca_model.fit(dataset) + if dataset is not None: + del dataset + + self.pca_model.fit(self.memory_bank) if self.score_type == "nll": - features_reduced = self.pca_model.transform(dataset) + features_reduced = self.pca_model.transform(self.memory_bank) self.gaussian_model.fit(features_reduced.T) + # clear memory bank, reduces GPU size + self.memory_bank = torch.empty(0).to(self.memory_bank) + def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: """Compute anomaly scores. @@ -194,7 +206,7 @@ def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: return (score, None) if self.score_type == "nll" else (score, score_map) - def get_features(self, batch: torch.Tensor) -> torch.Tensor: + def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Size]: """Extract features from the pretrained network. Args: @@ -202,17 +214,16 @@ def get_features(self, batch: torch.Tensor) -> torch.Tensor: ``(batch_size, channels, height, width)``. Returns: - Union[torch.Tensor, Tuple[torch.Tensor, torch.Size]]: Features during - training, or tuple of (features, feature_shapes) during inference. + tuple of (features, feature_shapes). """ - self.feature_extractor.eval() - features = self.feature_extractor(batch)[self.layer] - batch_size = len(features) - if self.pooling_kernel_size > 1: - features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size) - feature_shapes = features.shape - features = features.view(batch_size, -1).detach() - return features if self.training else (features, feature_shapes) + with torch.no_grad(): + features = self.feature_extractor(batch)[self.layer] + batch_size = len(features) + if self.pooling_kernel_size > 1: + features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size) + feature_shapes = features.shape + features = features.view(batch_size, -1) + return features, feature_shapes def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: """Compute anomaly predictions from input images. @@ -227,6 +238,15 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: ``InferenceBatch`` with prediction scores and anomaly maps. """ feature_vector, feature_shapes = self.get_features(batch) + + if self.training: + if self.memory_bank.size(0) == 0: + self.memory_bank = feature_vector + else: + new_bank = torch.cat((self.memory_bank, feature_vector), dim=0).to(self.memory_bank) + self.memory_bank = new_bank + return feature_vector + pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes) if anomaly_map is not None: anomaly_map = F.interpolate(anomaly_map, size=batch.shape[-2:], mode="bilinear", align_corners=False) diff --git a/src/anomalib/models/image/padim/lightning_model.py b/src/anomalib/models/image/padim/lightning_model.py index 977410302c..8660b88ee4 100644 --- a/src/anomalib/models/image/padim/lightning_model.py +++ b/src/anomalib/models/image/padim/lightning_model.py @@ -131,9 +131,6 @@ def __init__( n_features=n_features, ) - self.stats: list[torch.Tensor] = [] - self.embeddings: list[torch.Tensor] = [] - @staticmethod def configure_optimizers() -> None: """PADIM doesn't require optimization, therefore returns no optimizers.""" @@ -154,19 +151,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: """ del args, kwargs # These variables are not used. - embedding = self.model(batch.image) - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) def fit(self) -> None: """Fit a Gaussian to the embedding collected from the training set.""" - logger.info("Aggregating the embedding extracted from the training set.") - embeddings = torch.vstack(self.embeddings) - logger.info("Fitting a Gaussian to the embedding collected from the training set.") - self.stats = self.model.gaussian.fit(embeddings) + self.model.fit() def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a validation step of PADIM. @@ -190,16 +183,16 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: @property def trainer_arguments(self) -> dict[str, int | float]: - """Return PADIM trainer arguments. - - Since the model does not require training, we limit the max_epochs to 1. - Since we need to run training epoch before validation, we also set the - sanity steps to 0. + """Get default trainer arguments for Padim. Returns: - dict[str, int | float]: Dictionary of trainer arguments + dict[str, Any]: Trainer arguments + - ``max_epochs``: ``1`` (single pass through training data) + - ``val_check_interval``: ``1.0`` (check validation every 1 step) + - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0} + return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index dacc44fde4..d7367149a0 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -147,6 +147,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.gaussian = MultiVariateGaussian() + self.memory_bank = torch.empty(0) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. @@ -182,6 +183,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embeddings = self.tiler.untile(embeddings) if self.training: + if self.memory_bank.size(0) == 0: + self.memory_bank = embeddings + else: + new_bank = torch.cat((self.memory_bank, embeddings), dim=0).to(self.memory_bank) + self.memory_bank = new_bank return embeddings anomaly_map = self.anomaly_map_generator( @@ -217,3 +223,23 @@ def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor: # subsample embeddings idx = self.idx.to(embeddings.device) return torch.index_select(embeddings, 1, idx) + + def fit(self) -> None: + """Fits a Gaussian model to the current contents of the memory bank. + + This method is typically called after the memory bank has been filled during training. + + After fitting, the memory bank is cleared to free GPU memory before validation or testing. + + Raises: + ValueError: If the memory bank is empty. + """ + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot perform coreset selection." + raise ValueError(msg) + + # fit gaussian + self.gaussian.fit(self.memory_bank) + + # clear memory bank, redcues gpu usage + self.memory_bank = torch.empty(0).to(self.memory_bank) diff --git a/src/anomalib/models/image/patchcore/lightning_model.py b/src/anomalib/models/image/patchcore/lightning_model.py index d41d8e6ad0..4e104954fe 100644 --- a/src/anomalib/models/image/patchcore/lightning_model.py +++ b/src/anomalib/models/image/patchcore/lightning_model.py @@ -159,7 +159,6 @@ def __init__( num_neighbors=num_neighbors, ) self.coreset_sampling_ratio = coreset_sampling_ratio - self.embeddings: list[torch.Tensor] = [] @classmethod def configure_pre_processor( @@ -235,9 +234,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None: ``fit()``. """ del args, kwargs # These variables are not used. - - embedding = self.model(batch.image) - self.embeddings.append(embedding) + _ = self.model(batch.image) # Return a dummy loss tensor return torch.tensor(0.0, requires_grad=True, device=self.device) @@ -245,14 +242,10 @@ def fit(self) -> None: """Apply subsampling to the embedding collected from the training set. This method: - 1. Aggregates embeddings from all training batches - 2. Applies coreset subsampling to reduce memory requirements + 1. Applies coreset subsampling to reduce memory requirements """ - logger.info("Aggregating the embedding extracted from the training set.") - embeddings = torch.vstack(self.embeddings) - logger.info("Applying core-set subsampling to get the embedding.") - self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio) + self.model.subsample_embedding(self.coreset_sampling_ratio) def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Generate predictions for a batch of images. @@ -286,8 +279,9 @@ def trainer_arguments(self) -> dict[str, Any]: - ``gradient_clip_val``: ``0`` (no gradient clipping needed) - ``max_epochs``: ``1`` (single pass through training data) - ``num_sanity_val_steps``: ``0`` (skip validation sanity checks) + - ``devices``: ``1`` (only single gpu supported) """ - return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0} + return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1} @property def learning_type(self) -> LearningType: diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index cea234e227..1970e6dcc2 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -42,6 +42,7 @@ from anomalib.data import InferenceBatch from anomalib.models.components import DynamicBufferMixin, KCenterGreedy, TimmFeatureExtractor +from anomalib.utils import deprecate from .anomaly_map import AnomalyMapGenerator @@ -123,9 +124,8 @@ def __init__( ).eval() self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator() - - self.register_buffer("memory_bank", torch.empty(0)) self.memory_bank: torch.Tensor + self.register_buffer("memory_bank", torch.empty(0)) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Process input tensor through the model. @@ -169,7 +169,18 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embedding = self.reshape_embedding(embedding) if self.training: + if self.memory_bank.size(0) == 0: + self.memory_bank = embedding + else: + new_bank = torch.cat((self.memory_bank, embedding), dim=0).to(self.memory_bank) + self.memory_bank = new_bank return embedding + + # Ensure memory bank is not empty + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot provide anomaly scores" + raise ValueError(msg) + # apply nearest neighbor search patch_scores, locations = self.nearest_neighbors(embedding=embedding, n_neighbors=1) # reshape to batch dimension @@ -239,25 +250,34 @@ def reshape_embedding(embedding: torch.Tensor) -> torch.Tensor: embedding_size = embedding.size(1) return embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) - def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: - """Subsample embeddings using coreset selection. + @deprecate(args={"embeddings": None}, since="2.1.0", reason="Use the default memory bank instead.") + def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = None) -> None: + """Subsample the memory_banks embeddings using coreset selection. Uses k-center-greedy coreset subsampling to select a representative subset of patch embeddings to store in the memory bank. Args: - embedding (torch.Tensor): Embedding tensor to subsample from. sampling_ratio (float): Fraction of embeddings to keep, in range (0,1]. + embeddings (torch.Tensor): **[DEPRECATED]** + This argument is deprecated and will be removed in a future release. + Use the default behavior (i.e., rely on `self.memory_bank`) instead. Example: - >>> embedding = torch.randn(1000, 512) - >>> model.subsample_embedding(embedding, sampling_ratio=0.1) + >>> model.memory_bank = torch.randn(1000, 512) + >>> model.subsample_embedding(sampling_ratio=0.1) >>> model.memory_bank.shape torch.Size([100, 512]) """ + if embeddings is not None: + del embeddings + + if self.memory_bank.size(0) == 0: + msg = "Memory bank is empty. Cannot perform coreset selection." + raise ValueError(msg) # Coreset Subsampling - sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) - coreset = sampler.sample_coreset() + sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) + coreset = sampler.sample_coreset().to(self.memory_bank) self.memory_bank = coreset @staticmethod diff --git a/src/anomalib/utils/deprecation.py b/src/anomalib/utils/deprecation.py index 25717b8a50..0f195bea8d 100644 --- a/src/anomalib/utils/deprecation.py +++ b/src/anomalib/utils/deprecation.py @@ -3,23 +3,16 @@ """Deprecation utilities for anomalib. -This module provides utilities for marking functions and classes as deprecated. +This module provides utilities for marking functions, classes and args as deprecated. The utilities help maintain backward compatibility while encouraging users to migrate to newer APIs. Example: - >>> from anomalib.utils.deprecation import deprecated - - >>> # Deprecation without any information + >>> # Basic deprecation without any information >>> @deprecate ... def old_function(): ... pass - >>> # Deprecation with replacement function - >>> @deprecate(use="new_function") - ... def old_function(): - ... pass - >>> # Deprecation with version info and replacement class >>> @deprecate(since="2.1.0", remove="2.5.0", use="NewClass") ... class OldClass: @@ -34,6 +27,43 @@ >>> @deprecate(since="2.1.0") ... def yet_another_function(): ... pass + + >>> # Deprecation with reason and custom warning category + >>> @deprecate(since="2.1.0", reason="Performance improvements available", warning_category=FutureWarning) + ... def old_slow_function(): + ... pass + + >>> # Deprecation of classes + >>> @deprecate(since="2.1.0", remove="3.0.0") + ... class MyClass: + ... def __init__(self): + ... pass + + >>> # Deprecation of specific argument(s) + >>> @deprecate(args={"old_param": "new_param"}, since="2.1.0", remove="3.0.0") + ... def my_function(new_param=None, old_param=None): + ... # Handle the mapping logic yourself + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... # Rest of function logic + ... pass + + >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") + >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): + ... # Handle each mapping individually + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... if two_param is not None and two_new_param is None: + ... two_new_param = two_param + ... # Rest of function logic + ... pass + + >>> @deprecate(args={"old_arg": "new_arg"}, since="2.1.0", remove="3.0.0") + >>> class MyClass: + ... def __init__(self, new_arg=None, old_arg=None): + ... if new_arg is None and old_arg is not None: + ... new_arg = old_arg + ... pass """ import inspect @@ -45,6 +75,157 @@ _T = TypeVar("_T") +def _build_deprecation_msg( + name: str, + since: str | None, + remove: str | None, + use: str | None, + reason: str | None, + prefix: str = "", +) -> str: + """Construct a standardized deprecation warning message. + + Args: + name: Name of the deprecated item (function, class, or argument). + since: Version when the item was deprecated. + remove: Version when the item will be removed. + use: Suggested replacement name. + reason: Additional reason for deprecation. + prefix: Optional prefix (e.g., 'Argument ') for clarity. + + Returns: + A complete deprecation message string. + """ + msg = f"{prefix}{name} is deprecated" + if since: + msg += f" since v{since}" + if remove: + msg += f" and will be removed in v{remove}" + if use: + msg += f". Use {use} instead" + if reason: + msg += f". {reason}" + return msg + "." + + +def _warn_deprecated_arguments( + args_map: dict[str, str | None], + bound_args: inspect.BoundArguments, + since: str | None, + remove: str | None, + warning_category: type[Warning], +) -> None: + """Warn about deprecated keyword arguments. + + Args: + args_map: Mapping of deprecated argument names to new names. + bound_args: Bound arguments of the function or method. + since: Version when the argument was deprecated. + remove: Version when the argument will be removed. + warning_category: Type of warning to emit. + """ + for old_arg, new_arg in args_map.items(): + if old_arg in bound_args.arguments: + msg = _build_deprecation_msg( + old_arg, + since, + remove, + use=new_arg, + reason=None, + prefix="Argument ", + ) + warnings.warn(msg, warning_category, stacklevel=3) + + +def _wrap_deprecated_init( + cls: type, + since: str | None, + remove: str | None, + use: str | None, + reason: str | None, + args_map: dict[str, str | None] | None, + warning_category: type[Warning], +) -> type: + """Wrap a class constructor to emit a deprecation warning when instantiated. + + Args: + cls: The class being deprecated. + since: Version when deprecation began. + remove: Version when the class will be removed. + use: Suggested replacement class. + reason: Additional reason for deprecation. + args_map: Mapping of deprecated argument names to new names. + warning_category: The type of warning to emit. + + Returns: + The decorated class with a warning-emitting __init__. + """ + original_init = inspect.getattr_static(cls, "__init__") + sig = inspect.signature(original_init) + + @wraps(original_init) + def new_init(self: Any, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 + bound_args = sig.bind_partial(self, *args, **kwargs) + bound_args.apply_defaults() + + if args_map: + _warn_deprecated_arguments(args_map, bound_args, since, remove, warning_category) + else: + msg = _build_deprecation_msg(cls.__name__, since, remove, use, reason) + warnings.warn(msg, warning_category, stacklevel=2) + + original_init(self, *args, **kwargs) + + setattr(cls, "__init__", new_init) # noqa: B010 + return cls + + +def _wrap_deprecated_function( + func: Callable, + since: str | None, + remove: str | None, + use: str | None, + reason: str | None, + args_map: dict[str, str | None] | None, + warning_category: type[Warning], +) -> Callable: + """Wrap a function to emit a deprecation warning when called. + + Also handles deprecated keyword arguments, warning when old arguments are used. + + Args: + func: The function to wrap. + since: Version when the function was deprecated. + remove: Version when the function will be removed. + use: Suggested replacement function. + reason: Additional reason for deprecation. + args_map: Mapping of deprecated argument names to new names. + warning_category: The type of warning to emit. + + Returns: + The wrapped function that emits deprecation warnings when called. + """ + sig = inspect.signature(func) + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + + # Check for deprecated arguments and warn (but don't replace) + if args_map: + _warn_deprecated_arguments(args_map, bound_args, since, remove, warning_category) + + # If args_map exists, this is argument deprecation, not function deprecation + if not args_map: + msg = _build_deprecation_msg(func.__name__, since, remove, use, reason) + warnings.warn(msg, warning_category, stacklevel=2) + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + @overload def deprecate(__obj: _T) -> _T: ... @@ -56,6 +237,7 @@ def deprecate( remove: str | None = None, use: str | None = None, reason: str | None = None, + args: dict[str, str | None] | None = None, warning_category: type[Warning] = FutureWarning, ) -> Callable[[_T], _T]: ... @@ -67,25 +249,31 @@ def deprecate( remove: str | None = None, use: str | None = None, reason: str | None = None, - warning_category: type[Warning] = FutureWarning, + args: dict[str, str | None] | None = None, + warning_category: type[Warning] = DeprecationWarning, ) -> Any: - """Mark a function or class as deprecated. + """Mark a function, class, or keyword argument as deprecated. - This decorator will cause a warning to be emitted when the function is called or class is instantiated. + This decorator will cause a warning to be emitted when the function is called, + class is instantiated, or deprecated keyword arguments are used. + If args are passed, no function deprecation will be shown. Args: __obj: The function or class to be deprecated. If provided, the decorator is used without arguments. If None, the decorator is used with arguments. - since: Version when the function/class was deprecated (e.g. "2.1.0"). + since: Version when the function/class/arg was deprecated (e.g. "2.1.0"). If not provided, no deprecation version will be shown in the warning message. - remove: Version when the function/class will be removed. + remove: Version when the function/class/arg will be removed. If not provided, no removal version will be shown in the warning message. - use: Name of the replacement function/class. + use: Name of the replacement function/class/arg. If not provided, no replacement suggestion will be shown in the warning message. reason: Additional reason for the deprecation. If not provided, no reason will be shown in the warning message. - warning_category: Type of warning to emit (default: FutureWarning). + args: Mapping of deprecated argument names to their replacements. + If not provided, no argument deprecation warning will be shown. + If provided, function deprecation will not be show. + warning_category: Type of warning to emit (default: DeprecationWarning). Can be DeprecationWarning, FutureWarning, PendingDeprecationWarning, etc. Returns: @@ -116,49 +304,35 @@ def deprecate( >>> @deprecate(since="2.1.0", reason="Performance improvements available", warning_category=FutureWarning) ... def old_slow_function(): ... pass + + >>> # Deprecation of specific argument(s) + >>> @deprecate(args={"old_param": "new_param"}, since="2.1.0", remove="3.0.0") + ... def my_function(new_param=None, old_param=None): + ... # Handle the mapping logic yourself + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... # Rest of function logic + ... pass + + >>> @deprecate(args={"old_param": "new_param", "two_param": "two_new_param"}, since="2.1.0") + >>> def my_args_function(one_param, two_new_param=None, new_param=None, two_param=None, old_param=None): + ... # Handle each mapping individually + ... if old_param is not None and new_param is None: + ... new_param = old_param + ... if two_param is not None and two_new_param is None: + ... two_new_param = two_param + ... # Rest of function logic + ... pass + >>> @deprecate(args={"old_arg": "new_arg"}, since="2.1.0", remove="3.0.0") + >>> class MyClass: + >>> def __init__(self, new_arg=None, old_arg=None): + >>> if new_arg is None and old_arg is not None: + >>> new_arg = old_arg """ def _deprecate_impl(obj: Any) -> Any: # noqa: ANN401 if isinstance(obj, type): - # Handle class deprecation - original_init = inspect.getattr_static(obj, "__init__") - - @wraps(original_init) - def new_init(self: Any, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 - msg = f"{obj.__name__} is deprecated" - if since: - msg += f" since v{since}" - if remove: - msg += f" and will be removed in v{remove}" - if use: - msg += f". Use {use} instead" - if reason: - msg += f". {reason}" - msg += "." - warnings.warn(msg, warning_category, stacklevel=2) - original_init(self, *args, **kwargs) - - setattr(obj, "__init__", new_init) # noqa: B010 - return obj - - # Handle function deprecation - @wraps(obj) - def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 - msg = f"{obj.__name__} is deprecated" - if since: - msg += f" since v{since}" - if remove: - msg += f" and will be removed in v{remove}" - if use: - msg += f". Use {use} instead" - if reason: - msg += f". {reason}" - msg += "." - warnings.warn(msg, warning_category, stacklevel=2) - return obj(*args, **kwargs) - - return wrapper + return _wrap_deprecated_init(obj, since, remove, use, reason, args, warning_category) + return _wrap_deprecated_function(obj, since, remove, use, reason, args, warning_category) - if __obj is None: - return _deprecate_impl - return _deprecate_impl(__obj) + return _deprecate_impl if __obj is None else _deprecate_impl(__obj) diff --git a/tests/unit/utils/test_deprecate.py b/tests/unit/utils/test_deprecate.py new file mode 100644 index 0000000000..936e8d03c3 --- /dev/null +++ b/tests/unit/utils/test_deprecate.py @@ -0,0 +1,81 @@ +# Copyright (C) 2022-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the deprecation utility function.""" + +import pytest + +from anomalib.utils import deprecate + + +def test_deprecated_function_warns() -> None: + """Test simple function deprecation warning.""" + + @deprecate(since="1.0", remove="2.0", use="new_func") + def old_func() -> str: + return "hello" + + with pytest.warns(DeprecationWarning, match=r"(?=.*1\.0)(?=.*2\.0)(?=.*new_func)"): + assert old_func() == "hello" + + +def test_deprecated_class_warns_on_init() -> None: + """Test deprecated class emits warning on instantiation.""" + + @deprecate(since="1.0", remove="2.0", use="NewClass") + class OldClass: + def __init__(self) -> None: + self.val = 42 + + with pytest.warns(DeprecationWarning, match=r"(?=.*1\.0)(?=.*2\.0)(?=.*NewClass)"): + instance = OldClass() + assert instance.val == 42 + + +def test_deprecated_function_arg_warns_with_none_replacement() -> None: + """Test deprecated argument triggers warning when replacement is None.""" + + @deprecate(args={"old_param": None}, since="1.0", remove="2.0") + def func(new_param: int | None = None, old_param: int | None = None) -> int: + return new_param if new_param is not None else old_param + + with pytest.warns(DeprecationWarning, match=r"(?=.*old_param)(?=.*1\.0)(?=.*2\.0)"): + assert func(old_param=10) == 10 + + +def test_deprecated_function_arg_warns() -> None: + """Test deprecated argument triggers warning with correct match.""" + + @deprecate(args={"old_param": "new_param"}, since="1.0", remove="2.0") + def func(new_param: int | None = None, old_param: int | None = None) -> int: + return new_param if new_param is not None else old_param + + with pytest.warns(DeprecationWarning, match=r"(?=.*old_param)(?=.*new_param)(?=.*1\.0)(?=.*2\.0)"): + assert func(old_param=10) == 10 + + +def test_deprecated_multiple_args_warns() -> None: + """Test multiple deprecated arguments trigger appropriate warnings.""" + + @deprecate(args={"old1": "new1", "old2": "new2"}, since="1.0") + def func(new1=None, new2=None, old1=None, old2=None): # noqa: ANN001, ANN202 + return new1 or old1, new2 or old2 + + with pytest.warns(DeprecationWarning, match=r"(?=.*old1)(?=.*new1)(?=.*1\.0)"): + func(old1="a") + + with pytest.warns(DeprecationWarning, match=r"(?=.*old2)(?=.*new2)(?=.*1\.0)"): + func(old2="b") + + +def test_deprecated_class_arg_warns() -> None: + """Test deprecated constructor argument triggers warning.""" + + @deprecate(args={"legacy_arg": "modern_arg"}, since="1.0", remove="2.0") + class MyClass: + def __init__(self, modern_arg=None, legacy_arg=None) -> None: # noqa: ANN001 + self.value = modern_arg if modern_arg is not None else legacy_arg + + with pytest.warns(DeprecationWarning, match=r"(?=.*legacy_arg)(?=.*modern_arg)(?=.*1\.0)(?=.*2\.0)"): + obj = MyClass(legacy_arg="deprecated") + assert obj.value == "deprecated"