From c495a84a997b84ccb26ab33e85750163dc135adb Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 25 Jun 2025 21:41:35 +0200 Subject: [PATCH 01/12] stash changes Signed-off-by: Ashwin Vaidya --- src/anomalib/models/image/uninet/LICENSE | 29 ++++ src/anomalib/models/image/uninet/__init__.py | 32 +++++ .../models/image/uninet/anomaly_map.py | 72 ++++++++++ src/anomalib/models/image/uninet/dfs.py | 127 +++++++++++++++++ .../models/image/uninet/lightning_model.py | 133 ++++++++++++++++++ src/anomalib/models/image/uninet/loss.py | 104 ++++++++++++++ .../models/image/uninet/torch_model.py | 108 ++++++++++++++ third-party-programs.txt | 4 + 8 files changed, 609 insertions(+) create mode 100644 src/anomalib/models/image/uninet/LICENSE create mode 100644 src/anomalib/models/image/uninet/__init__.py create mode 100644 src/anomalib/models/image/uninet/anomaly_map.py create mode 100644 src/anomalib/models/image/uninet/dfs.py create mode 100644 src/anomalib/models/image/uninet/lightning_model.py create mode 100644 src/anomalib/models/image/uninet/loss.py create mode 100644 src/anomalib/models/image/uninet/torch_model.py diff --git a/src/anomalib/models/image/uninet/LICENSE b/src/anomalib/models/image/uninet/LICENSE new file mode 100644 index 0000000000..26a7ae014a --- /dev/null +++ b/src/anomalib/models/image/uninet/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2025 Intel Corporation +SPDX-License-Identifier: Apache-2.0 + +Some files in this folder are based on the original UniNet implementation by Shun Wei, Jielin Jiang, and Xiaolong Xu. + +Original license: +----------------- + + MIT License + + Copyright (c) 2025 Shun Wei + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. \ No newline at end of file diff --git a/src/anomalib/models/image/uninet/__init__.py b/src/anomalib/models/image/uninet/__init__.py new file mode 100644 index 0000000000..ba86f076d3 --- /dev/null +++ b/src/anomalib/models/image/uninet/__init__.py @@ -0,0 +1,32 @@ +"""UniNet Model for anomaly detection. + +This module implements anomaly detection using the UniNet model. It is a model designed for diverse domains and is suited for both supervised and unsupervised anomaly detection. It also focuses on multi-class anomaly detection. + +Example: + >>> from anomalib.models import UniNet + >>> from anomalib.data import MVTecAD + >>> from anomalib.engine import Engine + + >>> # Initialize model and data + >>> datamodule = MVTecAD() + >>> model = UniNet() + + >>> # Train the model + >>> engine = Engine() + >>> engine.train(model=model, datamodule=datamodule) + + >>> # Get predictions + >>> engine.predict(model=model, datamodule=datamodule) + +See Also: + - :class:`UniNet`: Main model class for UniNet-based anomaly detection + - :class:`UniNetModel`: PyTorch implementation of the UniNet model +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import UniNet +from .torch_model import UniNetModel + +__all__ = ["UniNet", "UniNetModel"] diff --git a/src/anomalib/models/image/uninet/anomaly_map.py b/src/anomalib/models/image/uninet/anomaly_map.py new file mode 100644 index 0000000000..64bc59b859 --- /dev/null +++ b/src/anomalib/models/image/uninet/anomaly_map.py @@ -0,0 +1,72 @@ +"""Utilities for computing anomaly maps.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.ndimage import gaussian_filter + + +def weighted_decision_mechanism( + batch_size: int, + output_list: list[list[torch.Tensor]], + alpha: float, + beta: float, + out_size: int = 25, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute anomaly maps using weighted decision mechanism. + + Args: + batch_size (int): Batch size. + output_list (list[list[torch.Tensor]]): List of output tensors. + alpha (float): Alpha parameter. Used for controlling the upper limit + beta (float): Beta parameter. Used for controlling the lower limit + out_size (int): Output size. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Anomaly score and anomaly map. + """ + total_weights_list = [] + for i in range(batch_size): + low_similarity_list = [torch.max(output_list[j][i]) for j in range(len(output_list))] + probs = F.softmax(torch.tensor(low_similarity_list), dim=0) + weight_list = [] # set P consists of L high probability values, where L ranges from n-1 to n+1 + for idx, prob in enumerate(probs): + weight_list.append(low_similarity_list[idx].cpu().numpy()) if prob > torch.mean(probs) else None + # TODO: see if all the numpy operations can be replaced with torch operations + weight = np.max([np.mean(weight_list) * alpha, beta]) + total_weights_list.append(weight) + + assert len(total_weights_list) == batch_size, "The number of weights do not match the number of samples" + + anomaly_map_lists = [[] for _ in output_list] + for idx, output in enumerate(output_list): + cat_output = torch.cat(output, dim=0) + anomaly_map = torch.unsqueeze(cat_output, dim=1) # Bx1xhxw + # Bx256x256 + anomaly_map_lists[idx] = F.interpolate(anomaly_map, out_size, mode="bilinear", align_corners=True)[:, 0, :, :] + + anomaly_map = sum(anomaly_map_lists) + + anomaly_score = [] + for idx in range(batch_size): + top_k = int(out_size * out_size * total_weights_list[idx]) + assert top_k >= 1 / (out_size * out_size), "weight can not be smaller than 1 / (H * W)!" + + single_anomaly_score_exp = anomaly_map[idx] + single_anomaly_score_exp = torch.tensor(gaussian_filter(single_anomaly_score_exp.cpu().numpy(), sigma=4)) + assert single_anomaly_score_exp.reshape(1, -1).shape[-1] == out_size * out_size, ( + "something wrong with the last dimension of reshaped map!" + ) + single_map = single_anomaly_score_exp.reshape(1, -1) + single_anomaly_score = np.sort(single_map.topk(top_k, dim=-1)[0].detach().cpu().numpy(), axis=1) + anomaly_score.append(single_anomaly_score) + return anomaly_score, anomaly_map.detach().cpu().numpy() diff --git a/src/anomalib/models/image/uninet/dfs.py b/src/anomalib/models/image/uninet/dfs.py new file mode 100644 index 0000000000..73a0ec3c85 --- /dev/null +++ b/src/anomalib/models/image/uninet/dfs.py @@ -0,0 +1,127 @@ +"""Domain Related Feature Selection.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F +from torch import nn + + +class DomainRelatedFeatureSelection(nn.Module): + """Domain Related Feature Selection. + + It is used to select the domain-related features from the source and target features. + + Args: + num_channels (int): Number of channels in the features. Defaults to 256. + """ + + def __init__(self, num_channels: int = 256) -> None: + super().__init__() + self.num_channels = num_channels + self.theta1 = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) + self.theta2 = nn.Parameter(torch.zeros(1, num_channels * 2, 1, 1)) + self.theta3 = nn.Parameter(torch.zeros(1, num_channels * 4, 1, 1)) + + def _get_theta(self, idx: int) -> torch.Tensor: + match idx: + case 0: + return self.theta1 + case 1: + return self.theta2 + case 2: + return self.theta3 + case _: + msg = f"Invalid index: {idx}" + raise ValueError(msg) + + def forward( + self, + source_features: list[torch.Tensor], + target_features: list[torch.Tensor], + learnable: bool = True, + conv: bool = False, + maximize: bool = True, + ) -> list[torch.Tensor]: + """Domain related feature selection. + + Args: + source_features (list[torch.Tensor]): Source features. + target_features (list[torch.Tensor]): Target features. + learnable (bool): Whether to use learnable theta. Theta controls the domain-related feature selection. + Defaults to True. + conv (bool): Whether to use convolutional domain-related feature selection. + Defaults to False. + maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the + max value from the target feature. Defaults to True. + + Returns: + list[torch.Tensor]: Domain related features. + """ + # TODO: maybe move learnable to constructor + features = [] + for idx, (source_feature, target_feature) in enumerate(zip(source_features, target_features, strict=True)): + theta = 1 + if learnable: + # to avoid losing local weight, theta should be as non-zero value as possible + if idx < 3: + theta = torch.clamp(torch.sigmoid(self._get_theta(idx + 1)) * 1.0 + 0.5, max=1) + else: + theta = torch.clamp(torch.sigmoid(self._get_theta(idx - 2)) * 1.0 + 0.5, max=1) + + b, c, h, w = source_feature.shape + if not conv: + prior_flat = target_feature.view(b, c, -1) + if maximize: + prior_flat_ = prior_flat.max(dim=-1, keepdim=True)[0] + prior_flat = prior_flat - prior_flat_ + weights = F.softmax(prior_flat, dim=-1) + weights = weights.view(b, c, h, w) + + global_inf = target_feature.mean(dim=(-2, -1), keepdim=True) + + inter_weights = weights * (theta + global_inf) + + x_ = source_feature * inter_weights + features.append(x_) + else: + # generated in a convolutional way + pass + + return features + + +def domain_related_feature_selection( + source_features: list[torch.Tensor], target_features: list[torch.Tensor], maximize: bool = True +) -> list[torch.Tensor]: + """Domain related feature selection. + + Args: + source_features (list[torch.Tensor]): Source features. + target_features (list[torch.Tensor]): Target features. + maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the + max value from the target feature. Defaults to True. + """ + features_list = [] + theta = 1 + for source_feature, target_feature in zip(source_features, target_features, strict=True): + b, c, h, w = source_feature.shape + prior_flat = target_feature.view(b, c, -1) + if maximize: + prior_flat_ = prior_flat.max(dim=-1, keepdim=True)[0] + prior_flat = prior_flat - prior_flat_ + weights = F.softmax(prior_flat, dim=-1) + weights = weights.view(b, c, h, w) + + global_inf = target_feature.mean(dim=(-2, -1), keepdim=True) + inter_weights = weights * (theta + global_inf) + source_feature_weighted = source_feature * inter_weights + features_list.append(source_feature_weighted) + return features_list diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py new file mode 100644 index 0000000000..a95440be12 --- /dev/null +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -0,0 +1,133 @@ +"""Lightning model for UniNet.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from anomalib.data import Batch, InferenceBatch +from anomalib.models.components import AnomalibModule + +from .anomaly_map import weighted_decision_mechanism +from .dfs import DomainRelatedFeatureSelection, domain_related_feature_selection +from .loss import UniNetLoss +from .torch_model import UniNetModel +from torchvision.models.resnet + + +class UniNet(AnomalibModule): + """UniNet model for anomaly detection. + + Args: + dfs (DomainRelatedFeatureSelection | None): Domain related feature selection module. + """ + + def __init__(self, dfs: DomainRelatedFeatureSelection | None = None) -> None: + # TODO: maybe DFS shouldn't be passed as argument as not sure when it will be the case it needs to be passed + self.dfs = dfs + self.loss = UniNetLoss(temperature=2.0) + self.bce_loss = torch.nn.BCEWithLogitsLoss() + self.model = UniNetModel(student=, bottleneck=, source_teacher=, target_teacher=) + + def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Perform a training step of UniNet.""" + del args, kwargs # These variables are not used. + source_target_features, student_features, predictions = self.model(batch.image) + student_features = self._feature_selection(source_target_features, student_features) + loss = self._compute_loss(student_features, source_target_features, predictions, batch.gt_label, batch.gt_mask) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Perform a validation step of UniNet.""" + del args, kwargs # These variables are not used. + source_target_features, student_features, _ = self.model(batch.image) + output_list = [[] for _ in range(self.model.num_teachers * 3)] + for idx, (target_feature, student_feature) in enumerate( + zip(source_target_features, student_features, strict=False), + ): + output = 1 - F.cosine_similarity(target_feature, student_feature) + output_list[idx].append(output) + + anomaly_score, anomaly_map = weighted_decision_mechanism( + batch_size=batch.image.shape[0], + output_list=output_list, + alpha=0.01, + beta=0.00003, + ) + predictions = InferenceBatch( + pred_score=anomaly_score, + anomaly_map=anomaly_map, + ) + return batch.update(**predictions._asdict()) + + def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]: + """Configure optimizers for training. + + Returns: + tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Optimizers for student and target teacher. + """ + # TODO: refactor this + return [ + torch.optim.AdamW( + self.model.student.parameters() + self.model.bottleneck.parameters() + self.dfs.parameters(), + lr=5e-3, + betas=(0.9, 0.999), + weight_decay=1e-5, + ), + torch.optim.AdamW(self.target_teacher.parameters(), lr=1e-6, betas=(0.9, 0.999), weight_decay=1e-5), + ] + + def _feature_selection( + self, + source_features: list[torch.Tensor], + target_features: list[torch.Tensor], + maximize: bool = True, + ) -> list[torch.Tensor]: + """Feature selection. + + Args: + source_features (list[torch.Tensor]): Source features. + target_features (list[torch.Tensor]): Target features. + maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the + max value from the target feature. Defaults to True. + """ + if self.dfs is not None: + selected_features = self.dfs(source_features, target_features, maximize=maximize) + else: + selected_features = domain_related_feature_selection(source_features, target_features, maximize=maximize) + return selected_features + + def _compute_loss( + self, + student_features: list[torch.Tensor], + teacher_features: list[torch.Tensor], + predictions: torch.Tensor | None = None, + label: int | None = None, + mask: torch.Tensor | None = None, + stop_gradient: bool = False, + ) -> torch.Tensor: + """Compute the loss. + + Args: + student_features (list[torch.Tensor]): Student features. + teacher_features (list[torch.Tensor]): Teacher features. + predictions (torch.Tensor): Predictions. + label (int | None): Label for the prediction. + mask (torch.Tensor | None): Mask for the prediction. + stop_gradient (bool): Whether to stop the gradient into teacher features. + """ + # TODO: figure out the dimension of predictions and update the docstring + if mask is not None: + mask_ = mask + else: + assert label is not None, "Label is required when mask is not provided" + mask_ = label + + loss = self.loss(student_features, teacher_features, mask=mask_, stop_gradient=stop_gradient) + if predictions is not None and label is not None: + loss += self.bce_loss(predictions[0], label.float()) + self.bce_loss(predictions[1], label.float()) + + return loss diff --git a/src/anomalib/models/image/uninet/loss.py b/src/anomalib/models/image/uninet/loss.py new file mode 100644 index 0000000000..b5430a392c --- /dev/null +++ b/src/anomalib/models/image/uninet/loss.py @@ -0,0 +1,104 @@ +"""Loss functions for UniNet.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F +from torch import nn + + +class UniNetLoss(nn.Module): + """Loss function for UniNet. + + Args: + lambda_weight (float): Hyperparameter for balancing the loss. Defaults to 0.7. + temperature (float): Temperature for contrastive learning. Defaults to 2.0. + """ + + def __init__(self, lambda_weight: float = 0.7, temperature: float = 2.0) -> None: + super().__init__() + self.lambda_weight = lambda_weight + self.temperature = temperature + + def forward( + self, + student_features: list[torch.Tensor], + teacher_features: list[torch.Tensor], + margin: int = 1, + mask: torch.Tensor | None = None, + stop_gradient: bool = False, + ) -> torch.Tensor: + """Compute the loss. + + Args: + student_features (list[torch.Tensor]): Student features. + teacher_features (list[torch.Tensor]): Teacher features. + margin (int): Hyperparameter for controlling the boundary. + mask (torch.Tensor | None): Mask for the prediction. + stop_gradient (bool): Whether to stop the gradient into teacher features. + """ + loss = 0.0 + margin_loss_a = 0.0 + + for idx in range(len(student_features)): + student_feature = student_features[idx] + teacher_feature = teacher_features[idx].detach() if stop_gradient else teacher_features[idx] + + n, c, h, w = student_feature.shape + student_feature = student_feature.view(n, c, -1).transpose(1, 2) # (N, H+W, C) + teacher_feature = teacher_feature.view(n, c, -1).transpose(1, 2) # (N, H+W, C) + + student_feature_normalized = F.normalize(student_feature, p=2, dim=2) + teacher_feature_normalized = F.normalize(teacher_feature, p=2, dim=2) + + cosine_loss = 1 - F.cosine_similarity(student_feature_normalized, teacher_feature_normalized, dim=2) + cosine_loss = cosine_loss.mean() + + similarity = ( + torch.matmul(student_feature_normalized, teacher_feature_normalized.transpose(1, 2)) / self.temperature + ) + similarity = torch.exp(similarity) + similarity_sum = similarity.sum(dim=2, keepdim=True) + similarity = similarity / (similarity_sum + 1e-8) + diag_sum = torch.diagonal(similarity, dim1=1, dim2=2) + + # unsupervised and only normal (or abnormal) + if mask is None: + contrastive_loss = -torch.log(diag_sum + 1e-8).mean() + margin_loss_n = F.relu(margin - diag_sum).mean() + + # supervised + else: + # gt label + if len(mask.size()) < 3: + normal_mask = mask == 0 + abnormal_mask = mask == 1 + # gt mask + else: + mask_ = F.interpolate(mask, size=(h, w), mode="nearest").squeeze(1) + mask_flat = mask.view(mask_.size(0), -1) + + normal_mask = mask_flat == 0 + abnormal_mask = mask_flat == 1 + + if normal_mask.sum() > 0: + diag_sim_normal = diag_sum[normal_mask] + contrastive_loss = -torch.log(diag_sim_normal + 1e-8).mean() + margin_loss_n = F.relu(margin - diag_sim_normal).mean() + + if abnormal_mask.sum() > 0: + diag_sim_abnormal = diag_sum[abnormal_mask] + margin_loss_a = F.relu(diag_sim_abnormal - margin).mean() + + margin_loss = margin_loss_n + margin_loss_a + + loss += cosine_loss * self.lambda_weight + contrastive_loss * (1 - self.lambda_weight) + margin_loss + + return loss diff --git a/src/anomalib/models/image/uninet/torch_model.py b/src/anomalib/models/image/uninet/torch_model.py new file mode 100644 index 0000000000..95bd839184 --- /dev/null +++ b/src/anomalib/models/image/uninet/torch_model.py @@ -0,0 +1,108 @@ +"""PyTorch model for UniNet. + +See Also: + :class:`anomalib.models.image.uninet.lightning_model.UniNet`: + UniNet Lightning model. +""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn + + +class UniNetModel(nn.Module): + """UniNet PyTorch model. + + It consists of teachers, student, and bottleneck modules. + + Args: + student (nn.Module): Student model. + bottleneck (nn.Module): Bottleneck model. + source_teacher (nn.Module): Source teacher model. + target_teacher (nn.Module | None): Target teacher model. + """ + + def __init__( + self, + student: nn.Module, + bottleneck: nn.Module, + source_teacher: nn.Module, + target_teacher: nn.Module | None = None, + ) -> None: + super().__init__() + self.num_teachers = 1 if target_teacher is None else 2 + self.teachers = Teachers(source_teacher, target_teacher) + self.student = Student(student) + self.bottleneck = Bottleneck(bottleneck) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + source_target_features, bottleneck_inputs = self.teachers(images) + bottleneck_outputs = self.bottleneck(bottleneck_inputs) + + student_features, predictions = self.student(bottleneck_outputs) + student_features = [d.chunk(dim=0, chunks=2) for d in student_features] + student_features = [ + student_features[0][0], + student_features[1][0], + student_features[2][0], + student_features[0][1], + student_features[1][1], + student_features[2][1], + ] + + predictions = predictions.chunk(dim=0, chunks=2) + + return source_target_features, student_features, predictions + + +class Teachers(nn.Module): + def __init__(self, source_teacher: nn.Module, target_teacher: nn.Module | None = None) -> None: + super().__init__() + self.source_teacher = source_teacher + self.target_teacher = target_teacher + + def forward(self, images: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: + # TODO: revisit this method and clean up multiple return statements + with torch.no_grad(): + source_features = self.source_teacher(images) + + if self.target_teacher is None: + return source_features + + with torch.no_grad(): + target_features = self.target_teacher(images) + bottleneck_inputs = [ + torch.cat([a, b], dim=0) for a, b in zip(target_features, source_features, strict=True) + ] # 512, 1024, 2048 + + return source_features + target_features, bottleneck_inputs + + +class Bottleneck(nn.Module): + """Bottleneck module for UniNet.""" + + def __init__(self, bottleneck: nn.Module) -> None: + super().__init__() + self.bottleneck = bottleneck + + def forward(self, bottleneck_inputs: list[torch.Tensor]) -> torch.Tensor: + return self.bottleneck(bottleneck_inputs) + + +class Student(nn.Module): + """Student model for UniNet.""" + + def __init__(self, backbone: nn.Module) -> None: + super().__init__() + self.backbone = backbone + + def forward(self, bottleneck_outputs: torch.Tensor, skips=None) -> torch.Tensor: + return self.backbone(bottleneck_outputs, skips) diff --git a/third-party-programs.txt b/third-party-programs.txt index 5eeaca8ea9..bc11567182 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -46,3 +46,7 @@ terms are listed below. 8. AUPIMO metric implementation is based on the original code Copyright (c) 2023 @jpcbertoldo, https://github.com/jpcbertoldo/aupimo SPDX-License-Identifier: MIT + +9. UninetModel + Copyright (c) 2025 @pangdatangtt, https://github.com/pangdatangtt/UniNet/ + SPDX-License-Identifier: MIT \ No newline at end of file From 4fe02e46a444891c3a697a736e60a72cf140b184 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Thu, 26 Jun 2025 15:04:26 +0200 Subject: [PATCH 02/12] Add initial code Signed-off-by: Ashwin Vaidya --- .../models/components/backbone/__init__.py | 17 + .../backbone}/de_resnet.py | 10 +- .../components/__init__.py | 6 +- .../image/reverse_distillation/torch_model.py | 3 +- .../models/image/uninet/anomaly_map.py | 21 +- .../image/uninet/components/__init__.py | 14 + .../uninet/components/attention_bottleneck.py | 343 ++++++++++++++++++ .../image/uninet/{ => components}/dfs.py | 12 +- .../models/image/uninet/lightning_model.py | 75 +++- .../models/image/uninet/torch_model.py | 66 ++-- 10 files changed, 498 insertions(+), 69 deletions(-) create mode 100644 src/anomalib/models/components/backbone/__init__.py rename src/anomalib/models/{image/reverse_distillation/components => components/backbone}/de_resnet.py (98%) create mode 100644 src/anomalib/models/image/uninet/components/__init__.py create mode 100644 src/anomalib/models/image/uninet/components/attention_bottleneck.py rename src/anomalib/models/image/uninet/{ => components}/dfs.py (93%) diff --git a/src/anomalib/models/components/backbone/__init__.py b/src/anomalib/models/components/backbone/__init__.py new file mode 100644 index 0000000000..29778f9ff6 --- /dev/null +++ b/src/anomalib/models/components/backbone/__init__.py @@ -0,0 +1,17 @@ +"""Common backbone models. + +Example: + >>> from anomalib.models.components.backbone import get_decoder + >>> decoder = get_decoder() + +See Also: + - :func:`anomalib.models.components.backbone.de_resnet`: + Decoder network implementation +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .de_resnet import get_decoder + +__all__ = ["get_decoder"] diff --git a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py b/src/anomalib/models/components/backbone/de_resnet.py similarity index 98% rename from src/anomalib/models/image/reverse_distillation/components/de_resnet.py rename to src/anomalib/models/components/backbone/de_resnet.py index 44dcf4c0d6..698d881556 100644 --- a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py +++ b/src/anomalib/models/components/backbone/de_resnet.py @@ -1,6 +1,6 @@ -"""PyTorch model defining the decoder network for Reverse Distillation. +"""PyTorch model defining the decoder network for Reverse Distillation and UniNet. -This module implements the decoder network used in the Reverse Distillation model +This module implements the decoder network used in the Reverse Distillation and UniNet models architecture. The decoder reconstructs features from the bottleneck representation back to the original feature space. @@ -10,7 +10,7 @@ - Full decoder network architecture Example: - >>> from anomalib.models.image.reverse_distillation.components.de_resnet import ( + >>> from anomalib.models.components.backbone.de_resnet import ( ... get_decoder ... ) >>> decoder = get_decoder() @@ -18,9 +18,7 @@ >>> reconstructed = decoder(features) See Also: - - :class:`anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel`: - Main model implementation using this decoder - - :class:`anomalib.models.image.reverse_distillation.components.DecoderBasicBlock`: + - :class:`anomalib.models.components.backbone.de_resnet.DecoderBasicBlock`: Basic building block for the decoder network """ diff --git a/src/anomalib/models/image/reverse_distillation/components/__init__.py b/src/anomalib/models/image/reverse_distillation/components/__init__.py index 62e54fccb3..b484eaeb24 100644 --- a/src/anomalib/models/image/reverse_distillation/components/__init__.py +++ b/src/anomalib/models/image/reverse_distillation/components/__init__.py @@ -7,7 +7,6 @@ through distillation and reconstruction: - Bottleneck layer: Compresses features into a lower dimensional space -- Decoder network: Reconstructs features from the bottleneck representation Example: >>> from anomalib.models.image.reverse_distillation.components import ( @@ -20,14 +19,11 @@ See Also: - :func:`anomalib.models.image.reverse_distillation.components.bottleneck`: Bottleneck layer implementation - - :func:`anomalib.models.image.reverse_distillation.components.de_resnet`: - Decoder network implementation """ # Copyright (C) 2022-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from .bottleneck import get_bottleneck_layer -from .de_resnet import get_decoder -__all__ = ["get_bottleneck_layer", "get_decoder"] +__all__ = ["get_bottleneck_layer"] diff --git a/src/anomalib/models/image/reverse_distillation/torch_model.py b/src/anomalib/models/image/reverse_distillation/torch_model.py index 2d4ed55af4..54f1650108 100644 --- a/src/anomalib/models/image/reverse_distillation/torch_model.py +++ b/src/anomalib/models/image/reverse_distillation/torch_model.py @@ -39,9 +39,10 @@ from anomalib.data import InferenceBatch from anomalib.models.components import TimmFeatureExtractor +from anomalib.models.components.backbone import get_decoder from .anomaly_map import AnomalyMapGenerationMode, AnomalyMapGenerator -from .components import get_bottleneck_layer, get_decoder +from .components import get_bottleneck_layer if TYPE_CHECKING: from anomalib.data.utils.tiler import Tiler diff --git a/src/anomalib/models/image/uninet/anomaly_map.py b/src/anomalib/models/image/uninet/anomaly_map.py index 64bc59b859..02e0e95bdb 100644 --- a/src/anomalib/models/image/uninet/anomaly_map.py +++ b/src/anomalib/models/image/uninet/anomaly_map.py @@ -11,8 +11,8 @@ import numpy as np import torch -import torch.nn.functional as F from scipy.ndimage import gaussian_filter +from torch.nn import functional def weighted_decision_mechanism( @@ -37,22 +37,27 @@ def weighted_decision_mechanism( total_weights_list = [] for i in range(batch_size): low_similarity_list = [torch.max(output_list[j][i]) for j in range(len(output_list))] - probs = F.softmax(torch.tensor(low_similarity_list), dim=0) + probs = functional.softmax(torch.tensor(low_similarity_list), dim=0) weight_list = [] # set P consists of L high probability values, where L ranges from n-1 to n+1 for idx, prob in enumerate(probs): weight_list.append(low_similarity_list[idx].cpu().numpy()) if prob > torch.mean(probs) else None - # TODO: see if all the numpy operations can be replaced with torch operations + # TODO(ashwinvaidya17): see if all the numpy operations can be replaced with torch operations weight = np.max([np.mean(weight_list) * alpha, beta]) total_weights_list.append(weight) assert len(total_weights_list) == batch_size, "The number of weights do not match the number of samples" - anomaly_map_lists = [[] for _ in output_list] + anomaly_map_lists: list[list[torch.Tensor]] = [[] for _ in output_list] for idx, output in enumerate(output_list): cat_output = torch.cat(output, dim=0) anomaly_map = torch.unsqueeze(cat_output, dim=1) # Bx1xhxw # Bx256x256 - anomaly_map_lists[idx] = F.interpolate(anomaly_map, out_size, mode="bilinear", align_corners=True)[:, 0, :, :] + anomaly_map_lists[idx] = functional.interpolate(anomaly_map, out_size, mode="bilinear", align_corners=True)[ + :, + 0, + :, + :, + ] anomaly_map = sum(anomaly_map_lists) @@ -63,9 +68,9 @@ def weighted_decision_mechanism( single_anomaly_score_exp = anomaly_map[idx] single_anomaly_score_exp = torch.tensor(gaussian_filter(single_anomaly_score_exp.cpu().numpy(), sigma=4)) - assert single_anomaly_score_exp.reshape(1, -1).shape[-1] == out_size * out_size, ( - "something wrong with the last dimension of reshaped map!" - ) + assert ( + single_anomaly_score_exp.reshape(1, -1).shape[-1] == out_size * out_size + ), "something wrong with the last dimension of reshaped map!" single_map = single_anomaly_score_exp.reshape(1, -1) single_anomaly_score = np.sort(single_map.topk(top_k, dim=-1)[0].detach().cpu().numpy(), axis=1) anomaly_score.append(single_anomaly_score) diff --git a/src/anomalib/models/image/uninet/components/__init__.py b/src/anomalib/models/image/uninet/components/__init__.py new file mode 100644 index 0000000000..ea51daeda4 --- /dev/null +++ b/src/anomalib/models/image/uninet/components/__init__.py @@ -0,0 +1,14 @@ +"""PyTorch modules for the UniNet model implementation.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .attention_bottleneck import AttentionBottleneck, BottleneckLayer +from .dfs import DomainRelatedFeatureSelection, domain_related_feature_selection + +__all__ = [ + "domain_related_feature_selection", + "DomainRelatedFeatureSelection", + "AttentionBottleneck", + "BottleneckLayer", +] diff --git a/src/anomalib/models/image/uninet/components/attention_bottleneck.py b/src/anomalib/models/image/uninet/components/attention_bottleneck.py new file mode 100644 index 0000000000..b36deedc64 --- /dev/null +++ b/src/anomalib/models/image/uninet/components/attention_bottleneck.py @@ -0,0 +1,343 @@ +"""Attention Bottleneck for UniNet.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Callable + +import torch +from torch import nn + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding. + + Args: + in_planes (int): Number of input planes. + out_planes (int): Number of output planes. + stride (int): Stride of the convolution. + groups (int): Number of groups. + dilation (int): Dilation rate. + + Returns: + nn.Conv2d: 3x3 convolution with padding. + """ + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution. + + Args: + in_planes (int): Number of input planes. + out_planes (int): Number of output planes. + stride (int): Stride of the convolution. + """ + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def fuse_bn(conv: nn.Module, bn: nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + """Fuse convolution and batch normalization layers. + + Args: + conv (nn.Module): Convolution layer. + bn (nn.Module): Batch normalization layer. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Fused convolution and batch normalization layers. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class AttentionBottleneck(nn.Module): + """Attention Bottleneck for UniNet.""" + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: nn.Module | None = None, + groups: int = 1, + base_width: int = 64, + norm_layer: Callable[..., nn.Module] | None = None, + attention: bool = True, + halve: int = 1, + ) -> None: + super().__init__() + self.attention = attention + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups # 512 + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.halve = halve + k = 7 + p = 3 + + self.bn2 = norm_layer(width // halve) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + + self.downsample = downsample + self.stride = stride + + self.bn4 = norm_layer(width // 2) + self.bn5 = norm_layer(width // 2) + self.bn6 = norm_layer(width // 2) + self.bn7 = norm_layer(width) + self.conv3x3 = nn.Conv2d(inplanes // 2, width // 2, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv3x3_ = nn.Conv2d(width // 2, width // 2, 3, 1, 1, bias=False) + self.conv7x7 = nn.Conv2d(inplanes // 2, width // 2, kernel_size=k, stride=stride, padding=p, bias=False) + self.conv7x7_ = nn.Conv2d(width // 2, width // 2, k, 1, p, bias=False) + + def get_same_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get same kernel and bias of the bottleneck.""" + k1, b1 = fuse_bn(self.conv3x3, self.bn2) + k2, b2 = fuse_bn(self.conv3x3_, self.bn6) + + return k1, b1, k2, b2 + + def merge_kernel(self) -> None: + """Merge kernel of the bottleneck.""" + k1, b1, k2, b2 = self.get_same_kernel_bias() + self.conv7x7 = nn.Conv2d( + self.conv3x3.in_channels, + self.conv3x3.out_channels, + self.conv3x3.kernel_size, + self.conv3x3.stride, + self.conv3x3.padding, + self.conv3x3.dilation, + self.conv3x3.groups, + ) + self.conv7x7_ = nn.Conv2d( + self.conv3x3_.in_channels, + self.conv3x3_.out_channels, + self.conv3x3_.kernel_size, + self.conv3x3_.stride, + self.conv3x3_.padding, + self.conv3x3_.dilation, + self.conv3x3_.groups, + ) + self.conv7x7.weight.data = k1 + self.conv7x7.bias.data = b1 + self.conv7x7_.weight.data = k2 + self.conv7x7_.bias.data = b2 + + @staticmethod + def _process_branch( + branch: torch.Tensor, + conv1: nn.Module, + bn1: nn.Module, + conv2: nn.Module, + bn2: nn.Module, + relu: nn.Module, + ) -> torch.Tensor: + """Process a branch of the bottleneck. + + Args: + branch (torch.Tensor): Input tensor. + conv1 (nn.Module): First convolution layer. + bn1 (nn.Module): First batch normalization layer. + conv2 (nn.Module): Second convolution layer. + bn2 (nn.Module): Second batch normalization layer. + relu (nn.Module): ReLU activation layer. + + Returns: + torch.Tensor: Output tensor. + """ + out = conv1(branch) + out = bn1(out) + out = relu(out) + out = conv2(out) + return bn2(out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + identity = x + + if self.halve == 1: + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + else: + num_channels = x.shape[1] + x_split = torch.split(x, [num_channels // 2, num_channels // 2], dim=1) + + out1 = self._process_branch(x_split[0], self.conv3x3, self.bn2, self.conv3x3_, self.bn5, self.relu) + out2 = self._process_branch(x_split[-1], self.conv7x7, self.bn4, self.conv7x7_, self.bn6, self.relu) + + out = torch.cat([out1, out2], dim=1) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = out + identity + return self.relu(out) + + +class BottleneckLayer(nn.Module): + """Batch Normalization layer for UniNet. + + Args: + block (Type[AttentionBottleneck]): Attention bottleneck. + layers (int): Number of layers. + groups (int): Number of groups. + width_per_group (int): Width per group. + norm_layer (Callable[..., nn.Module] | None): Normalization layer. + halve (int): Number of halved channels. + """ + + def __init__( + self, + block: type[AttentionBottleneck], + layers: int, + groups: int = 1, + width_per_group: int = 64, + norm_layer: Callable[..., nn.Module] | None = None, + halve: int = 2, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.groups = groups + self.base_width = width_per_group + self.inplanes = 256 * block.expansion + self.halve = halve + self.bn_layer = nn.Sequential(self._make_layer(block, 512, layers, stride=2)) + self.conv1 = conv3x3(64 * block.expansion, 128 * block.expansion, 2) + self.bn1 = norm_layer(128 * block.expansion) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) + self.bn2 = norm_layer(256 * block.expansion) + self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) + self.bn3 = norm_layer(256 * block.expansion) + + self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1) + self.bn4 = norm_layer(512 * block.expansion) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d | nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif self.halve == 2: + m.merge_kernel() + + def _make_layer( + self, + block: type[AttentionBottleneck], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + """Make a layer of the bottleneck. + + Args: + block (AttentionBottleneck): Attention bottleneck. + planes (int): Number of planes. + blocks (int): Number of blocks. + stride (int): Stride of the convolution. + dilate (bool): Whether to dilate the convolution. + + Returns: + nn.Sequential: Sequential layer. + """ + norm_layer = self._norm_layer + downsample = None + if dilate: + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes * 3, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes * 3, + planes, + stride, + downsample, + self.groups, + self.base_width, + norm_layer, + halve=self.halve, + ), + ) + self.inplanes = planes * block.expansion + layers.extend( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + norm_layer=norm_layer, + halve=self.halve, + ) + for _ in range(1, blocks) + ) + + return nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + l1 = self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x[0])))))) + l2 = self.relu(self.bn3(self.conv3(x[1]))) + feature = torch.cat([l1, l2, x[2]], 1) + output = self.bn_layer(feature) # 16*2048*8*8 + return output.contiguous() diff --git a/src/anomalib/models/image/uninet/dfs.py b/src/anomalib/models/image/uninet/components/dfs.py similarity index 93% rename from src/anomalib/models/image/uninet/dfs.py rename to src/anomalib/models/image/uninet/components/dfs.py index 73a0ec3c85..cfe78c684e 100644 --- a/src/anomalib/models/image/uninet/dfs.py +++ b/src/anomalib/models/image/uninet/components/dfs.py @@ -10,8 +10,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import torch.nn.functional as F from torch import nn +from torch.nn import functional class DomainRelatedFeatureSelection(nn.Module): @@ -65,7 +65,7 @@ def forward( Returns: list[torch.Tensor]: Domain related features. """ - # TODO: maybe move learnable to constructor + # TODO(ashwinvaidya17): maybe move learnable to constructor features = [] for idx, (source_feature, target_feature) in enumerate(zip(source_features, target_features, strict=True)): theta = 1 @@ -82,7 +82,7 @@ def forward( if maximize: prior_flat_ = prior_flat.max(dim=-1, keepdim=True)[0] prior_flat = prior_flat - prior_flat_ - weights = F.softmax(prior_flat, dim=-1) + weights = functional.softmax(prior_flat, dim=-1) weights = weights.view(b, c, h, w) global_inf = target_feature.mean(dim=(-2, -1), keepdim=True) @@ -99,7 +99,9 @@ def forward( def domain_related_feature_selection( - source_features: list[torch.Tensor], target_features: list[torch.Tensor], maximize: bool = True + source_features: list[torch.Tensor], + target_features: list[torch.Tensor], + maximize: bool = True, ) -> list[torch.Tensor]: """Domain related feature selection. @@ -117,7 +119,7 @@ def domain_related_feature_selection( if maximize: prior_flat_ = prior_flat.max(dim=-1, keepdim=True)[0] prior_flat = prior_flat - prior_flat_ - weights = F.softmax(prior_flat, dim=-1) + weights = functional.softmax(prior_flat, dim=-1) weights = weights.view(b, c, h, w) global_inf = target_feature.mean(dim=(-2, -1), keepdim=True) diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py index a95440be12..a1a0cd2971 100644 --- a/src/anomalib/models/image/uninet/lightning_model.py +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -3,33 +3,54 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import copy + import torch -import torch.nn.functional as F +import torchvision from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import nn +from torch.nn import functional +from torchvision.models.feature_extraction import create_feature_extractor from anomalib.data import Batch, InferenceBatch from anomalib.models.components import AnomalibModule +from anomalib.models.components.backbone import get_decoder from .anomaly_map import weighted_decision_mechanism -from .dfs import DomainRelatedFeatureSelection, domain_related_feature_selection +from .components import ( + AttentionBottleneck, + BottleneckLayer, + DomainRelatedFeatureSelection, + domain_related_feature_selection, +) from .loss import UniNetLoss from .torch_model import UniNetModel -from torchvision.models.resnet class UniNet(AnomalibModule): """UniNet model for anomaly detection. Args: - dfs (DomainRelatedFeatureSelection | None): Domain related feature selection module. + student_backbone (str): The backbone model to use for the student. + teacher_backbone (str): The backbone model to use for the teacher. """ - def __init__(self, dfs: DomainRelatedFeatureSelection | None = None) -> None: - # TODO: maybe DFS shouldn't be passed as argument as not sure when it will be the case it needs to be passed - self.dfs = dfs + def __init__( + self, + student_backbone: str = "wide_resent50_2", + teacher_backbone: str = "wide_resent50_2", + ) -> None: + self.dfs = DomainRelatedFeatureSelection() self.loss = UniNetLoss(temperature=2.0) self.bce_loss = torch.nn.BCEWithLogitsLoss() - self.model = UniNetModel(student=, bottleneck=, source_teacher=, target_teacher=) + source_teacher = self._get_teacher(teacher_backbone) + target_teacher = copy.deepcopy(source_teacher) + self.model = UniNetModel( + student=get_decoder(student_backbone), + bottleneck=BottleneckLayer(block=AttentionBottleneck, layers=3), + source_teacher=source_teacher, + target_teacher=target_teacher, + ) def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a training step of UniNet.""" @@ -43,12 +64,16 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a validation step of UniNet.""" del args, kwargs # These variables are not used. + if batch.image is None or not isinstance(batch.image, torch.Tensor): + msg = "Expected batch.image to be a tensor, but got None or non-tensor type" + raise ValueError(msg) + source_target_features, student_features, _ = self.model(batch.image) - output_list = [[] for _ in range(self.model.num_teachers * 3)] + output_list: list[list[torch.Tensor]] = [[] for _ in range(self.model.num_teachers * 3)] for idx, (target_feature, student_feature) in enumerate( zip(source_target_features, student_features, strict=False), ): - output = 1 - F.cosine_similarity(target_feature, student_feature) + output = 1 - functional.cosine_similarity(target_feature, student_feature) output_list[idx].append(output) anomaly_score, anomaly_map = weighted_decision_mechanism( @@ -69,8 +94,8 @@ def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optim Returns: tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Optimizers for student and target teacher. """ - # TODO: refactor this - return [ + # TODO(ashwinvaidya17): refactor this + return ( torch.optim.AdamW( self.model.student.parameters() + self.model.bottleneck.parameters() + self.dfs.parameters(), lr=5e-3, @@ -78,7 +103,23 @@ def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optim weight_decay=1e-5, ), torch.optim.AdamW(self.target_teacher.parameters(), lr=1e-6, betas=(0.9, 0.999), weight_decay=1e-5), - ] + ) + + @staticmethod + def _get_teacher(backbone: str) -> nn.Module: + """Get the teacher model. + + In the original code, the teacher resnet model is used to extract features from the input image. + We can just use the feature extractor from torchvision to extract the features. + + Args: + backbone (str): The backbone model to use. + + Returns: + nn.Module: The teacher model. + """ + model = getattr(torchvision.models, backbone)(pretrained=True) + return create_feature_extractor(model, return_nodes=["layer1", "layer2", "layer3"]) def _feature_selection( self, @@ -105,7 +146,7 @@ def _compute_loss( student_features: list[torch.Tensor], teacher_features: list[torch.Tensor], predictions: torch.Tensor | None = None, - label: int | None = None, + label: float | None = None, mask: torch.Tensor | None = None, stop_gradient: bool = False, ) -> torch.Tensor: @@ -115,11 +156,11 @@ def _compute_loss( student_features (list[torch.Tensor]): Student features. teacher_features (list[torch.Tensor]): Teacher features. predictions (torch.Tensor): Predictions. - label (int | None): Label for the prediction. + label (float | None): Label for the prediction. mask (torch.Tensor | None): Mask for the prediction. stop_gradient (bool): Whether to stop the gradient into teacher features. """ - # TODO: figure out the dimension of predictions and update the docstring + # TODO(ashwinvaidya17): figure out the dimension of predictions and update the docstring if mask is not None: mask_ = mask else: @@ -128,6 +169,6 @@ def _compute_loss( loss = self.loss(student_features, teacher_features, mask=mask_, stop_gradient=stop_gradient) if predictions is not None and label is not None: - loss += self.bce_loss(predictions[0], label.float()) + self.bce_loss(predictions[1], label.float()) + loss += self.bce_loss(predictions[0], label) + self.bce_loss(predictions[1], label) return loss diff --git a/src/anomalib/models/image/uninet/torch_model.py b/src/anomalib/models/image/uninet/torch_model.py index 95bd839184..837fbfe7a6 100644 --- a/src/anomalib/models/image/uninet/torch_model.py +++ b/src/anomalib/models/image/uninet/torch_model.py @@ -40,14 +40,33 @@ def __init__( super().__init__() self.num_teachers = 1 if target_teacher is None else 2 self.teachers = Teachers(source_teacher, target_teacher) - self.student = Student(student) - self.bottleneck = Bottleneck(bottleneck) + self.student = student + self.bottleneck = bottleneck + + # Used to post-process the student features from the de_resnet model to get the predictions + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1) def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of the UniNet model. + + Args: + images (torch.Tensor): Input images. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Source target features, student features, and predictions. + """ source_target_features, bottleneck_inputs = self.teachers(images) bottleneck_outputs = self.bottleneck(bottleneck_inputs) - student_features, predictions = self.student(bottleneck_outputs) + student_features = self.student(bottleneck_outputs) + + # These predictions are part of the de_resnet model of the original code. + # since we are using the de_resnet model from anomalib, we need to compute predictions here + predictions = self.avgpool(student_features[2]) + predictions = torch.flatten(predictions, 1) + predictions = self.fc(predictions).squeeze() + student_features = [d.chunk(dim=0, chunks=2) for d in student_features] student_features = [ student_features[0][0], @@ -64,13 +83,28 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tor class Teachers(nn.Module): + """Teachers module for UniNet. + + Args: + source_teacher (nn.Module): Source teacher model. + target_teacher (nn.Module | None): Target teacher model. + """ + def __init__(self, source_teacher: nn.Module, target_teacher: nn.Module | None = None) -> None: super().__init__() self.source_teacher = source_teacher self.target_teacher = target_teacher - def forward(self, images: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: - # TODO: revisit this method and clean up multiple return statements + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Forward pass of the teachers. + + Args: + images (torch.Tensor): Input images. + + Returns: + torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: Source features or source and target features. + """ + # TODO(ashwinvaidya17): revisit this method and clean up multiple return statements with torch.no_grad(): source_features = self.source_teacher(images) @@ -84,25 +118,3 @@ def forward(self, images: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, li ] # 512, 1024, 2048 return source_features + target_features, bottleneck_inputs - - -class Bottleneck(nn.Module): - """Bottleneck module for UniNet.""" - - def __init__(self, bottleneck: nn.Module) -> None: - super().__init__() - self.bottleneck = bottleneck - - def forward(self, bottleneck_inputs: list[torch.Tensor]) -> torch.Tensor: - return self.bottleneck(bottleneck_inputs) - - -class Student(nn.Module): - """Student model for UniNet.""" - - def __init__(self, backbone: nn.Module) -> None: - super().__init__() - self.backbone = backbone - - def forward(self, bottleneck_outputs: torch.Tensor, skips=None) -> torch.Tensor: - return self.backbone(bottleneck_outputs, skips) From 51fa496184599a691edbc5b41582fa56baa79fa7 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Mon, 30 Jun 2025 15:22:54 +0200 Subject: [PATCH 03/12] Working model Signed-off-by: Ashwin Vaidya --- src/anomalib/models/__init__.py | 2 + src/anomalib/models/image/__init__.py | 2 + .../models/image/uninet/anomaly_map.py | 28 +++--- .../uninet/components/attention_bottleneck.py | 16 ++-- .../models/image/uninet/components/dfs.py | 6 +- .../models/image/uninet/lightning_model.py | 92 +++++++++++-------- src/anomalib/models/image/uninet/loss.py | 26 +++--- .../models/image/uninet/torch_model.py | 52 +++++++---- 8 files changed, 135 insertions(+), 89 deletions(-) diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 553f3c7b7c..ca99daf96e 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -72,6 +72,7 @@ Stfpm, Supersimplenet, Uflow, + UniNet, VlmAd, WinClip, ) @@ -108,6 +109,7 @@ class UnknownModelError(ModuleNotFoundError): "Stfpm", "Supersimplenet", "Uflow", + "UniNet", "VlmAd", "WinClip", "AiVad", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 2717290f3a..ab90b47420 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -61,6 +61,7 @@ from .stfpm import Stfpm from .supersimplenet import Supersimplenet from .uflow import Uflow +from .uninet import UniNet from .vlm_ad import VlmAd from .winclip import WinClip @@ -82,6 +83,7 @@ "Stfpm", "Supersimplenet", "Uflow", + "UniNet", "VlmAd", "WinClip", ] diff --git a/src/anomalib/models/image/uninet/anomaly_map.py b/src/anomalib/models/image/uninet/anomaly_map.py index 02e0e95bdb..9504beec8a 100644 --- a/src/anomalib/models/image/uninet/anomaly_map.py +++ b/src/anomalib/models/image/uninet/anomaly_map.py @@ -9,9 +9,10 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import einops import numpy as np import torch -from scipy.ndimage import gaussian_filter +from kornia.filters import gaussian_blur2d from torch.nn import functional @@ -49,17 +50,16 @@ def weighted_decision_mechanism( anomaly_map_lists: list[list[torch.Tensor]] = [[] for _ in output_list] for idx, output in enumerate(output_list): - cat_output = torch.cat(output, dim=0) - anomaly_map = torch.unsqueeze(cat_output, dim=1) # Bx1xhxw + anomaly_map_ = torch.unsqueeze(output, dim=1) # Bx1xhxw # Bx256x256 - anomaly_map_lists[idx] = functional.interpolate(anomaly_map, out_size, mode="bilinear", align_corners=True)[ + anomaly_map_lists[idx] = functional.interpolate(anomaly_map_, out_size, mode="bilinear", align_corners=True)[ :, 0, :, :, ] - anomaly_map = sum(anomaly_map_lists) + anomaly_map: torch.Tensor = sum(anomaly_map_lists) anomaly_score = [] for idx in range(batch_size): @@ -67,11 +67,15 @@ def weighted_decision_mechanism( assert top_k >= 1 / (out_size * out_size), "weight can not be smaller than 1 / (H * W)!" single_anomaly_score_exp = anomaly_map[idx] - single_anomaly_score_exp = torch.tensor(gaussian_filter(single_anomaly_score_exp.cpu().numpy(), sigma=4)) - assert ( - single_anomaly_score_exp.reshape(1, -1).shape[-1] == out_size * out_size - ), "something wrong with the last dimension of reshaped map!" + single_anomaly_score_exp = gaussian_blur2d( + einops.rearrange(single_anomaly_score_exp, "h w -> 1 1 h w"), # kornia expects 4D tensor + kernel_size=(5, 5), + sigma=(4, 4), + ).squeeze() + assert single_anomaly_score_exp.reshape(1, -1).shape[-1] == out_size * out_size, ( + "something wrong with the last dimension of reshaped map!" + ) single_map = single_anomaly_score_exp.reshape(1, -1) - single_anomaly_score = np.sort(single_map.topk(top_k, dim=-1)[0].detach().cpu().numpy(), axis=1) - anomaly_score.append(single_anomaly_score) - return anomaly_score, anomaly_map.detach().cpu().numpy() + single_anomaly_score = single_map.topk(top_k).values[0][0] + anomaly_score.append(single_anomaly_score.detach()) + return torch.vstack(anomaly_score), anomaly_map.detach() diff --git a/src/anomalib/models/image/uninet/components/attention_bottleneck.py b/src/anomalib/models/image/uninet/components/attention_bottleneck.py index b36deedc64..9f8eed814a 100644 --- a/src/anomalib/models/image/uninet/components/attention_bottleneck.py +++ b/src/anomalib/models/image/uninet/components/attention_bottleneck.py @@ -260,14 +260,14 @@ def __init__( self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1) self.bn4 = norm_layer(512 * block.expansion) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, nn.BatchNorm2d | nn.GroupNorm): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif self.halve == 2: - m.merge_kernel() + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, nn.BatchNorm2d | nn.GroupNorm): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + elif hasattr(module, "merge_kernel") and self.halve == 2: + module.merge_kernel() def _make_layer( self, diff --git a/src/anomalib/models/image/uninet/components/dfs.py b/src/anomalib/models/image/uninet/components/dfs.py index cfe78c684e..11d7a1fdbc 100644 --- a/src/anomalib/models/image/uninet/components/dfs.py +++ b/src/anomalib/models/image/uninet/components/dfs.py @@ -32,11 +32,11 @@ def __init__(self, num_channels: int = 256) -> None: def _get_theta(self, idx: int) -> torch.Tensor: match idx: - case 0: - return self.theta1 case 1: - return self.theta2 + return self.theta1 case 2: + return self.theta2 + case 3: return self.theta3 case _: msg = f"Invalid index: {idx}" diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py index a1a0cd2971..a27f802041 100644 --- a/src/anomalib/models/image/uninet/lightning_model.py +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -3,15 +3,13 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import copy +from typing import Any import torch -import torchvision from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch import nn from torch.nn import functional -from torchvision.models.feature_extraction import create_feature_extractor +from anomalib import LearningType from anomalib.data import Batch, InferenceBatch from anomalib.models.components import AnomalibModule from anomalib.models.components.backbone import get_decoder @@ -37,27 +35,42 @@ class UniNet(AnomalibModule): def __init__( self, - student_backbone: str = "wide_resent50_2", - teacher_backbone: str = "wide_resent50_2", + student_backbone: str = "wide_resnet50_2", + teacher_backbone: str = "wide_resnet50_2", ) -> None: + super().__init__() self.dfs = DomainRelatedFeatureSelection() self.loss = UniNetLoss(temperature=2.0) self.bce_loss = torch.nn.BCEWithLogitsLoss() - source_teacher = self._get_teacher(teacher_backbone) - target_teacher = copy.deepcopy(source_teacher) self.model = UniNetModel( student=get_decoder(student_backbone), bottleneck=BottleneckLayer(block=AttentionBottleneck, layers=3), - source_teacher=source_teacher, - target_teacher=target_teacher, + teacher_backbone=teacher_backbone, ) + self.automatic_optimization = False def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a training step of UniNet.""" del args, kwargs # These variables are not used. + assert batch.gt_label is not None, "Ground truth label is required for training" + source_target_features, student_features, predictions = self.model(batch.image) student_features = self._feature_selection(source_target_features, student_features) - loss = self._compute_loss(student_features, source_target_features, predictions, batch.gt_label, batch.gt_mask) + loss = self._compute_loss( + student_features, + source_target_features, + predictions, + batch.gt_label.float(), + batch.gt_mask, + ) + + optimizer_1, optimizer_2 = self.optimizers() + optimizer_1.zero_grad() + optimizer_2.zero_grad() + self.manual_backward(loss) + optimizer_1.step() + optimizer_2.step() + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) return {"loss": loss} @@ -68,13 +81,11 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: msg = "Expected batch.image to be a tensor, but got None or non-tensor type" raise ValueError(msg) - source_target_features, student_features, _ = self.model(batch.image) - output_list: list[list[torch.Tensor]] = [[] for _ in range(self.model.num_teachers * 3)] - for idx, (target_feature, student_feature) in enumerate( - zip(source_target_features, student_features, strict=False), - ): - output = 1 - functional.cosine_similarity(target_feature, student_feature) - output_list[idx].append(output) + target_features, student_features, _ = self.model(batch.image) + output_list: list[torch.Tensor] = [] + for target_feature, student_feature in zip(target_features, student_features, strict=True): + output = 1 - functional.cosine_similarity(target_feature, student_feature) # B*64*64 + output_list.append(output) anomaly_score, anomaly_map = weighted_decision_mechanism( batch_size=batch.image.shape[0], @@ -97,29 +108,33 @@ def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optim # TODO(ashwinvaidya17): refactor this return ( torch.optim.AdamW( - self.model.student.parameters() + self.model.bottleneck.parameters() + self.dfs.parameters(), + list(self.model.student.parameters()) + + list(self.model.bottleneck.parameters()) + + list(self.dfs.parameters()), lr=5e-3, betas=(0.9, 0.999), weight_decay=1e-5, ), - torch.optim.AdamW(self.target_teacher.parameters(), lr=1e-6, betas=(0.9, 0.999), weight_decay=1e-5), + torch.optim.AdamW( + self.model.teachers.target_teacher.parameters(), + lr=1e-6, + betas=(0.9, 0.999), + weight_decay=1e-5, + ), ) - @staticmethod - def _get_teacher(backbone: str) -> nn.Module: - """Get the teacher model. - - In the original code, the teacher resnet model is used to extract features from the input image. - We can just use the feature extractor from torchvision to extract the features. + @property + def trainer_arguments(self) -> dict[str, Any]: + """Does not require any trainer arguments.""" + return {} - Args: - backbone (str): The backbone model to use. + @property + def learning_type(self) -> LearningType: + """The model uses one-class learning. - Returns: - nn.Module: The teacher model. + Though technicaly it suppports multi-class and few_shot as well. """ - model = getattr(torchvision.models, backbone)(pretrained=True) - return create_feature_extractor(model, return_nodes=["layer1", "layer2", "layer3"]) + return LearningType.ONE_CLASS def _feature_selection( self, @@ -138,6 +153,7 @@ def _feature_selection( if self.dfs is not None: selected_features = self.dfs(source_features, target_features, maximize=maximize) else: + # TODO(ashwinvaidya17): self.dfs will never be none with current implementation selected_features = domain_related_feature_selection(source_features, target_features, maximize=maximize) return selected_features @@ -145,7 +161,7 @@ def _compute_loss( self, student_features: list[torch.Tensor], teacher_features: list[torch.Tensor], - predictions: torch.Tensor | None = None, + predictions: tuple[torch.Tensor, torch.Tensor] | None = None, label: float | None = None, mask: torch.Tensor | None = None, stop_gradient: bool = False, @@ -155,14 +171,16 @@ def _compute_loss( Args: student_features (list[torch.Tensor]): Student features. teacher_features (list[torch.Tensor]): Teacher features. - predictions (torch.Tensor): Predictions. + predictions (tuple[torch.Tensor, torch.Tensor] | None): Predictions is (B, B) label (float | None): Label for the prediction. - mask (torch.Tensor | None): Mask for the prediction. + mask (torch.Tensor | None): Mask for the prediction. Mask is of shape BxHxW stop_gradient (bool): Whether to stop the gradient into teacher features. + + Returns: + torch.Tensor: Loss. """ - # TODO(ashwinvaidya17): figure out the dimension of predictions and update the docstring if mask is not None: - mask_ = mask + mask_ = mask.unsqueeze(1) # Bx1xHxW else: assert label is not None, "Label is required when mask is not provided" mask_ = label diff --git a/src/anomalib/models/image/uninet/loss.py b/src/anomalib/models/image/uninet/loss.py index b5430a392c..248bab3d14 100644 --- a/src/anomalib/models/image/uninet/loss.py +++ b/src/anomalib/models/image/uninet/loss.py @@ -10,8 +10,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import torch.nn.functional as F from torch import nn +from torch.nn import functional class UniNetLoss(nn.Module): @@ -41,7 +41,7 @@ def forward( student_features (list[torch.Tensor]): Student features. teacher_features (list[torch.Tensor]): Teacher features. margin (int): Hyperparameter for controlling the boundary. - mask (torch.Tensor | None): Mask for the prediction. + mask (torch.Tensor | None): Mask for the prediction. Mask is of shape Bx1xHxW stop_gradient (bool): Whether to stop the gradient into teacher features. """ loss = 0.0 @@ -55,10 +55,14 @@ def forward( student_feature = student_feature.view(n, c, -1).transpose(1, 2) # (N, H+W, C) teacher_feature = teacher_feature.view(n, c, -1).transpose(1, 2) # (N, H+W, C) - student_feature_normalized = F.normalize(student_feature, p=2, dim=2) - teacher_feature_normalized = F.normalize(teacher_feature, p=2, dim=2) + student_feature_normalized = functional.normalize(student_feature, p=2, dim=2) + teacher_feature_normalized = functional.normalize(teacher_feature, p=2, dim=2) - cosine_loss = 1 - F.cosine_similarity(student_feature_normalized, teacher_feature_normalized, dim=2) + cosine_loss = 1 - functional.cosine_similarity( + student_feature_normalized, + teacher_feature_normalized, + dim=2, + ) cosine_loss = cosine_loss.mean() similarity = ( @@ -72,18 +76,18 @@ def forward( # unsupervised and only normal (or abnormal) if mask is None: contrastive_loss = -torch.log(diag_sum + 1e-8).mean() - margin_loss_n = F.relu(margin - diag_sum).mean() + margin_loss_n = functional.relu(margin - diag_sum).mean() # supervised else: # gt label - if len(mask.size()) < 3: + if len(mask.shape) < 3: normal_mask = mask == 0 abnormal_mask = mask == 1 # gt mask else: - mask_ = F.interpolate(mask, size=(h, w), mode="nearest").squeeze(1) - mask_flat = mask.view(mask_.size(0), -1) + mask_ = functional.interpolate(mask, size=(h, w), mode="nearest").squeeze(1) + mask_flat = mask_.view(mask_.size(0), -1) normal_mask = mask_flat == 0 abnormal_mask = mask_flat == 1 @@ -91,11 +95,11 @@ def forward( if normal_mask.sum() > 0: diag_sim_normal = diag_sum[normal_mask] contrastive_loss = -torch.log(diag_sim_normal + 1e-8).mean() - margin_loss_n = F.relu(margin - diag_sim_normal).mean() + margin_loss_n = functional.relu(margin - diag_sim_normal).mean() if abnormal_mask.sum() > 0: diag_sim_abnormal = diag_sum[abnormal_mask] - margin_loss_a = F.relu(diag_sim_abnormal - margin).mean() + margin_loss_a = functional.relu(diag_sim_abnormal - margin).mean() margin_loss = margin_loss_n + margin_loss_a diff --git a/src/anomalib/models/image/uninet/torch_model.py b/src/anomalib/models/image/uninet/torch_model.py index 837fbfe7a6..b37d838406 100644 --- a/src/anomalib/models/image/uninet/torch_model.py +++ b/src/anomalib/models/image/uninet/torch_model.py @@ -15,7 +15,10 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torchvision from torch import nn +from torch.fx import GraphModule +from torchvision.models.feature_extraction import create_feature_extractor class UniNetModel(nn.Module): @@ -34,12 +37,11 @@ def __init__( self, student: nn.Module, bottleneck: nn.Module, - source_teacher: nn.Module, - target_teacher: nn.Module | None = None, + teacher_backbone: str, ) -> None: super().__init__() - self.num_teachers = 1 if target_teacher is None else 2 - self.teachers = Teachers(source_teacher, target_teacher) + self.num_teachers = 2 # in the original code, there is an option to have only one teacher + self.teachers = Teachers(teacher_backbone) self.student = student self.bottleneck = bottleneck @@ -47,14 +49,15 @@ def __init__( self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(256, 1) - def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """Forward pass of the UniNet model. Args: images (torch.Tensor): Input images. Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Source target features, student features, and predictions. + tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: Source target features, + student features, and predictions. """ source_target_features, bottleneck_inputs = self.teachers(images) bottleneck_outputs = self.bottleneck(bottleneck_inputs) @@ -63,7 +66,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tor # These predictions are part of the de_resnet model of the original code. # since we are using the de_resnet model from anomalib, we need to compute predictions here - predictions = self.avgpool(student_features[2]) + predictions = self.avgpool(student_features[0]) predictions = torch.flatten(predictions, 1) predictions = self.fc(predictions).squeeze() @@ -90,10 +93,26 @@ class Teachers(nn.Module): target_teacher (nn.Module | None): Target teacher model. """ - def __init__(self, source_teacher: nn.Module, target_teacher: nn.Module | None = None) -> None: + def __init__(self, teacher_backbone: str) -> None: super().__init__() - self.source_teacher = source_teacher - self.target_teacher = target_teacher + self.source_teacher = self._get_teacher(teacher_backbone).eval() + self.target_teacher = self._get_teacher(teacher_backbone) + + @staticmethod + def _get_teacher(backbone: str) -> GraphModule: + """Get the teacher model. + + In the original code, the teacher resnet model is used to extract features from the input image. + We can just use the feature extractor from torchvision to extract the features. + + Args: + backbone (str): The backbone model to use. + + Returns: + GraphModule: The teacher model. + """ + model = getattr(torchvision.models, backbone)(pretrained=True) + return create_feature_extractor(model, return_nodes=["layer3", "layer2", "layer1"]) def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: """Forward pass of the teachers. @@ -108,13 +127,10 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor with torch.no_grad(): source_features = self.source_teacher(images) - if self.target_teacher is None: - return source_features + target_features = self.target_teacher(images) - with torch.no_grad(): - target_features = self.target_teacher(images) - bottleneck_inputs = [ - torch.cat([a, b], dim=0) for a, b in zip(target_features, source_features, strict=True) - ] # 512, 1024, 2048 + bottleneck_inputs = [ + torch.cat([a, b], dim=0) for a, b in zip(target_features.values(), source_features.values(), strict=True) + ] # 512, 1024, 2048 - return source_features + target_features, bottleneck_inputs + return list(source_features.values()) + list(target_features.values()), bottleneck_inputs From dc7174a65bdc5d43357dd37148ccdca4a89ba724 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 9 Jul 2025 16:08:01 +0200 Subject: [PATCH 04/12] Working model Signed-off-by: Ashwin Vaidya --- examples/configs/model/uninet.yaml | 15 ++ .../models/image/uninet/anomaly_map.py | 22 ++- .../image/uninet/components/__init__.py | 3 +- .../models/image/uninet/components/dfs.py | 41 +---- .../models/image/uninet/lightning_model.py | 166 ++++++------------ src/anomalib/models/image/uninet/loss.py | 2 +- .../models/image/uninet/torch_model.py | 101 ++++++++++- 7 files changed, 175 insertions(+), 175 deletions(-) create mode 100644 examples/configs/model/uninet.yaml diff --git a/examples/configs/model/uninet.yaml b/examples/configs/model/uninet.yaml new file mode 100644 index 0000000000..fb625a06c4 --- /dev/null +++ b/examples/configs/model/uninet.yaml @@ -0,0 +1,15 @@ +model: + class_path: anomalib.models.UniNet + init_args: + student_backbone: wide_resnet50_2 + teacher_backbone: wide_resnet50_2 + temperature: 0.1 + +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 20 + monitor: image_AUROC + mode: max diff --git a/src/anomalib/models/image/uninet/anomaly_map.py b/src/anomalib/models/image/uninet/anomaly_map.py index 9504beec8a..c48da7f8fb 100644 --- a/src/anomalib/models/image/uninet/anomaly_map.py +++ b/src/anomalib/models/image/uninet/anomaly_map.py @@ -10,10 +10,9 @@ # SPDX-License-Identifier: Apache-2.0 import einops -import numpy as np import torch from kornia.filters import gaussian_blur2d -from torch.nn import functional +from torch.nn import functional as F # noqa: N812 def weighted_decision_mechanism( @@ -21,7 +20,7 @@ def weighted_decision_mechanism( output_list: list[list[torch.Tensor]], alpha: float, beta: float, - out_size: int = 25, + output_size: tuple[int, int] = (256, 256), ) -> tuple[torch.Tensor, torch.Tensor]: """Compute anomaly maps using weighted decision mechanism. @@ -30,7 +29,7 @@ def weighted_decision_mechanism( output_list (list[list[torch.Tensor]]): List of output tensors. alpha (float): Alpha parameter. Used for controlling the upper limit beta (float): Beta parameter. Used for controlling the lower limit - out_size (int): Output size. + output_size (tuple[int, int]): Output size. Returns: tuple[torch.Tensor, torch.Tensor]: Anomaly score and anomaly map. @@ -38,12 +37,11 @@ def weighted_decision_mechanism( total_weights_list = [] for i in range(batch_size): low_similarity_list = [torch.max(output_list[j][i]) for j in range(len(output_list))] - probs = functional.softmax(torch.tensor(low_similarity_list), dim=0) + probs = F.softmax(torch.tensor(low_similarity_list), dim=0) weight_list = [] # set P consists of L high probability values, where L ranges from n-1 to n+1 for idx, prob in enumerate(probs): - weight_list.append(low_similarity_list[idx].cpu().numpy()) if prob > torch.mean(probs) else None - # TODO(ashwinvaidya17): see if all the numpy operations can be replaced with torch operations - weight = np.max([np.mean(weight_list) * alpha, beta]) + weight_list.append(low_similarity_list[idx]) if prob > torch.mean(probs) else None + weight = torch.max(torch.tensor([torch.mean(torch.tensor(weight_list)) * alpha, beta])) total_weights_list.append(weight) assert len(total_weights_list) == batch_size, "The number of weights do not match the number of samples" @@ -52,7 +50,7 @@ def weighted_decision_mechanism( for idx, output in enumerate(output_list): anomaly_map_ = torch.unsqueeze(output, dim=1) # Bx1xhxw # Bx256x256 - anomaly_map_lists[idx] = functional.interpolate(anomaly_map_, out_size, mode="bilinear", align_corners=True)[ + anomaly_map_lists[idx] = F.interpolate(anomaly_map_, output_size, mode="bilinear", align_corners=True)[ :, 0, :, @@ -63,8 +61,8 @@ def weighted_decision_mechanism( anomaly_score = [] for idx in range(batch_size): - top_k = int(out_size * out_size * total_weights_list[idx]) - assert top_k >= 1 / (out_size * out_size), "weight can not be smaller than 1 / (H * W)!" + top_k = int(output_size[0] * output_size[1] * total_weights_list[idx]) + assert top_k >= 1 / (output_size[0] * output_size[1]), "weight can not be smaller than 1 / (H * W)!" single_anomaly_score_exp = anomaly_map[idx] single_anomaly_score_exp = gaussian_blur2d( @@ -72,7 +70,7 @@ def weighted_decision_mechanism( kernel_size=(5, 5), sigma=(4, 4), ).squeeze() - assert single_anomaly_score_exp.reshape(1, -1).shape[-1] == out_size * out_size, ( + assert single_anomaly_score_exp.reshape(1, -1).shape[-1] == output_size[0] * output_size[1], ( "something wrong with the last dimension of reshaped map!" ) single_map = single_anomaly_score_exp.reshape(1, -1) diff --git a/src/anomalib/models/image/uninet/components/__init__.py b/src/anomalib/models/image/uninet/components/__init__.py index ea51daeda4..cb8bb50694 100644 --- a/src/anomalib/models/image/uninet/components/__init__.py +++ b/src/anomalib/models/image/uninet/components/__init__.py @@ -4,10 +4,9 @@ # SPDX-License-Identifier: Apache-2.0 from .attention_bottleneck import AttentionBottleneck, BottleneckLayer -from .dfs import DomainRelatedFeatureSelection, domain_related_feature_selection +from .dfs import DomainRelatedFeatureSelection __all__ = [ - "domain_related_feature_selection", "DomainRelatedFeatureSelection", "AttentionBottleneck", "BottleneckLayer", diff --git a/src/anomalib/models/image/uninet/components/dfs.py b/src/anomalib/models/image/uninet/components/dfs.py index 11d7a1fdbc..1b3417cb16 100644 --- a/src/anomalib/models/image/uninet/components/dfs.py +++ b/src/anomalib/models/image/uninet/components/dfs.py @@ -21,14 +21,16 @@ class DomainRelatedFeatureSelection(nn.Module): Args: num_channels (int): Number of channels in the features. Defaults to 256. + learnable (bool): Whether to use learnable theta. Theta controls the domain-related feature selection. """ - def __init__(self, num_channels: int = 256) -> None: + def __init__(self, num_channels: int = 256, learnable: bool = True) -> None: super().__init__() self.num_channels = num_channels self.theta1 = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) self.theta2 = nn.Parameter(torch.zeros(1, num_channels * 2, 1, 1)) self.theta3 = nn.Parameter(torch.zeros(1, num_channels * 4, 1, 1)) + self.learnable = learnable def _get_theta(self, idx: int) -> torch.Tensor: match idx: @@ -46,7 +48,6 @@ def forward( self, source_features: list[torch.Tensor], target_features: list[torch.Tensor], - learnable: bool = True, conv: bool = False, maximize: bool = True, ) -> list[torch.Tensor]: @@ -55,8 +56,6 @@ def forward( Args: source_features (list[torch.Tensor]): Source features. target_features (list[torch.Tensor]): Target features. - learnable (bool): Whether to use learnable theta. Theta controls the domain-related feature selection. - Defaults to True. conv (bool): Whether to use convolutional domain-related feature selection. Defaults to False. maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the @@ -65,11 +64,10 @@ def forward( Returns: list[torch.Tensor]: Domain related features. """ - # TODO(ashwinvaidya17): maybe move learnable to constructor features = [] for idx, (source_feature, target_feature) in enumerate(zip(source_features, target_features, strict=True)): theta = 1 - if learnable: + if self.learnable: # to avoid losing local weight, theta should be as non-zero value as possible if idx < 3: theta = torch.clamp(torch.sigmoid(self._get_theta(idx + 1)) * 1.0 + 0.5, max=1) @@ -96,34 +94,3 @@ def forward( pass return features - - -def domain_related_feature_selection( - source_features: list[torch.Tensor], - target_features: list[torch.Tensor], - maximize: bool = True, -) -> list[torch.Tensor]: - """Domain related feature selection. - - Args: - source_features (list[torch.Tensor]): Source features. - target_features (list[torch.Tensor]): Target features. - maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the - max value from the target feature. Defaults to True. - """ - features_list = [] - theta = 1 - for source_feature, target_feature in zip(source_features, target_features, strict=True): - b, c, h, w = source_feature.shape - prior_flat = target_feature.view(b, c, -1) - if maximize: - prior_flat_ = prior_flat.max(dim=-1, keepdim=True)[0] - prior_flat = prior_flat - prior_flat_ - weights = functional.softmax(prior_flat, dim=-1) - weights = weights.view(b, c, h, w) - - global_inf = target_feature.mean(dim=(-2, -1), keepdim=True) - inter_weights = weights * (theta + global_inf) - source_feature_weighted = source_feature * inter_weights - features_list.append(source_feature_weighted) - return features_list diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py index a27f802041..57e7ff262f 100644 --- a/src/anomalib/models/image/uninet/lightning_model.py +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -7,19 +7,20 @@ import torch from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch.nn import functional +from torch.optim.lr_scheduler import MultiStepLR from anomalib import LearningType -from anomalib.data import Batch, InferenceBatch +from anomalib.data import Batch +from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule from anomalib.models.components.backbone import get_decoder +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer -from .anomaly_map import weighted_decision_mechanism from .components import ( AttentionBottleneck, BottleneckLayer, - DomainRelatedFeatureSelection, - domain_related_feature_selection, ) from .loss import UniNetLoss from .torch_model import UniNetModel @@ -29,47 +30,49 @@ class UniNet(AnomalibModule): """UniNet model for anomaly detection. Args: - student_backbone (str): The backbone model to use for the student. - teacher_backbone (str): The backbone model to use for the teacher. + student_backbone (str): The backbone model to use for the student network. Defaults to "wide_resnet50_2". + teacher_backbone (str): The backbone model to use for the teacher network. Defaults to "wide_resnet50_2". + temperature (float): Temperature parameter used for contrastive loss. Controls the temperature of the student + and teacher similarity computation. Defaults to 0.1. + pre_processor (PreProcessor | bool, optional): Preprocessor instance or bool flag. Defaults to True. + post_processor (PostProcessor | bool, optional): Postprocessor instance or bool flag. Defaults to True. + evaluator (Evaluator | bool, optional): Evaluator instance or bool flag. Defaults to True. + visualizer (Visualizer | bool, optional): Visualizer instance or bool flag. Defaults to True. """ def __init__( self, student_backbone: str = "wide_resnet50_2", teacher_backbone: str = "wide_resnet50_2", + temperature: float = 0.1, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, ) -> None: - super().__init__() - self.dfs = DomainRelatedFeatureSelection() - self.loss = UniNetLoss(temperature=2.0) - self.bce_loss = torch.nn.BCEWithLogitsLoss() + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) + self.loss = UniNetLoss(temperature=temperature) # base class expects self.loss in lightning module self.model = UniNetModel( student=get_decoder(student_backbone), bottleneck=BottleneckLayer(block=AttentionBottleneck, layers=3), teacher_backbone=teacher_backbone, + loss=self.loss, ) - self.automatic_optimization = False def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a training step of UniNet.""" del args, kwargs # These variables are not used. - assert batch.gt_label is not None, "Ground truth label is required for training" - - source_target_features, student_features, predictions = self.model(batch.image) - student_features = self._feature_selection(source_target_features, student_features) - loss = self._compute_loss( - student_features, - source_target_features, - predictions, - batch.gt_label.float(), - batch.gt_mask, - ) - optimizer_1, optimizer_2 = self.optimizers() - optimizer_1.zero_grad() - optimizer_2.zero_grad() - self.manual_backward(loss) - optimizer_1.step() - optimizer_2.step() + loss = self.model( + images=batch.image, + masks=batch.gt_mask, + labels=batch.gt_label, + ) self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) return {"loss": loss} @@ -81,22 +84,7 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: msg = "Expected batch.image to be a tensor, but got None or non-tensor type" raise ValueError(msg) - target_features, student_features, _ = self.model(batch.image) - output_list: list[torch.Tensor] = [] - for target_feature, student_feature in zip(target_features, student_features, strict=True): - output = 1 - functional.cosine_similarity(target_feature, student_feature) # B*64*64 - output_list.append(output) - - anomaly_score, anomaly_map = weighted_decision_mechanism( - batch_size=batch.image.shape[0], - output_list=output_list, - alpha=0.01, - beta=0.00003, - ) - predictions = InferenceBatch( - pred_score=anomaly_score, - anomaly_map=anomaly_map, - ) + predictions = self.model(batch.image) return batch.update(**predictions._asdict()) def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]: @@ -105,23 +93,24 @@ def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optim Returns: tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Optimizers for student and target teacher. """ - # TODO(ashwinvaidya17): refactor this - return ( - torch.optim.AdamW( - list(self.model.student.parameters()) - + list(self.model.bottleneck.parameters()) - + list(self.dfs.parameters()), - lr=5e-3, - betas=(0.9, 0.999), - weight_decay=1e-5, - ), - torch.optim.AdamW( - self.model.teachers.target_teacher.parameters(), - lr=1e-6, - betas=(0.9, 0.999), - weight_decay=1e-5, - ), + optimizer = torch.optim.AdamW( + [ + {"params": self.model.student.parameters()}, + {"params": self.model.bottleneck.parameters()}, + {"params": self.model.dfs.parameters()}, + {"params": self.model.teachers.target_teacher.parameters(), "lr": 1e-6}, + ], + lr=5e-3, + betas=(0.9, 0.999), + weight_decay=1e-5, + eps=1e-10, + amsgrad=True, ) + milestones = [ + int(self.trainer.max_steps * 0.8) if self.trainer.max_steps != -1 else self.trainer.max_epochs * 0.8, + ] + scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.2) + return [optimizer], [scheduler] @property def trainer_arguments(self) -> dict[str, Any]: @@ -135,58 +124,3 @@ def learning_type(self) -> LearningType: Though technicaly it suppports multi-class and few_shot as well. """ return LearningType.ONE_CLASS - - def _feature_selection( - self, - source_features: list[torch.Tensor], - target_features: list[torch.Tensor], - maximize: bool = True, - ) -> list[torch.Tensor]: - """Feature selection. - - Args: - source_features (list[torch.Tensor]): Source features. - target_features (list[torch.Tensor]): Target features. - maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the - max value from the target feature. Defaults to True. - """ - if self.dfs is not None: - selected_features = self.dfs(source_features, target_features, maximize=maximize) - else: - # TODO(ashwinvaidya17): self.dfs will never be none with current implementation - selected_features = domain_related_feature_selection(source_features, target_features, maximize=maximize) - return selected_features - - def _compute_loss( - self, - student_features: list[torch.Tensor], - teacher_features: list[torch.Tensor], - predictions: tuple[torch.Tensor, torch.Tensor] | None = None, - label: float | None = None, - mask: torch.Tensor | None = None, - stop_gradient: bool = False, - ) -> torch.Tensor: - """Compute the loss. - - Args: - student_features (list[torch.Tensor]): Student features. - teacher_features (list[torch.Tensor]): Teacher features. - predictions (tuple[torch.Tensor, torch.Tensor] | None): Predictions is (B, B) - label (float | None): Label for the prediction. - mask (torch.Tensor | None): Mask for the prediction. Mask is of shape BxHxW - stop_gradient (bool): Whether to stop the gradient into teacher features. - - Returns: - torch.Tensor: Loss. - """ - if mask is not None: - mask_ = mask.unsqueeze(1) # Bx1xHxW - else: - assert label is not None, "Label is required when mask is not provided" - mask_ = label - - loss = self.loss(student_features, teacher_features, mask=mask_, stop_gradient=stop_gradient) - if predictions is not None and label is not None: - loss += self.bce_loss(predictions[0], label) + self.bce_loss(predictions[1], label) - - return loss diff --git a/src/anomalib/models/image/uninet/loss.py b/src/anomalib/models/image/uninet/loss.py index 248bab3d14..d2ddaa9d33 100644 --- a/src/anomalib/models/image/uninet/loss.py +++ b/src/anomalib/models/image/uninet/loss.py @@ -99,7 +99,7 @@ def forward( if abnormal_mask.sum() > 0: diag_sim_abnormal = diag_sum[abnormal_mask] - margin_loss_a = functional.relu(diag_sim_abnormal - margin).mean() + margin_loss_a = functional.relu(diag_sim_abnormal - margin / 2).mean() margin_loss = margin_loss_n + margin_loss_a diff --git a/src/anomalib/models/image/uninet/torch_model.py b/src/anomalib/models/image/uninet/torch_model.py index b37d838406..b6c3e654aa 100644 --- a/src/anomalib/models/image/uninet/torch_model.py +++ b/src/anomalib/models/image/uninet/torch_model.py @@ -18,8 +18,14 @@ import torchvision from torch import nn from torch.fx import GraphModule +from torch.nn import functional as F # noqa: N812 from torchvision.models.feature_extraction import create_feature_extractor +from anomalib.data import InferenceBatch + +from .anomaly_map import weighted_decision_mechanism +from .components import DomainRelatedFeatureSelection + class UniNetModel(nn.Module): """UniNet PyTorch model. @@ -38,26 +44,35 @@ def __init__( student: nn.Module, bottleneck: nn.Module, teacher_backbone: str, + loss: nn.Module, ) -> None: super().__init__() - self.num_teachers = 2 # in the original code, there is an option to have only one teacher self.teachers = Teachers(teacher_backbone) self.student = student self.bottleneck = bottleneck + self.dfs = DomainRelatedFeatureSelection() + self.loss = loss + self.bce_loss = torch.nn.BCEWithLogitsLoss() # Used to post-process the student features from the de_resnet model to get the predictions self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(256, 1) - def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + def forward( + self, + images: torch.Tensor, + masks: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + ) -> torch.Tensor | InferenceBatch: """Forward pass of the UniNet model. Args: images (torch.Tensor): Input images. + masks (torch.Tensor | None): Ground truth masks. + labels (torch.Tensor | None): Ground truth labels. Returns: - tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: Source target features, - student features, and predictions. + torch.Tensor | InferenceBatch: Loss or InferenceBatch. """ source_target_features, bottleneck_inputs = self.teachers(images) bottleneck_outputs = self.bottleneck(bottleneck_inputs) @@ -69,6 +84,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tup predictions = self.avgpool(student_features[0]) predictions = torch.flatten(predictions, 1) predictions = self.fc(predictions).squeeze() + predictions = predictions.chunk(dim=0, chunks=2) student_features = [d.chunk(dim=0, chunks=2) for d in student_features] student_features = [ @@ -79,10 +95,82 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tup student_features[1][1], student_features[2][1], ] + if self.training: + student_features = self._feature_selection(source_target_features, student_features) + return self._compute_loss( + student_features, + source_target_features, + predictions, + labels, + masks, + ) + + output_list: list[torch.Tensor] = [] + for target_feature, student_feature in zip(source_target_features, student_features, strict=True): + output = 1 - F.cosine_similarity(target_feature, student_feature) # B*64*64 + output_list.append(output) + + anomaly_score, anomaly_map = weighted_decision_mechanism( + batch_size=images.shape[0], + output_list=output_list, + alpha=0.01, + beta=3e-05, + output_size=images.shape[-2:], + ) + return InferenceBatch( + pred_score=anomaly_score, + anomaly_map=anomaly_map, + ) + + def _compute_loss( + self, + student_features: list[torch.Tensor], + teacher_features: list[torch.Tensor], + predictions: tuple[torch.Tensor, torch.Tensor] | None = None, + label: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + stop_gradient: bool = False, + ) -> torch.Tensor: + """Compute the loss. - predictions = predictions.chunk(dim=0, chunks=2) + Args: + student_features (list[torch.Tensor]): Student features. + teacher_features (list[torch.Tensor]): Teacher features. + predictions (tuple[torch.Tensor, torch.Tensor] | None): Predictions is (B, B) + label (torch.Tensor | None): Label for the prediction. + mask (torch.Tensor | None): Mask for the prediction. Mask is of shape BxHxW + stop_gradient (bool): Whether to stop the gradient into teacher features. + + Returns: + torch.Tensor: Loss. + """ + if mask is not None: + mask_ = mask.float().unsqueeze(1) # Bx1xHxW + else: + assert label is not None, "Label is required when mask is not provided" + mask_ = label.float() + + loss = self.loss(student_features, teacher_features, mask=mask_, stop_gradient=stop_gradient) + if predictions is not None and label is not None: + loss += self.bce_loss(predictions[0], label.float()) + self.bce_loss(predictions[1], label.float()) + + return loss - return source_target_features, student_features, predictions + def _feature_selection( + self, + target_features: list[torch.Tensor], + source_features: list[torch.Tensor], + maximize: bool = True, + ) -> list[torch.Tensor]: + """Feature selection. + + Args: + source_features (list[torch.Tensor]): Source features. + target_features (list[torch.Tensor]): Target features. + maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the + max value from the target feature. Defaults to True. + """ + return self.dfs(source_features, target_features, maximize=maximize) class Teachers(nn.Module): @@ -123,7 +211,6 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor Returns: torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: Source features or source and target features. """ - # TODO(ashwinvaidya17): revisit this method and clean up multiple return statements with torch.no_grad(): source_features = self.source_teacher(images) From aee930d34b7a16a85902404d6ac89136c5ee6196 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 9 Jul 2025 16:13:11 +0200 Subject: [PATCH 05/12] Fix pre-commit Signed-off-by: Ashwin Vaidya --- src/anomalib/models/image/uninet/LICENSE | 2 +- src/anomalib/models/image/uninet/__init__.py | 3 ++- third-party-programs.txt | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/anomalib/models/image/uninet/LICENSE b/src/anomalib/models/image/uninet/LICENSE index 26a7ae014a..5b0371a8cd 100644 --- a/src/anomalib/models/image/uninet/LICENSE +++ b/src/anomalib/models/image/uninet/LICENSE @@ -26,4 +26,4 @@ Original license: AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. \ No newline at end of file + SOFTWARE. diff --git a/src/anomalib/models/image/uninet/__init__.py b/src/anomalib/models/image/uninet/__init__.py index ba86f076d3..7d03336ad2 100644 --- a/src/anomalib/models/image/uninet/__init__.py +++ b/src/anomalib/models/image/uninet/__init__.py @@ -1,6 +1,7 @@ """UniNet Model for anomaly detection. -This module implements anomaly detection using the UniNet model. It is a model designed for diverse domains and is suited for both supervised and unsupervised anomaly detection. It also focuses on multi-class anomaly detection. +This module implements anomaly detection using the UniNet model. It is a model designed for diverse domains and is +suited for both supervised and unsupervised anomaly detection. It also focuses on multi-class anomaly detection. Example: >>> from anomalib.models import UniNet diff --git a/third-party-programs.txt b/third-party-programs.txt index bc11567182..4f71dbc247 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -49,4 +49,4 @@ terms are listed below. 9. UninetModel Copyright (c) 2025 @pangdatangtt, https://github.com/pangdatangtt/UniNet/ - SPDX-License-Identifier: MIT \ No newline at end of file + SPDX-License-Identifier: MIT From 6f40e3157396adec372f6267538bb6c675995133 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Mon, 14 Jul 2025 17:08:40 +0200 Subject: [PATCH 06/12] Minor refactor Signed-off-by: Ashwin Vaidya --- src/anomalib/models/image/uninet/anomaly_map.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/anomalib/models/image/uninet/anomaly_map.py b/src/anomalib/models/image/uninet/anomaly_map.py index c48da7f8fb..82ee783127 100644 --- a/src/anomalib/models/image/uninet/anomaly_map.py +++ b/src/anomalib/models/image/uninet/anomaly_map.py @@ -46,25 +46,20 @@ def weighted_decision_mechanism( assert len(total_weights_list) == batch_size, "The number of weights do not match the number of samples" - anomaly_map_lists: list[list[torch.Tensor]] = [[] for _ in output_list] + anomaly_maps: list[list[torch.Tensor]] = [[] for _ in output_list] for idx, output in enumerate(output_list): anomaly_map_ = torch.unsqueeze(output, dim=1) # Bx1xhxw # Bx256x256 - anomaly_map_lists[idx] = F.interpolate(anomaly_map_, output_size, mode="bilinear", align_corners=True)[ - :, - 0, - :, - :, - ] + anomaly_maps[idx] = F.interpolate(anomaly_map_, output_size, mode="bilinear", align_corners=True)[:, 0, :, :] - anomaly_map: torch.Tensor = sum(anomaly_map_lists) + processed_anomaly_maps: torch.Tensor = sum(anomaly_maps) anomaly_score = [] for idx in range(batch_size): top_k = int(output_size[0] * output_size[1] * total_weights_list[idx]) assert top_k >= 1 / (output_size[0] * output_size[1]), "weight can not be smaller than 1 / (H * W)!" - single_anomaly_score_exp = anomaly_map[idx] + single_anomaly_score_exp = processed_anomaly_maps[idx] single_anomaly_score_exp = gaussian_blur2d( einops.rearrange(single_anomaly_score_exp, "h w -> 1 1 h w"), # kornia expects 4D tensor kernel_size=(5, 5), @@ -76,4 +71,4 @@ def weighted_decision_mechanism( single_map = single_anomaly_score_exp.reshape(1, -1) single_anomaly_score = single_map.topk(top_k).values[0][0] anomaly_score.append(single_anomaly_score.detach()) - return torch.vstack(anomaly_score), anomaly_map.detach() + return torch.vstack(anomaly_score), processed_anomaly_maps.detach() From 55c6bad1c47f2439f06de95905ccefa7a3f58515 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 16 Jul 2025 13:17:58 +0200 Subject: [PATCH 07/12] Refactor components Signed-off-by: Ashwin Vaidya --- src/anomalib/models/image/uninet/components/__init__.py | 4 ++++ .../models/image/uninet/{ => components}/anomaly_map.py | 0 src/anomalib/models/image/uninet/{ => components}/loss.py | 0 src/anomalib/models/image/uninet/lightning_model.py | 2 +- src/anomalib/models/image/uninet/torch_model.py | 3 +-- 5 files changed, 6 insertions(+), 3 deletions(-) rename src/anomalib/models/image/uninet/{ => components}/anomaly_map.py (100%) rename src/anomalib/models/image/uninet/{ => components}/loss.py (100%) diff --git a/src/anomalib/models/image/uninet/components/__init__.py b/src/anomalib/models/image/uninet/components/__init__.py index cb8bb50694..d1a754efad 100644 --- a/src/anomalib/models/image/uninet/components/__init__.py +++ b/src/anomalib/models/image/uninet/components/__init__.py @@ -3,11 +3,15 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from .anomaly_map import weighted_decision_mechanism from .attention_bottleneck import AttentionBottleneck, BottleneckLayer from .dfs import DomainRelatedFeatureSelection +from .loss import UniNetLoss __all__ = [ + "UniNetLoss", "DomainRelatedFeatureSelection", "AttentionBottleneck", "BottleneckLayer", + "weighted_decision_mechanism", ] diff --git a/src/anomalib/models/image/uninet/anomaly_map.py b/src/anomalib/models/image/uninet/components/anomaly_map.py similarity index 100% rename from src/anomalib/models/image/uninet/anomaly_map.py rename to src/anomalib/models/image/uninet/components/anomaly_map.py diff --git a/src/anomalib/models/image/uninet/loss.py b/src/anomalib/models/image/uninet/components/loss.py similarity index 100% rename from src/anomalib/models/image/uninet/loss.py rename to src/anomalib/models/image/uninet/components/loss.py diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py index 57e7ff262f..2147aa3d97 100644 --- a/src/anomalib/models/image/uninet/lightning_model.py +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -21,8 +21,8 @@ from .components import ( AttentionBottleneck, BottleneckLayer, + UniNetLoss, ) -from .loss import UniNetLoss from .torch_model import UniNetModel diff --git a/src/anomalib/models/image/uninet/torch_model.py b/src/anomalib/models/image/uninet/torch_model.py index b6c3e654aa..61ffc76a70 100644 --- a/src/anomalib/models/image/uninet/torch_model.py +++ b/src/anomalib/models/image/uninet/torch_model.py @@ -23,8 +23,7 @@ from anomalib.data import InferenceBatch -from .anomaly_map import weighted_decision_mechanism -from .components import DomainRelatedFeatureSelection +from .components import DomainRelatedFeatureSelection, weighted_decision_mechanism class UniNetModel(nn.Module): From 38041b4dfd19419c0f5124593a004dba7517da56 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 16 Jul 2025 13:20:11 +0200 Subject: [PATCH 08/12] single line Signed-off-by: Ashwin Vaidya --- src/anomalib/models/image/uninet/lightning_model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py index 2147aa3d97..591175a1fb 100644 --- a/src/anomalib/models/image/uninet/lightning_model.py +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -68,11 +68,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: """Perform a training step of UniNet.""" del args, kwargs # These variables are not used. - loss = self.model( - images=batch.image, - masks=batch.gt_mask, - labels=batch.gt_label, - ) + loss = self.model(images=batch.image, masks=batch.gt_mask, labels=batch.gt_label) self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) return {"loss": loss} From 7ff16ffac8b382a1eefd0be9cf944034cdfd4ae4 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 16 Jul 2025 15:54:07 +0200 Subject: [PATCH 09/12] de_resnet->resnet_decoder Signed-off-by: Ashwin Vaidya --- .../{de_resnet.py => resnet_decoder.py} | 42 +++++++++---------- .../models/image/uninet/lightning_model.py | 5 +-- 2 files changed, 22 insertions(+), 25 deletions(-) rename src/anomalib/models/components/backbone/{de_resnet.py => resnet_decoder.py} (94%) diff --git a/src/anomalib/models/components/backbone/de_resnet.py b/src/anomalib/models/components/backbone/resnet_decoder.py similarity index 94% rename from src/anomalib/models/components/backbone/de_resnet.py rename to src/anomalib/models/components/backbone/resnet_decoder.py index d56943c2c8..878f4f8327 100644 --- a/src/anomalib/models/components/backbone/de_resnet.py +++ b/src/anomalib/models/components/backbone/resnet_decoder.py @@ -19,7 +19,7 @@ - Full decoder network architecture Example: - >>> from anomalib.models.components.backbone.de_resnet import ( + >>> from anomalib.models.components.backbone.resnet_decoder import ( ... get_decoder ... ) >>> decoder = get_decoder() @@ -27,7 +27,7 @@ >>> reconstructed = decoder(features) See Also: - - :class:`anomalib.models.components.backbone.de_resnet.DecoderBasicBlock`: + - :class:`anomalib.models.components.backbone.resnet_decoder.DecoderBasicBlock`: Basic building block for the decoder network """ @@ -155,8 +155,8 @@ class DecoderBottleneck(nn.Module): """Bottleneck block for the decoder network. This module implements a bottleneck block used in the decoder part of the Reverse - Distillation model. It performs upsampling and feature reconstruction through a series of - convolutional layers. + Distillation and UniNet models. It performs upsampling and feature reconstruction + through a series of convolutional layers. The block consists of three convolution layers: 1. 1x1 conv to adjust channels @@ -184,7 +184,7 @@ class DecoderBottleneck(nn.Module): Example: >>> import torch - >>> from anomalib.models.image.reverse_distillation.components.de_resnet import ( + >>> from anomalib.models.components.backbone.resnet_decoder import ( ... DecoderBottleneck ... ) >>> layer = DecoderBottleneck(256, 64) @@ -267,7 +267,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: return self.relu(out) -class ResNet(nn.Module): +class ResNetDecoder(nn.Module): """Decoder ResNet model for feature reconstruction. This module implements a decoder version of the ResNet architecture, which @@ -295,11 +295,11 @@ class ResNet(nn.Module): layer to use. If ``None``, uses ``BatchNorm2d``. Defaults to ``None``. Example: - >>> from anomalib.models.image.reverse_distillation.components import ( + >>> from anomalib.models.components.backbone.resnet_decoder import ( ... DecoderBasicBlock, - ... ResNet + ... ResNetDecoder ... ) - >>> model = ResNet( + >>> model = ResNetDecoder( ... block=DecoderBasicBlock, ... layers=[2, 2, 2, 2] ... ) @@ -435,56 +435,56 @@ def forward(self, batch: torch.Tensor) -> list[torch.Tensor]: return [feature_c, feature_b, feature_a] -def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNet: - return ResNet(block, layers, **kwargs) +def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNetDecoder: + return ResNetDecoder(block, layers, **kwargs) -def de_resnet18() -> ResNet: +def de_resnet18() -> ResNetDecoder: """ResNet-18 model.""" return _resnet(DecoderBasicBlock, [2, 2, 2, 2]) -def de_resnet34() -> ResNet: +def de_resnet34() -> ResNetDecoder: """ResNet-34 model.""" return _resnet(DecoderBasicBlock, [3, 4, 6, 3]) -def de_resnet50() -> ResNet: +def de_resnet50() -> ResNetDecoder: """ResNet-50 model.""" return _resnet(DecoderBottleneck, [3, 4, 6, 3]) -def de_resnet101() -> ResNet: +def de_resnet101() -> ResNetDecoder: """ResNet-101 model.""" return _resnet(DecoderBottleneck, [3, 4, 23, 3]) -def de_resnet152() -> ResNet: +def de_resnet152() -> ResNetDecoder: """ResNet-152 model.""" return _resnet(DecoderBottleneck, [3, 8, 36, 3]) -def de_resnext50_32x4d() -> ResNet: +def de_resnext50_32x4d() -> ResNetDecoder: """ResNeXt-50 32x4d model.""" return _resnet(DecoderBottleneck, [3, 4, 6, 3], groups=32, width_per_group=4) -def de_resnext101_32x8d() -> ResNet: +def de_resnext101_32x8d() -> ResNetDecoder: """ResNeXt-101 32x8d model.""" return _resnet(DecoderBottleneck, [3, 4, 23, 3], groups=32, width_per_group=8) -def de_wide_resnet50_2() -> ResNet: +def de_wide_resnet50_2() -> ResNetDecoder: """Wide ResNet-50-2 model.""" return _resnet(DecoderBottleneck, [3, 4, 6, 3], width_per_group=128) -def de_wide_resnet101_2() -> ResNet: +def de_wide_resnet101_2() -> ResNetDecoder: """Wide ResNet-101-2 model.""" return _resnet(DecoderBottleneck, [3, 4, 23, 3], width_per_group=128) -def get_decoder(name: str) -> ResNet: +def get_decoder(name: str) -> ResNetDecoder: """Get decoder model based on the name of the backbone. Args: diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py index 591175a1fb..8317e16ced 100644 --- a/src/anomalib/models/image/uninet/lightning_model.py +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -115,8 +115,5 @@ def trainer_arguments(self) -> dict[str, Any]: @property def learning_type(self) -> LearningType: - """The model uses one-class learning. - - Though technicaly it suppports multi-class and few_shot as well. - """ + """The model uses one-class learning.""" return LearningType.ONE_CLASS From 9e481f8732e2cae244516c24eebe62c58259dd65 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 16 Jul 2025 16:44:00 +0200 Subject: [PATCH 10/12] bug fixes Signed-off-by: Ashwin Vaidya --- src/anomalib/models/components/backbone/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/components/backbone/__init__.py b/src/anomalib/models/components/backbone/__init__.py index 29778f9ff6..e5afe4e026 100644 --- a/src/anomalib/models/components/backbone/__init__.py +++ b/src/anomalib/models/components/backbone/__init__.py @@ -12,6 +12,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .de_resnet import get_decoder +from .resnet_decoder import get_decoder __all__ = ["get_decoder"] From 27458611c3333af9e5b5394937620664ee7bf0a6 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Thu, 17 Jul 2025 14:35:32 +0200 Subject: [PATCH 11/12] openvino export + pr comments Signed-off-by: Ashwin Vaidya --- .../components/backbone/resnet_decoder.py | 2 +- .../image/uninet/components/anomaly_map.py | 95 ++++++++++++------- .../models/image/uninet/lightning_model.py | 2 +- 3 files changed, 64 insertions(+), 35 deletions(-) diff --git a/src/anomalib/models/components/backbone/resnet_decoder.py b/src/anomalib/models/components/backbone/resnet_decoder.py index 878f4f8327..83ead7a859 100644 --- a/src/anomalib/models/components/backbone/resnet_decoder.py +++ b/src/anomalib/models/components/backbone/resnet_decoder.py @@ -491,7 +491,7 @@ def get_decoder(name: str) -> ResNetDecoder: name (str): Name of the backbone. Returns: - ResNet: Decoder ResNet architecture. + ResNetDecoder: Decoder ResNet architecture. """ decoder_map = { "resnet18": de_resnet18, diff --git a/src/anomalib/models/image/uninet/components/anomaly_map.py b/src/anomalib/models/image/uninet/components/anomaly_map.py index 82ee783127..a0ee1bbfaf 100644 --- a/src/anomalib/models/image/uninet/components/anomaly_map.py +++ b/src/anomalib/models/image/uninet/components/anomaly_map.py @@ -11,13 +11,14 @@ import einops import torch -from kornia.filters import gaussian_blur2d from torch.nn import functional as F # noqa: N812 +from anomalib.models.components.filters import GaussianBlur2d + def weighted_decision_mechanism( batch_size: int, - output_list: list[list[torch.Tensor]], + output_list: list[torch.Tensor], alpha: float, beta: float, output_size: tuple[int, int] = (256, 256), @@ -26,7 +27,7 @@ def weighted_decision_mechanism( Args: batch_size (int): Batch size. - output_list (list[list[torch.Tensor]]): List of output tensors. + output_list (list[torch.Tensor]): List of output tensors, each with shape [batch_size, H, W]. alpha (float): Alpha parameter. Used for controlling the upper limit beta (float): Beta parameter. Used for controlling the lower limit output_size (tuple[int, int]): Output size. @@ -34,41 +35,69 @@ def weighted_decision_mechanism( Returns: tuple[torch.Tensor, torch.Tensor]: Anomaly score and anomaly map. """ - total_weights_list = [] + # Convert to tensor operations to avoid SequenceConstruct/ConcatFromSequence + device = output_list[0].device + num_outputs = len(output_list) + + # Pre-allocate tensors instead of using lists + total_weights = torch.zeros(batch_size, device=device) + gaussian_blur = GaussianBlur2d(sigma=4.0, kernel_size=(5, 5), channels=1).to(device) + + # Process each batch item individually for i in range(batch_size): - low_similarity_list = [torch.max(output_list[j][i]) for j in range(len(output_list))] - probs = F.softmax(torch.tensor(low_similarity_list), dim=0) - weight_list = [] # set P consists of L high probability values, where L ranges from n-1 to n+1 - for idx, prob in enumerate(probs): - weight_list.append(low_similarity_list[idx]) if prob > torch.mean(probs) else None - weight = torch.max(torch.tensor([torch.mean(torch.tensor(weight_list)) * alpha, beta])) - total_weights_list.append(weight) + # Get max value from each output for this batch item + # Create tensor directly from max values to avoid list operations + max_values = torch.zeros(num_outputs, device=device) + for j, output_tensor in enumerate(output_list): + max_values[j] = torch.max(output_tensor[i]) + + probs = F.softmax(max_values, dim=0) + + # Use tensor operations instead of list filtering + prob_mean = torch.mean(probs) + mask = probs > prob_mean - assert len(total_weights_list) == batch_size, "The number of weights do not match the number of samples" + if mask.any(): + weight_tensor = max_values[mask] + weight = torch.max(torch.stack([torch.mean(weight_tensor) * alpha, torch.tensor(beta, device=device)])) + else: + weight = torch.tensor(beta, device=device) - anomaly_maps: list[list[torch.Tensor]] = [[] for _ in output_list] - for idx, output in enumerate(output_list): - anomaly_map_ = torch.unsqueeze(output, dim=1) # Bx1xhxw - # Bx256x256 - anomaly_maps[idx] = F.interpolate(anomaly_map_, output_size, mode="bilinear", align_corners=True)[:, 0, :, :] + total_weights[i] = weight - processed_anomaly_maps: torch.Tensor = sum(anomaly_maps) + # Process anomaly maps using tensor operations + # Pre-allocate the processed anomaly maps tensor + processed_anomaly_maps = torch.zeros(batch_size, *output_size, device=device) + + # Process each output tensor separately due to different spatial dimensions + for output_tensor in output_list: + # Interpolate current output to target size + # Add channel dimension for interpolation: [batch_size, H, W] -> [batch_size, 1, H, W] + output_resized = F.interpolate( + output_tensor.unsqueeze(1), + output_size, + mode="bilinear", + align_corners=True, + ).squeeze(1) # [batch_size, H_out, W_out] + + # Add to accumulated anomaly maps + processed_anomaly_maps += output_resized + + # Pre-allocate anomaly scores tensor instead of using list + anomaly_scores = torch.zeros(batch_size, device=device) - anomaly_score = [] for idx in range(batch_size): - top_k = int(output_size[0] * output_size[1] * total_weights_list[idx]) - assert top_k >= 1 / (output_size[0] * output_size[1]), "weight can not be smaller than 1 / (H * W)!" + top_k = int(output_size[0] * output_size[1] * total_weights[idx]) + top_k = max(top_k, 1) # Ensure at least 1 element single_anomaly_score_exp = processed_anomaly_maps[idx] - single_anomaly_score_exp = gaussian_blur2d( - einops.rearrange(single_anomaly_score_exp, "h w -> 1 1 h w"), # kornia expects 4D tensor - kernel_size=(5, 5), - sigma=(4, 4), - ).squeeze() - assert single_anomaly_score_exp.reshape(1, -1).shape[-1] == output_size[0] * output_size[1], ( - "something wrong with the last dimension of reshaped map!" - ) - single_map = single_anomaly_score_exp.reshape(1, -1) - single_anomaly_score = single_map.topk(top_k).values[0][0] - anomaly_score.append(single_anomaly_score.detach()) - return torch.vstack(anomaly_score), processed_anomaly_maps.detach() + single_anomaly_score_exp = gaussian_blur(einops.rearrange(single_anomaly_score_exp, "h w -> 1 1 h w")) + single_anomaly_score_exp = single_anomaly_score_exp.squeeze() + + # Flatten and get top-k values + single_map_flat = single_anomaly_score_exp.view(-1) + top_k_values = torch.topk(single_map_flat, top_k).values + single_anomaly_score = top_k_values[0] if len(top_k_values) > 0 else torch.tensor(0.0, device=device) + anomaly_scores[idx] = single_anomaly_score.detach() + + return anomaly_scores.unsqueeze(1), processed_anomaly_maps.detach() diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py index 8317e16ced..3cb7f0579c 100644 --- a/src/anomalib/models/image/uninet/lightning_model.py +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -103,7 +103,7 @@ def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optim amsgrad=True, ) milestones = [ - int(self.trainer.max_steps * 0.8) if self.trainer.max_steps != -1 else self.trainer.max_epochs * 0.8, + int(self.trainer.max_steps * 0.8) if self.trainer.max_steps != -1 else (self.trainer.max_epochs * 0.8), ] scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.2) return [optimizer], [scheduler] From 348b7eee09fccdfa939f8c934930981cdf72dc22 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Thu, 17 Jul 2025 14:51:32 +0200 Subject: [PATCH 12/12] Add documentation Signed-off-by: Ashwin Vaidya --- .../guides/reference/models/image/uninet.md | 37 +++++++++++ src/anomalib/models/image/uninet/README.md | 65 +++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 docs/source/markdown/guides/reference/models/image/uninet.md create mode 100644 src/anomalib/models/image/uninet/README.md diff --git a/docs/source/markdown/guides/reference/models/image/uninet.md b/docs/source/markdown/guides/reference/models/image/uninet.md new file mode 100644 index 0000000000..c46bf0a665 --- /dev/null +++ b/docs/source/markdown/guides/reference/models/image/uninet.md @@ -0,0 +1,37 @@ +# UniNet + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.lightning_model + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.torch_model + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.loss + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.anomaly_map + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.attention_bottleneck + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.dfs + :members: + :show-inheritance: +``` diff --git a/src/anomalib/models/image/uninet/README.md b/src/anomalib/models/image/uninet/README.md new file mode 100644 index 0000000000..dbf11256c9 --- /dev/null +++ b/src/anomalib/models/image/uninet/README.md @@ -0,0 +1,65 @@ +# UniNet: Unified Contrastive Learning Framework for Anomaly Detection + +This is the implementation of the UniNet model, a unified contrastive learning framework for anomaly detection presented in CVPR 2025. + +Model Type: Classification and Segmentation + +## Description + +UniNet is a contrastive learning-based anomaly detection model that uses a teacher-student architecture with attention bottleneck mechanisms. The model is designed for diverse domains and supports both supervised and unsupervised anomaly detection scenarios. It focuses on multi-class anomaly detection and leverages domain-related feature selection to improve performance across different categories. + +The model consists of: + +- **Teacher Networks**: Pre-trained backbone networks that provide reference features for normal samples +- **Student Network**: A decoder network that learns to reconstruct normal patterns +- **Attention Bottleneck**: Mechanisms that help focus on relevant features +- **Domain-Related Feature Selection**: Adaptive feature selection for different domains +- **Contrastive Loss**: Temperature-controlled similarity computation between student and teacher features + +During training, the student network learns to match teacher features for normal samples while being trained to distinguish anomalous patterns through contrastive learning. The model uses a weighted decision mechanism during inference to combine multi-scale features for final anomaly scoring. + +## Architecture + +The UniNet architecture leverages contrastive learning with teacher-student networks, incorporating attention bottlenecks and domain-specific feature selection for robust anomaly detection across diverse domains. + +## Usage + +`anomalib train --model UniNet --data MVTecAD --data.category ` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTecAD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper | +| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: | +| 0.956 | 0.999 | 0.982 | 0.939 | 0.896 | 0.996 | 0.999 | 1.000 | 1.000 | 0.816 | 0.919 | 0.970 | 1.000 | 0.984 | 0.993 | 0.945 | + +### Pixel-Level AUC + +| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper | +| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: | +| 0.976 | 0.989 | 0.983 | 0.985 | 0.973 | 0.992 | 0.987 | 0.993 | 0.984 | 0.964 | 0.992 | 0.965 | 0.992 | 0.923 | 0.961 | 0.984 | + +### Image F1 Score + +| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper | +| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: | +| 0.957 | 0.984 | 0.944 | 0.964 | 0.883 | 0.973 | 0.986 | 0.995 | 0.989 | 0.921 | 0.905 | 0.946 | 0.983 | 0.961 | 0.974 | 0.959 | + +## Model Features + +- **Contrastive Learning**: Uses temperature-controlled contrastive loss for effective anomaly detection +- **Multi-Domain Support**: Designed to work across diverse domains with domain-related feature selection +- **Flexible Training**: Supports both supervised and unsupervised training modes +- **Attention Mechanisms**: Incorporates attention bottlenecks for focused feature learning +- **Multi-Scale Features**: Leverages multi-scale feature matching for robust detection + +## Parameters + +- `student_backbone`: Backbone model for student network (default: "wide_resnet50_2") +- `teacher_backbone`: Backbone model for teacher network (default: "wide_resnet50_2") +- `temperature`: Temperature parameter for contrastive loss (default: 0.1)