diff --git a/docs/source/markdown/guides/reference/models/image/uninet.md b/docs/source/markdown/guides/reference/models/image/uninet.md new file mode 100644 index 0000000000..c46bf0a665 --- /dev/null +++ b/docs/source/markdown/guides/reference/models/image/uninet.md @@ -0,0 +1,37 @@ +# UniNet + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.lightning_model + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.torch_model + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.loss + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.anomaly_map + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.attention_bottleneck + :members: + :show-inheritance: +``` + +```{eval-rst} +.. automodule:: anomalib.models.image.uninet.components.dfs + :members: + :show-inheritance: +``` diff --git a/examples/configs/model/uninet.yaml b/examples/configs/model/uninet.yaml new file mode 100644 index 0000000000..fb625a06c4 --- /dev/null +++ b/examples/configs/model/uninet.yaml @@ -0,0 +1,15 @@ +model: + class_path: anomalib.models.UniNet + init_args: + student_backbone: wide_resnet50_2 + teacher_backbone: wide_resnet50_2 + temperature: 0.1 + +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 20 + monitor: image_AUROC + mode: max diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 43509e3ef0..3443285d7b 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -72,6 +72,7 @@ Stfpm, Supersimplenet, Uflow, + UniNet, VlmAd, WinClip, ) @@ -108,6 +109,7 @@ class UnknownModelError(ModuleNotFoundError): "Stfpm", "Supersimplenet", "Uflow", + "UniNet", "VlmAd", "WinClip", "AiVad", diff --git a/src/anomalib/models/components/backbone/__init__.py b/src/anomalib/models/components/backbone/__init__.py new file mode 100644 index 0000000000..e5afe4e026 --- /dev/null +++ b/src/anomalib/models/components/backbone/__init__.py @@ -0,0 +1,17 @@ +"""Common backbone models. + +Example: + >>> from anomalib.models.components.backbone import get_decoder + >>> decoder = get_decoder() + +See Also: + - :func:`anomalib.models.components.backbone.de_resnet`: + Decoder network implementation +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .resnet_decoder import get_decoder + +__all__ = ["get_decoder"] diff --git a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py b/src/anomalib/models/components/backbone/resnet_decoder.py similarity index 93% rename from src/anomalib/models/image/reverse_distillation/components/de_resnet.py rename to src/anomalib/models/components/backbone/resnet_decoder.py index 127dee710c..83ead7a859 100644 --- a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py +++ b/src/anomalib/models/components/backbone/resnet_decoder.py @@ -7,9 +7,9 @@ # Copyright (C) 2022-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""PyTorch model defining the decoder network for Reverse Distillation. +"""PyTorch model defining the decoder network for Reverse Distillation and UniNet. -This module implements the decoder network used in the Reverse Distillation model +This module implements the decoder network used in the Reverse Distillation and UniNet models architecture. The decoder reconstructs features from the bottleneck representation back to the original feature space. @@ -19,7 +19,7 @@ - Full decoder network architecture Example: - >>> from anomalib.models.image.reverse_distillation.components.de_resnet import ( + >>> from anomalib.models.components.backbone.resnet_decoder import ( ... get_decoder ... ) >>> decoder = get_decoder() @@ -27,9 +27,7 @@ >>> reconstructed = decoder(features) See Also: - - :class:`anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel`: - Main model implementation using this decoder - - :class:`anomalib.models.image.reverse_distillation.components.DecoderBasicBlock`: + - :class:`anomalib.models.components.backbone.resnet_decoder.DecoderBasicBlock`: Basic building block for the decoder network """ @@ -157,8 +155,8 @@ class DecoderBottleneck(nn.Module): """Bottleneck block for the decoder network. This module implements a bottleneck block used in the decoder part of the Reverse - Distillation model. It performs upsampling and feature reconstruction through a series of - convolutional layers. + Distillation and UniNet models. It performs upsampling and feature reconstruction + through a series of convolutional layers. The block consists of three convolution layers: 1. 1x1 conv to adjust channels @@ -186,7 +184,7 @@ class DecoderBottleneck(nn.Module): Example: >>> import torch - >>> from anomalib.models.image.reverse_distillation.components.de_resnet import ( + >>> from anomalib.models.components.backbone.resnet_decoder import ( ... DecoderBottleneck ... ) >>> layer = DecoderBottleneck(256, 64) @@ -269,7 +267,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: return self.relu(out) -class ResNet(nn.Module): +class ResNetDecoder(nn.Module): """Decoder ResNet model for feature reconstruction. This module implements a decoder version of the ResNet architecture, which @@ -297,11 +295,11 @@ class ResNet(nn.Module): layer to use. If ``None``, uses ``BatchNorm2d``. Defaults to ``None``. Example: - >>> from anomalib.models.image.reverse_distillation.components import ( + >>> from anomalib.models.components.backbone.resnet_decoder import ( ... DecoderBasicBlock, - ... ResNet + ... ResNetDecoder ... ) - >>> model = ResNet( + >>> model = ResNetDecoder( ... block=DecoderBasicBlock, ... layers=[2, 2, 2, 2] ... ) @@ -437,63 +435,63 @@ def forward(self, batch: torch.Tensor) -> list[torch.Tensor]: return [feature_c, feature_b, feature_a] -def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNet: - return ResNet(block, layers, **kwargs) +def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNetDecoder: + return ResNetDecoder(block, layers, **kwargs) -def de_resnet18() -> ResNet: +def de_resnet18() -> ResNetDecoder: """ResNet-18 model.""" return _resnet(DecoderBasicBlock, [2, 2, 2, 2]) -def de_resnet34() -> ResNet: +def de_resnet34() -> ResNetDecoder: """ResNet-34 model.""" return _resnet(DecoderBasicBlock, [3, 4, 6, 3]) -def de_resnet50() -> ResNet: +def de_resnet50() -> ResNetDecoder: """ResNet-50 model.""" return _resnet(DecoderBottleneck, [3, 4, 6, 3]) -def de_resnet101() -> ResNet: +def de_resnet101() -> ResNetDecoder: """ResNet-101 model.""" return _resnet(DecoderBottleneck, [3, 4, 23, 3]) -def de_resnet152() -> ResNet: +def de_resnet152() -> ResNetDecoder: """ResNet-152 model.""" return _resnet(DecoderBottleneck, [3, 8, 36, 3]) -def de_resnext50_32x4d() -> ResNet: +def de_resnext50_32x4d() -> ResNetDecoder: """ResNeXt-50 32x4d model.""" return _resnet(DecoderBottleneck, [3, 4, 6, 3], groups=32, width_per_group=4) -def de_resnext101_32x8d() -> ResNet: +def de_resnext101_32x8d() -> ResNetDecoder: """ResNeXt-101 32x8d model.""" return _resnet(DecoderBottleneck, [3, 4, 23, 3], groups=32, width_per_group=8) -def de_wide_resnet50_2() -> ResNet: +def de_wide_resnet50_2() -> ResNetDecoder: """Wide ResNet-50-2 model.""" return _resnet(DecoderBottleneck, [3, 4, 6, 3], width_per_group=128) -def de_wide_resnet101_2() -> ResNet: +def de_wide_resnet101_2() -> ResNetDecoder: """Wide ResNet-101-2 model.""" return _resnet(DecoderBottleneck, [3, 4, 23, 3], width_per_group=128) -def get_decoder(name: str) -> ResNet: +def get_decoder(name: str) -> ResNetDecoder: """Get decoder model based on the name of the backbone. Args: name (str): Name of the backbone. Returns: - ResNet: Decoder ResNet architecture. + ResNetDecoder: Decoder ResNet architecture. """ decoder_map = { "resnet18": de_resnet18, diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 9dec63dd09..09460f2c78 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -61,6 +61,7 @@ from .stfpm import Stfpm from .supersimplenet import Supersimplenet from .uflow import Uflow +from .uninet import UniNet from .vlm_ad import VlmAd from .winclip import WinClip @@ -82,6 +83,7 @@ "Stfpm", "Supersimplenet", "Uflow", + "UniNet", "VlmAd", "WinClip", ] diff --git a/src/anomalib/models/image/reverse_distillation/components/__init__.py b/src/anomalib/models/image/reverse_distillation/components/__init__.py index 8ef34a687a..e24fd734b9 100644 --- a/src/anomalib/models/image/reverse_distillation/components/__init__.py +++ b/src/anomalib/models/image/reverse_distillation/components/__init__.py @@ -10,7 +10,6 @@ through distillation and reconstruction: - Bottleneck layer: Compresses features into a lower dimensional space -- Decoder network: Reconstructs features from the bottleneck representation Example: >>> from anomalib.models.image.reverse_distillation.components import ( @@ -23,11 +22,8 @@ See Also: - :func:`anomalib.models.image.reverse_distillation.components.bottleneck`: Bottleneck layer implementation - - :func:`anomalib.models.image.reverse_distillation.components.de_resnet`: - Decoder network implementation """ from .bottleneck import get_bottleneck_layer -from .de_resnet import get_decoder -__all__ = ["get_bottleneck_layer", "get_decoder"] +__all__ = ["get_bottleneck_layer"] diff --git a/src/anomalib/models/image/reverse_distillation/torch_model.py b/src/anomalib/models/image/reverse_distillation/torch_model.py index 264ae4b93d..31cf3fdcd8 100644 --- a/src/anomalib/models/image/reverse_distillation/torch_model.py +++ b/src/anomalib/models/image/reverse_distillation/torch_model.py @@ -39,9 +39,10 @@ from anomalib.data import InferenceBatch from anomalib.models.components import TimmFeatureExtractor +from anomalib.models.components.backbone import get_decoder from .anomaly_map import AnomalyMapGenerationMode, AnomalyMapGenerator -from .components import get_bottleneck_layer, get_decoder +from .components import get_bottleneck_layer if TYPE_CHECKING: from anomalib.data.utils.tiler import Tiler diff --git a/src/anomalib/models/image/uninet/LICENSE b/src/anomalib/models/image/uninet/LICENSE new file mode 100644 index 0000000000..5b0371a8cd --- /dev/null +++ b/src/anomalib/models/image/uninet/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2025 Intel Corporation +SPDX-License-Identifier: Apache-2.0 + +Some files in this folder are based on the original UniNet implementation by Shun Wei, Jielin Jiang, and Xiaolong Xu. + +Original license: +----------------- + + MIT License + + Copyright (c) 2025 Shun Wei + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/src/anomalib/models/image/uninet/README.md b/src/anomalib/models/image/uninet/README.md new file mode 100644 index 0000000000..dbf11256c9 --- /dev/null +++ b/src/anomalib/models/image/uninet/README.md @@ -0,0 +1,65 @@ +# UniNet: Unified Contrastive Learning Framework for Anomaly Detection + +This is the implementation of the UniNet model, a unified contrastive learning framework for anomaly detection presented in CVPR 2025. + +Model Type: Classification and Segmentation + +## Description + +UniNet is a contrastive learning-based anomaly detection model that uses a teacher-student architecture with attention bottleneck mechanisms. The model is designed for diverse domains and supports both supervised and unsupervised anomaly detection scenarios. It focuses on multi-class anomaly detection and leverages domain-related feature selection to improve performance across different categories. + +The model consists of: + +- **Teacher Networks**: Pre-trained backbone networks that provide reference features for normal samples +- **Student Network**: A decoder network that learns to reconstruct normal patterns +- **Attention Bottleneck**: Mechanisms that help focus on relevant features +- **Domain-Related Feature Selection**: Adaptive feature selection for different domains +- **Contrastive Loss**: Temperature-controlled similarity computation between student and teacher features + +During training, the student network learns to match teacher features for normal samples while being trained to distinguish anomalous patterns through contrastive learning. The model uses a weighted decision mechanism during inference to combine multi-scale features for final anomaly scoring. + +## Architecture + +The UniNet architecture leverages contrastive learning with teacher-student networks, incorporating attention bottlenecks and domain-specific feature selection for robust anomaly detection across diverse domains. + +## Usage + +`anomalib train --model UniNet --data MVTecAD --data.category ` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTecAD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper | +| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: | +| 0.956 | 0.999 | 0.982 | 0.939 | 0.896 | 0.996 | 0.999 | 1.000 | 1.000 | 0.816 | 0.919 | 0.970 | 1.000 | 0.984 | 0.993 | 0.945 | + +### Pixel-Level AUC + +| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper | +| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: | +| 0.976 | 0.989 | 0.983 | 0.985 | 0.973 | 0.992 | 0.987 | 0.993 | 0.984 | 0.964 | 0.992 | 0.965 | 0.992 | 0.923 | 0.961 | 0.984 | + +### Image F1 Score + +| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper | +| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: | +| 0.957 | 0.984 | 0.944 | 0.964 | 0.883 | 0.973 | 0.986 | 0.995 | 0.989 | 0.921 | 0.905 | 0.946 | 0.983 | 0.961 | 0.974 | 0.959 | + +## Model Features + +- **Contrastive Learning**: Uses temperature-controlled contrastive loss for effective anomaly detection +- **Multi-Domain Support**: Designed to work across diverse domains with domain-related feature selection +- **Flexible Training**: Supports both supervised and unsupervised training modes +- **Attention Mechanisms**: Incorporates attention bottlenecks for focused feature learning +- **Multi-Scale Features**: Leverages multi-scale feature matching for robust detection + +## Parameters + +- `student_backbone`: Backbone model for student network (default: "wide_resnet50_2") +- `teacher_backbone`: Backbone model for teacher network (default: "wide_resnet50_2") +- `temperature`: Temperature parameter for contrastive loss (default: 0.1) diff --git a/src/anomalib/models/image/uninet/__init__.py b/src/anomalib/models/image/uninet/__init__.py new file mode 100644 index 0000000000..7d03336ad2 --- /dev/null +++ b/src/anomalib/models/image/uninet/__init__.py @@ -0,0 +1,33 @@ +"""UniNet Model for anomaly detection. + +This module implements anomaly detection using the UniNet model. It is a model designed for diverse domains and is +suited for both supervised and unsupervised anomaly detection. It also focuses on multi-class anomaly detection. + +Example: + >>> from anomalib.models import UniNet + >>> from anomalib.data import MVTecAD + >>> from anomalib.engine import Engine + + >>> # Initialize model and data + >>> datamodule = MVTecAD() + >>> model = UniNet() + + >>> # Train the model + >>> engine = Engine() + >>> engine.train(model=model, datamodule=datamodule) + + >>> # Get predictions + >>> engine.predict(model=model, datamodule=datamodule) + +See Also: + - :class:`UniNet`: Main model class for UniNet-based anomaly detection + - :class:`UniNetModel`: PyTorch implementation of the UniNet model +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import UniNet +from .torch_model import UniNetModel + +__all__ = ["UniNet", "UniNetModel"] diff --git a/src/anomalib/models/image/uninet/components/__init__.py b/src/anomalib/models/image/uninet/components/__init__.py new file mode 100644 index 0000000000..d1a754efad --- /dev/null +++ b/src/anomalib/models/image/uninet/components/__init__.py @@ -0,0 +1,17 @@ +"""PyTorch modules for the UniNet model implementation.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .anomaly_map import weighted_decision_mechanism +from .attention_bottleneck import AttentionBottleneck, BottleneckLayer +from .dfs import DomainRelatedFeatureSelection +from .loss import UniNetLoss + +__all__ = [ + "UniNetLoss", + "DomainRelatedFeatureSelection", + "AttentionBottleneck", + "BottleneckLayer", + "weighted_decision_mechanism", +] diff --git a/src/anomalib/models/image/uninet/components/anomaly_map.py b/src/anomalib/models/image/uninet/components/anomaly_map.py new file mode 100644 index 0000000000..a0ee1bbfaf --- /dev/null +++ b/src/anomalib/models/image/uninet/components/anomaly_map.py @@ -0,0 +1,103 @@ +"""Utilities for computing anomaly maps.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import einops +import torch +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components.filters import GaussianBlur2d + + +def weighted_decision_mechanism( + batch_size: int, + output_list: list[torch.Tensor], + alpha: float, + beta: float, + output_size: tuple[int, int] = (256, 256), +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute anomaly maps using weighted decision mechanism. + + Args: + batch_size (int): Batch size. + output_list (list[torch.Tensor]): List of output tensors, each with shape [batch_size, H, W]. + alpha (float): Alpha parameter. Used for controlling the upper limit + beta (float): Beta parameter. Used for controlling the lower limit + output_size (tuple[int, int]): Output size. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Anomaly score and anomaly map. + """ + # Convert to tensor operations to avoid SequenceConstruct/ConcatFromSequence + device = output_list[0].device + num_outputs = len(output_list) + + # Pre-allocate tensors instead of using lists + total_weights = torch.zeros(batch_size, device=device) + gaussian_blur = GaussianBlur2d(sigma=4.0, kernel_size=(5, 5), channels=1).to(device) + + # Process each batch item individually + for i in range(batch_size): + # Get max value from each output for this batch item + # Create tensor directly from max values to avoid list operations + max_values = torch.zeros(num_outputs, device=device) + for j, output_tensor in enumerate(output_list): + max_values[j] = torch.max(output_tensor[i]) + + probs = F.softmax(max_values, dim=0) + + # Use tensor operations instead of list filtering + prob_mean = torch.mean(probs) + mask = probs > prob_mean + + if mask.any(): + weight_tensor = max_values[mask] + weight = torch.max(torch.stack([torch.mean(weight_tensor) * alpha, torch.tensor(beta, device=device)])) + else: + weight = torch.tensor(beta, device=device) + + total_weights[i] = weight + + # Process anomaly maps using tensor operations + # Pre-allocate the processed anomaly maps tensor + processed_anomaly_maps = torch.zeros(batch_size, *output_size, device=device) + + # Process each output tensor separately due to different spatial dimensions + for output_tensor in output_list: + # Interpolate current output to target size + # Add channel dimension for interpolation: [batch_size, H, W] -> [batch_size, 1, H, W] + output_resized = F.interpolate( + output_tensor.unsqueeze(1), + output_size, + mode="bilinear", + align_corners=True, + ).squeeze(1) # [batch_size, H_out, W_out] + + # Add to accumulated anomaly maps + processed_anomaly_maps += output_resized + + # Pre-allocate anomaly scores tensor instead of using list + anomaly_scores = torch.zeros(batch_size, device=device) + + for idx in range(batch_size): + top_k = int(output_size[0] * output_size[1] * total_weights[idx]) + top_k = max(top_k, 1) # Ensure at least 1 element + + single_anomaly_score_exp = processed_anomaly_maps[idx] + single_anomaly_score_exp = gaussian_blur(einops.rearrange(single_anomaly_score_exp, "h w -> 1 1 h w")) + single_anomaly_score_exp = single_anomaly_score_exp.squeeze() + + # Flatten and get top-k values + single_map_flat = single_anomaly_score_exp.view(-1) + top_k_values = torch.topk(single_map_flat, top_k).values + single_anomaly_score = top_k_values[0] if len(top_k_values) > 0 else torch.tensor(0.0, device=device) + anomaly_scores[idx] = single_anomaly_score.detach() + + return anomaly_scores.unsqueeze(1), processed_anomaly_maps.detach() diff --git a/src/anomalib/models/image/uninet/components/attention_bottleneck.py b/src/anomalib/models/image/uninet/components/attention_bottleneck.py new file mode 100644 index 0000000000..9f8eed814a --- /dev/null +++ b/src/anomalib/models/image/uninet/components/attention_bottleneck.py @@ -0,0 +1,343 @@ +"""Attention Bottleneck for UniNet.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Callable + +import torch +from torch import nn + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding. + + Args: + in_planes (int): Number of input planes. + out_planes (int): Number of output planes. + stride (int): Stride of the convolution. + groups (int): Number of groups. + dilation (int): Dilation rate. + + Returns: + nn.Conv2d: 3x3 convolution with padding. + """ + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution. + + Args: + in_planes (int): Number of input planes. + out_planes (int): Number of output planes. + stride (int): Stride of the convolution. + """ + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +def fuse_bn(conv: nn.Module, bn: nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + """Fuse convolution and batch normalization layers. + + Args: + conv (nn.Module): Convolution layer. + bn (nn.Module): Batch normalization layer. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Fused convolution and batch normalization layers. + """ + kernel = conv.weight + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class AttentionBottleneck(nn.Module): + """Attention Bottleneck for UniNet.""" + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: nn.Module | None = None, + groups: int = 1, + base_width: int = 64, + norm_layer: Callable[..., nn.Module] | None = None, + attention: bool = True, + halve: int = 1, + ) -> None: + super().__init__() + self.attention = attention + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups # 512 + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.halve = halve + k = 7 + p = 3 + + self.bn2 = norm_layer(width // halve) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + + self.downsample = downsample + self.stride = stride + + self.bn4 = norm_layer(width // 2) + self.bn5 = norm_layer(width // 2) + self.bn6 = norm_layer(width // 2) + self.bn7 = norm_layer(width) + self.conv3x3 = nn.Conv2d(inplanes // 2, width // 2, kernel_size=3, stride=stride, padding=1, bias=False) + self.conv3x3_ = nn.Conv2d(width // 2, width // 2, 3, 1, 1, bias=False) + self.conv7x7 = nn.Conv2d(inplanes // 2, width // 2, kernel_size=k, stride=stride, padding=p, bias=False) + self.conv7x7_ = nn.Conv2d(width // 2, width // 2, k, 1, p, bias=False) + + def get_same_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get same kernel and bias of the bottleneck.""" + k1, b1 = fuse_bn(self.conv3x3, self.bn2) + k2, b2 = fuse_bn(self.conv3x3_, self.bn6) + + return k1, b1, k2, b2 + + def merge_kernel(self) -> None: + """Merge kernel of the bottleneck.""" + k1, b1, k2, b2 = self.get_same_kernel_bias() + self.conv7x7 = nn.Conv2d( + self.conv3x3.in_channels, + self.conv3x3.out_channels, + self.conv3x3.kernel_size, + self.conv3x3.stride, + self.conv3x3.padding, + self.conv3x3.dilation, + self.conv3x3.groups, + ) + self.conv7x7_ = nn.Conv2d( + self.conv3x3_.in_channels, + self.conv3x3_.out_channels, + self.conv3x3_.kernel_size, + self.conv3x3_.stride, + self.conv3x3_.padding, + self.conv3x3_.dilation, + self.conv3x3_.groups, + ) + self.conv7x7.weight.data = k1 + self.conv7x7.bias.data = b1 + self.conv7x7_.weight.data = k2 + self.conv7x7_.bias.data = b2 + + @staticmethod + def _process_branch( + branch: torch.Tensor, + conv1: nn.Module, + bn1: nn.Module, + conv2: nn.Module, + bn2: nn.Module, + relu: nn.Module, + ) -> torch.Tensor: + """Process a branch of the bottleneck. + + Args: + branch (torch.Tensor): Input tensor. + conv1 (nn.Module): First convolution layer. + bn1 (nn.Module): First batch normalization layer. + conv2 (nn.Module): Second convolution layer. + bn2 (nn.Module): Second batch normalization layer. + relu (nn.Module): ReLU activation layer. + + Returns: + torch.Tensor: Output tensor. + """ + out = conv1(branch) + out = bn1(out) + out = relu(out) + out = conv2(out) + return bn2(out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + identity = x + + if self.halve == 1: + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + else: + num_channels = x.shape[1] + x_split = torch.split(x, [num_channels // 2, num_channels // 2], dim=1) + + out1 = self._process_branch(x_split[0], self.conv3x3, self.bn2, self.conv3x3_, self.bn5, self.relu) + out2 = self._process_branch(x_split[-1], self.conv7x7, self.bn4, self.conv7x7_, self.bn6, self.relu) + + out = torch.cat([out1, out2], dim=1) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out = out + identity + return self.relu(out) + + +class BottleneckLayer(nn.Module): + """Batch Normalization layer for UniNet. + + Args: + block (Type[AttentionBottleneck]): Attention bottleneck. + layers (int): Number of layers. + groups (int): Number of groups. + width_per_group (int): Width per group. + norm_layer (Callable[..., nn.Module] | None): Normalization layer. + halve (int): Number of halved channels. + """ + + def __init__( + self, + block: type[AttentionBottleneck], + layers: int, + groups: int = 1, + width_per_group: int = 64, + norm_layer: Callable[..., nn.Module] | None = None, + halve: int = 2, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.groups = groups + self.base_width = width_per_group + self.inplanes = 256 * block.expansion + self.halve = halve + self.bn_layer = nn.Sequential(self._make_layer(block, 512, layers, stride=2)) + self.conv1 = conv3x3(64 * block.expansion, 128 * block.expansion, 2) + self.bn1 = norm_layer(128 * block.expansion) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) + self.bn2 = norm_layer(256 * block.expansion) + self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) + self.bn3 = norm_layer(256 * block.expansion) + + self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1) + self.bn4 = norm_layer(512 * block.expansion) + + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, nn.BatchNorm2d | nn.GroupNorm): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + elif hasattr(module, "merge_kernel") and self.halve == 2: + module.merge_kernel() + + def _make_layer( + self, + block: type[AttentionBottleneck], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + """Make a layer of the bottleneck. + + Args: + block (AttentionBottleneck): Attention bottleneck. + planes (int): Number of planes. + blocks (int): Number of blocks. + stride (int): Stride of the convolution. + dilate (bool): Whether to dilate the convolution. + + Returns: + nn.Sequential: Sequential layer. + """ + norm_layer = self._norm_layer + downsample = None + if dilate: + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes * 3, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes * 3, + planes, + stride, + downsample, + self.groups, + self.base_width, + norm_layer, + halve=self.halve, + ), + ) + self.inplanes = planes * block.expansion + layers.extend( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + norm_layer=norm_layer, + halve=self.halve, + ) + for _ in range(1, blocks) + ) + + return nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + l1 = self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x[0])))))) + l2 = self.relu(self.bn3(self.conv3(x[1]))) + feature = torch.cat([l1, l2, x[2]], 1) + output = self.bn_layer(feature) # 16*2048*8*8 + return output.contiguous() diff --git a/src/anomalib/models/image/uninet/components/dfs.py b/src/anomalib/models/image/uninet/components/dfs.py new file mode 100644 index 0000000000..1b3417cb16 --- /dev/null +++ b/src/anomalib/models/image/uninet/components/dfs.py @@ -0,0 +1,96 @@ +"""Domain Related Feature Selection.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn +from torch.nn import functional + + +class DomainRelatedFeatureSelection(nn.Module): + """Domain Related Feature Selection. + + It is used to select the domain-related features from the source and target features. + + Args: + num_channels (int): Number of channels in the features. Defaults to 256. + learnable (bool): Whether to use learnable theta. Theta controls the domain-related feature selection. + """ + + def __init__(self, num_channels: int = 256, learnable: bool = True) -> None: + super().__init__() + self.num_channels = num_channels + self.theta1 = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) + self.theta2 = nn.Parameter(torch.zeros(1, num_channels * 2, 1, 1)) + self.theta3 = nn.Parameter(torch.zeros(1, num_channels * 4, 1, 1)) + self.learnable = learnable + + def _get_theta(self, idx: int) -> torch.Tensor: + match idx: + case 1: + return self.theta1 + case 2: + return self.theta2 + case 3: + return self.theta3 + case _: + msg = f"Invalid index: {idx}" + raise ValueError(msg) + + def forward( + self, + source_features: list[torch.Tensor], + target_features: list[torch.Tensor], + conv: bool = False, + maximize: bool = True, + ) -> list[torch.Tensor]: + """Domain related feature selection. + + Args: + source_features (list[torch.Tensor]): Source features. + target_features (list[torch.Tensor]): Target features. + conv (bool): Whether to use convolutional domain-related feature selection. + Defaults to False. + maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the + max value from the target feature. Defaults to True. + + Returns: + list[torch.Tensor]: Domain related features. + """ + features = [] + for idx, (source_feature, target_feature) in enumerate(zip(source_features, target_features, strict=True)): + theta = 1 + if self.learnable: + # to avoid losing local weight, theta should be as non-zero value as possible + if idx < 3: + theta = torch.clamp(torch.sigmoid(self._get_theta(idx + 1)) * 1.0 + 0.5, max=1) + else: + theta = torch.clamp(torch.sigmoid(self._get_theta(idx - 2)) * 1.0 + 0.5, max=1) + + b, c, h, w = source_feature.shape + if not conv: + prior_flat = target_feature.view(b, c, -1) + if maximize: + prior_flat_ = prior_flat.max(dim=-1, keepdim=True)[0] + prior_flat = prior_flat - prior_flat_ + weights = functional.softmax(prior_flat, dim=-1) + weights = weights.view(b, c, h, w) + + global_inf = target_feature.mean(dim=(-2, -1), keepdim=True) + + inter_weights = weights * (theta + global_inf) + + x_ = source_feature * inter_weights + features.append(x_) + else: + # generated in a convolutional way + pass + + return features diff --git a/src/anomalib/models/image/uninet/components/loss.py b/src/anomalib/models/image/uninet/components/loss.py new file mode 100644 index 0000000000..d2ddaa9d33 --- /dev/null +++ b/src/anomalib/models/image/uninet/components/loss.py @@ -0,0 +1,108 @@ +"""Loss functions for UniNet.""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn +from torch.nn import functional + + +class UniNetLoss(nn.Module): + """Loss function for UniNet. + + Args: + lambda_weight (float): Hyperparameter for balancing the loss. Defaults to 0.7. + temperature (float): Temperature for contrastive learning. Defaults to 2.0. + """ + + def __init__(self, lambda_weight: float = 0.7, temperature: float = 2.0) -> None: + super().__init__() + self.lambda_weight = lambda_weight + self.temperature = temperature + + def forward( + self, + student_features: list[torch.Tensor], + teacher_features: list[torch.Tensor], + margin: int = 1, + mask: torch.Tensor | None = None, + stop_gradient: bool = False, + ) -> torch.Tensor: + """Compute the loss. + + Args: + student_features (list[torch.Tensor]): Student features. + teacher_features (list[torch.Tensor]): Teacher features. + margin (int): Hyperparameter for controlling the boundary. + mask (torch.Tensor | None): Mask for the prediction. Mask is of shape Bx1xHxW + stop_gradient (bool): Whether to stop the gradient into teacher features. + """ + loss = 0.0 + margin_loss_a = 0.0 + + for idx in range(len(student_features)): + student_feature = student_features[idx] + teacher_feature = teacher_features[idx].detach() if stop_gradient else teacher_features[idx] + + n, c, h, w = student_feature.shape + student_feature = student_feature.view(n, c, -1).transpose(1, 2) # (N, H+W, C) + teacher_feature = teacher_feature.view(n, c, -1).transpose(1, 2) # (N, H+W, C) + + student_feature_normalized = functional.normalize(student_feature, p=2, dim=2) + teacher_feature_normalized = functional.normalize(teacher_feature, p=2, dim=2) + + cosine_loss = 1 - functional.cosine_similarity( + student_feature_normalized, + teacher_feature_normalized, + dim=2, + ) + cosine_loss = cosine_loss.mean() + + similarity = ( + torch.matmul(student_feature_normalized, teacher_feature_normalized.transpose(1, 2)) / self.temperature + ) + similarity = torch.exp(similarity) + similarity_sum = similarity.sum(dim=2, keepdim=True) + similarity = similarity / (similarity_sum + 1e-8) + diag_sum = torch.diagonal(similarity, dim1=1, dim2=2) + + # unsupervised and only normal (or abnormal) + if mask is None: + contrastive_loss = -torch.log(diag_sum + 1e-8).mean() + margin_loss_n = functional.relu(margin - diag_sum).mean() + + # supervised + else: + # gt label + if len(mask.shape) < 3: + normal_mask = mask == 0 + abnormal_mask = mask == 1 + # gt mask + else: + mask_ = functional.interpolate(mask, size=(h, w), mode="nearest").squeeze(1) + mask_flat = mask_.view(mask_.size(0), -1) + + normal_mask = mask_flat == 0 + abnormal_mask = mask_flat == 1 + + if normal_mask.sum() > 0: + diag_sim_normal = diag_sum[normal_mask] + contrastive_loss = -torch.log(diag_sim_normal + 1e-8).mean() + margin_loss_n = functional.relu(margin - diag_sim_normal).mean() + + if abnormal_mask.sum() > 0: + diag_sim_abnormal = diag_sum[abnormal_mask] + margin_loss_a = functional.relu(diag_sim_abnormal - margin / 2).mean() + + margin_loss = margin_loss_n + margin_loss_a + + loss += cosine_loss * self.lambda_weight + contrastive_loss * (1 - self.lambda_weight) + margin_loss + + return loss diff --git a/src/anomalib/models/image/uninet/lightning_model.py b/src/anomalib/models/image/uninet/lightning_model.py new file mode 100644 index 0000000000..3cb7f0579c --- /dev/null +++ b/src/anomalib/models/image/uninet/lightning_model.py @@ -0,0 +1,119 @@ +"""Lightning model for UniNet.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch.optim.lr_scheduler import MultiStepLR + +from anomalib import LearningType +from anomalib.data import Batch +from anomalib.metrics import Evaluator +from anomalib.models.components import AnomalibModule +from anomalib.models.components.backbone import get_decoder +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer + +from .components import ( + AttentionBottleneck, + BottleneckLayer, + UniNetLoss, +) +from .torch_model import UniNetModel + + +class UniNet(AnomalibModule): + """UniNet model for anomaly detection. + + Args: + student_backbone (str): The backbone model to use for the student network. Defaults to "wide_resnet50_2". + teacher_backbone (str): The backbone model to use for the teacher network. Defaults to "wide_resnet50_2". + temperature (float): Temperature parameter used for contrastive loss. Controls the temperature of the student + and teacher similarity computation. Defaults to 0.1. + pre_processor (PreProcessor | bool, optional): Preprocessor instance or bool flag. Defaults to True. + post_processor (PostProcessor | bool, optional): Postprocessor instance or bool flag. Defaults to True. + evaluator (Evaluator | bool, optional): Evaluator instance or bool flag. Defaults to True. + visualizer (Visualizer | bool, optional): Visualizer instance or bool flag. Defaults to True. + """ + + def __init__( + self, + student_backbone: str = "wide_resnet50_2", + teacher_backbone: str = "wide_resnet50_2", + temperature: float = 0.1, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, + ) -> None: + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) + self.loss = UniNetLoss(temperature=temperature) # base class expects self.loss in lightning module + self.model = UniNetModel( + student=get_decoder(student_backbone), + bottleneck=BottleneckLayer(block=AttentionBottleneck, layers=3), + teacher_backbone=teacher_backbone, + loss=self.loss, + ) + + def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Perform a training step of UniNet.""" + del args, kwargs # These variables are not used. + + loss = self.model(images=batch.image, masks=batch.gt_mask, labels=batch.gt_label) + + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Perform a validation step of UniNet.""" + del args, kwargs # These variables are not used. + if batch.image is None or not isinstance(batch.image, torch.Tensor): + msg = "Expected batch.image to be a tensor, but got None or non-tensor type" + raise ValueError(msg) + + predictions = self.model(batch.image) + return batch.update(**predictions._asdict()) + + def configure_optimizers(self) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]: + """Configure optimizers for training. + + Returns: + tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Optimizers for student and target teacher. + """ + optimizer = torch.optim.AdamW( + [ + {"params": self.model.student.parameters()}, + {"params": self.model.bottleneck.parameters()}, + {"params": self.model.dfs.parameters()}, + {"params": self.model.teachers.target_teacher.parameters(), "lr": 1e-6}, + ], + lr=5e-3, + betas=(0.9, 0.999), + weight_decay=1e-5, + eps=1e-10, + amsgrad=True, + ) + milestones = [ + int(self.trainer.max_steps * 0.8) if self.trainer.max_steps != -1 else (self.trainer.max_epochs * 0.8), + ] + scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.2) + return [optimizer], [scheduler] + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Does not require any trainer arguments.""" + return {} + + @property + def learning_type(self) -> LearningType: + """The model uses one-class learning.""" + return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/uninet/torch_model.py b/src/anomalib/models/image/uninet/torch_model.py new file mode 100644 index 0000000000..61ffc76a70 --- /dev/null +++ b/src/anomalib/models/image/uninet/torch_model.py @@ -0,0 +1,222 @@ +"""PyTorch model for UniNet. + +See Also: + :class:`anomalib.models.image.uninet.lightning_model.UniNet`: + UniNet Lightning model. +""" + +# Original Code +# Copyright (c) 2025 Shun Wei +# https://github.com/pangdatangtt/UniNet +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torchvision +from torch import nn +from torch.fx import GraphModule +from torch.nn import functional as F # noqa: N812 +from torchvision.models.feature_extraction import create_feature_extractor + +from anomalib.data import InferenceBatch + +from .components import DomainRelatedFeatureSelection, weighted_decision_mechanism + + +class UniNetModel(nn.Module): + """UniNet PyTorch model. + + It consists of teachers, student, and bottleneck modules. + + Args: + student (nn.Module): Student model. + bottleneck (nn.Module): Bottleneck model. + source_teacher (nn.Module): Source teacher model. + target_teacher (nn.Module | None): Target teacher model. + """ + + def __init__( + self, + student: nn.Module, + bottleneck: nn.Module, + teacher_backbone: str, + loss: nn.Module, + ) -> None: + super().__init__() + self.teachers = Teachers(teacher_backbone) + self.student = student + self.bottleneck = bottleneck + self.dfs = DomainRelatedFeatureSelection() + + self.loss = loss + self.bce_loss = torch.nn.BCEWithLogitsLoss() + # Used to post-process the student features from the de_resnet model to get the predictions + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1) + + def forward( + self, + images: torch.Tensor, + masks: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + ) -> torch.Tensor | InferenceBatch: + """Forward pass of the UniNet model. + + Args: + images (torch.Tensor): Input images. + masks (torch.Tensor | None): Ground truth masks. + labels (torch.Tensor | None): Ground truth labels. + + Returns: + torch.Tensor | InferenceBatch: Loss or InferenceBatch. + """ + source_target_features, bottleneck_inputs = self.teachers(images) + bottleneck_outputs = self.bottleneck(bottleneck_inputs) + + student_features = self.student(bottleneck_outputs) + + # These predictions are part of the de_resnet model of the original code. + # since we are using the de_resnet model from anomalib, we need to compute predictions here + predictions = self.avgpool(student_features[0]) + predictions = torch.flatten(predictions, 1) + predictions = self.fc(predictions).squeeze() + predictions = predictions.chunk(dim=0, chunks=2) + + student_features = [d.chunk(dim=0, chunks=2) for d in student_features] + student_features = [ + student_features[0][0], + student_features[1][0], + student_features[2][0], + student_features[0][1], + student_features[1][1], + student_features[2][1], + ] + if self.training: + student_features = self._feature_selection(source_target_features, student_features) + return self._compute_loss( + student_features, + source_target_features, + predictions, + labels, + masks, + ) + + output_list: list[torch.Tensor] = [] + for target_feature, student_feature in zip(source_target_features, student_features, strict=True): + output = 1 - F.cosine_similarity(target_feature, student_feature) # B*64*64 + output_list.append(output) + + anomaly_score, anomaly_map = weighted_decision_mechanism( + batch_size=images.shape[0], + output_list=output_list, + alpha=0.01, + beta=3e-05, + output_size=images.shape[-2:], + ) + return InferenceBatch( + pred_score=anomaly_score, + anomaly_map=anomaly_map, + ) + + def _compute_loss( + self, + student_features: list[torch.Tensor], + teacher_features: list[torch.Tensor], + predictions: tuple[torch.Tensor, torch.Tensor] | None = None, + label: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + stop_gradient: bool = False, + ) -> torch.Tensor: + """Compute the loss. + + Args: + student_features (list[torch.Tensor]): Student features. + teacher_features (list[torch.Tensor]): Teacher features. + predictions (tuple[torch.Tensor, torch.Tensor] | None): Predictions is (B, B) + label (torch.Tensor | None): Label for the prediction. + mask (torch.Tensor | None): Mask for the prediction. Mask is of shape BxHxW + stop_gradient (bool): Whether to stop the gradient into teacher features. + + Returns: + torch.Tensor: Loss. + """ + if mask is not None: + mask_ = mask.float().unsqueeze(1) # Bx1xHxW + else: + assert label is not None, "Label is required when mask is not provided" + mask_ = label.float() + + loss = self.loss(student_features, teacher_features, mask=mask_, stop_gradient=stop_gradient) + if predictions is not None and label is not None: + loss += self.bce_loss(predictions[0], label.float()) + self.bce_loss(predictions[1], label.float()) + + return loss + + def _feature_selection( + self, + target_features: list[torch.Tensor], + source_features: list[torch.Tensor], + maximize: bool = True, + ) -> list[torch.Tensor]: + """Feature selection. + + Args: + source_features (list[torch.Tensor]): Source features. + target_features (list[torch.Tensor]): Target features. + maximize (bool): Used for weights computation. If True, the weights are computed by subtracting the + max value from the target feature. Defaults to True. + """ + return self.dfs(source_features, target_features, maximize=maximize) + + +class Teachers(nn.Module): + """Teachers module for UniNet. + + Args: + source_teacher (nn.Module): Source teacher model. + target_teacher (nn.Module | None): Target teacher model. + """ + + def __init__(self, teacher_backbone: str) -> None: + super().__init__() + self.source_teacher = self._get_teacher(teacher_backbone).eval() + self.target_teacher = self._get_teacher(teacher_backbone) + + @staticmethod + def _get_teacher(backbone: str) -> GraphModule: + """Get the teacher model. + + In the original code, the teacher resnet model is used to extract features from the input image. + We can just use the feature extractor from torchvision to extract the features. + + Args: + backbone (str): The backbone model to use. + + Returns: + GraphModule: The teacher model. + """ + model = getattr(torchvision.models, backbone)(pretrained=True) + return create_feature_extractor(model, return_nodes=["layer3", "layer2", "layer1"]) + + def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Forward pass of the teachers. + + Args: + images (torch.Tensor): Input images. + + Returns: + torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: Source features or source and target features. + """ + with torch.no_grad(): + source_features = self.source_teacher(images) + + target_features = self.target_teacher(images) + + bottleneck_inputs = [ + torch.cat([a, b], dim=0) for a, b in zip(target_features.values(), source_features.values(), strict=True) + ] # 512, 1024, 2048 + + return list(source_features.values()) + list(target_features.values()), bottleneck_inputs diff --git a/third-party-programs.txt b/third-party-programs.txt index 5eeaca8ea9..4f71dbc247 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -46,3 +46,7 @@ terms are listed below. 8. AUPIMO metric implementation is based on the original code Copyright (c) 2023 @jpcbertoldo, https://github.com/jpcbertoldo/aupimo SPDX-License-Identifier: MIT + +9. UninetModel + Copyright (c) 2025 @pangdatangtt, https://github.com/pangdatangtt/UniNet/ + SPDX-License-Identifier: MIT