Skip to content

πŸš€ feat(model): add UniNet #2797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/source/markdown/guides/reference/models/image/uninet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# UniNet

```{eval-rst}
.. automodule:: anomalib.models.image.uninet.lightning_model
:members:
:show-inheritance:
```

```{eval-rst}
.. automodule:: anomalib.models.image.uninet.torch_model
:members:
:show-inheritance:
```

```{eval-rst}
.. automodule:: anomalib.models.image.uninet.components.loss
:members:
:show-inheritance:
```

```{eval-rst}
.. automodule:: anomalib.models.image.uninet.components.anomaly_map
:members:
:show-inheritance:
```

```{eval-rst}
.. automodule:: anomalib.models.image.uninet.components.attention_bottleneck
:members:
:show-inheritance:
```

```{eval-rst}
.. automodule:: anomalib.models.image.uninet.components.dfs
:members:
:show-inheritance:
```
15 changes: 15 additions & 0 deletions examples/configs/model/uninet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model:
class_path: anomalib.models.UniNet
init_args:
student_backbone: wide_resnet50_2
teacher_backbone: wide_resnet50_2
temperature: 0.1

trainer:
max_epochs: 100
callbacks:
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
patience: 20
monitor: image_AUROC
mode: max
2 changes: 2 additions & 0 deletions src/anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
Stfpm,
Supersimplenet,
Uflow,
UniNet,
VlmAd,
WinClip,
)
Expand Down Expand Up @@ -108,6 +109,7 @@ class UnknownModelError(ModuleNotFoundError):
"Stfpm",
"Supersimplenet",
"Uflow",
"UniNet",
"VlmAd",
"WinClip",
"AiVad",
Expand Down
17 changes: 17 additions & 0 deletions src/anomalib/models/components/backbone/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Common backbone models.

Example:
>>> from anomalib.models.components.backbone import get_decoder
>>> decoder = get_decoder()

See Also:
- :func:`anomalib.models.components.backbone.de_resnet`:
Decoder network implementation
"""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .resnet_decoder import get_decoder

__all__ = ["get_decoder"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# Copyright (C) 2022-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""PyTorch model defining the decoder network for Reverse Distillation.
"""PyTorch model defining the decoder network for Reverse Distillation and UniNet.

This module implements the decoder network used in the Reverse Distillation model
This module implements the decoder network used in the Reverse Distillation and UniNet models
architecture. The decoder reconstructs features from the bottleneck representation
back to the original feature space.

Expand All @@ -19,17 +19,15 @@
- Full decoder network architecture

Example:
>>> from anomalib.models.image.reverse_distillation.components.de_resnet import (
>>> from anomalib.models.components.backbone.resnet_decoder import (
... get_decoder
... )
>>> decoder = get_decoder()
>>> features = torch.randn(32, 512, 28, 28)
>>> reconstructed = decoder(features)

See Also:
- :class:`anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel`:
Main model implementation using this decoder
- :class:`anomalib.models.image.reverse_distillation.components.DecoderBasicBlock`:
- :class:`anomalib.models.components.backbone.resnet_decoder.DecoderBasicBlock`:
Basic building block for the decoder network
"""

Expand Down Expand Up @@ -157,8 +155,8 @@ class DecoderBottleneck(nn.Module):
"""Bottleneck block for the decoder network.

This module implements a bottleneck block used in the decoder part of the Reverse
Distillation model. It performs upsampling and feature reconstruction through a series of
convolutional layers.
Distillation and UniNet models. It performs upsampling and feature reconstruction
through a series of convolutional layers.

The block consists of three convolution layers:
1. 1x1 conv to adjust channels
Expand Down Expand Up @@ -186,7 +184,7 @@ class DecoderBottleneck(nn.Module):

Example:
>>> import torch
>>> from anomalib.models.image.reverse_distillation.components.de_resnet import (
>>> from anomalib.models.components.backbone.resnet_decoder import (
... DecoderBottleneck
... )
>>> layer = DecoderBottleneck(256, 64)
Expand Down Expand Up @@ -269,7 +267,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor:
return self.relu(out)


class ResNet(nn.Module):
class ResNetDecoder(nn.Module):
"""Decoder ResNet model for feature reconstruction.

This module implements a decoder version of the ResNet architecture, which
Expand Down Expand Up @@ -297,11 +295,11 @@ class ResNet(nn.Module):
layer to use. If ``None``, uses ``BatchNorm2d``. Defaults to ``None``.

Example:
>>> from anomalib.models.image.reverse_distillation.components import (
>>> from anomalib.models.components.backbone.resnet_decoder import (
... DecoderBasicBlock,
... ResNet
... ResNetDecoder
... )
>>> model = ResNet(
>>> model = ResNetDecoder(
... block=DecoderBasicBlock,
... layers=[2, 2, 2, 2]
... )
Expand Down Expand Up @@ -437,63 +435,63 @@ def forward(self, batch: torch.Tensor) -> list[torch.Tensor]:
return [feature_c, feature_b, feature_a]


def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNet:
return ResNet(block, layers, **kwargs)
def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNetDecoder:
return ResNetDecoder(block, layers, **kwargs)


def de_resnet18() -> ResNet:
def de_resnet18() -> ResNetDecoder:
"""ResNet-18 model."""
return _resnet(DecoderBasicBlock, [2, 2, 2, 2])


def de_resnet34() -> ResNet:
def de_resnet34() -> ResNetDecoder:
"""ResNet-34 model."""
return _resnet(DecoderBasicBlock, [3, 4, 6, 3])


def de_resnet50() -> ResNet:
def de_resnet50() -> ResNetDecoder:
"""ResNet-50 model."""
return _resnet(DecoderBottleneck, [3, 4, 6, 3])


def de_resnet101() -> ResNet:
def de_resnet101() -> ResNetDecoder:
"""ResNet-101 model."""
return _resnet(DecoderBottleneck, [3, 4, 23, 3])


def de_resnet152() -> ResNet:
def de_resnet152() -> ResNetDecoder:
"""ResNet-152 model."""
return _resnet(DecoderBottleneck, [3, 8, 36, 3])


def de_resnext50_32x4d() -> ResNet:
def de_resnext50_32x4d() -> ResNetDecoder:
"""ResNeXt-50 32x4d model."""
return _resnet(DecoderBottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)


def de_resnext101_32x8d() -> ResNet:
def de_resnext101_32x8d() -> ResNetDecoder:
"""ResNeXt-101 32x8d model."""
return _resnet(DecoderBottleneck, [3, 4, 23, 3], groups=32, width_per_group=8)


def de_wide_resnet50_2() -> ResNet:
def de_wide_resnet50_2() -> ResNetDecoder:
"""Wide ResNet-50-2 model."""
return _resnet(DecoderBottleneck, [3, 4, 6, 3], width_per_group=128)


def de_wide_resnet101_2() -> ResNet:
def de_wide_resnet101_2() -> ResNetDecoder:
"""Wide ResNet-101-2 model."""
return _resnet(DecoderBottleneck, [3, 4, 23, 3], width_per_group=128)


def get_decoder(name: str) -> ResNet:
def get_decoder(name: str) -> ResNetDecoder:
Copy link
Preview

Copilot AI Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation mentions 'ResNet' in the docstring but should be 'ResNetDecoder' to match the actual return type annotation.

Copilot uses AI. Check for mistakes.

"""Get decoder model based on the name of the backbone.

Args:
name (str): Name of the backbone.

Returns:
ResNet: Decoder ResNet architecture.
ResNetDecoder: Decoder ResNet architecture.
"""
decoder_map = {
"resnet18": de_resnet18,
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .stfpm import Stfpm
from .supersimplenet import Supersimplenet
from .uflow import Uflow
from .uninet import UniNet
from .vlm_ad import VlmAd
from .winclip import WinClip

Expand All @@ -82,6 +83,7 @@
"Stfpm",
"Supersimplenet",
"Uflow",
"UniNet",
"VlmAd",
"WinClip",
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
through distillation and reconstruction:

- Bottleneck layer: Compresses features into a lower dimensional space
- Decoder network: Reconstructs features from the bottleneck representation

Example:
>>> from anomalib.models.image.reverse_distillation.components import (
Expand All @@ -23,11 +22,8 @@
See Also:
- :func:`anomalib.models.image.reverse_distillation.components.bottleneck`:
Bottleneck layer implementation
- :func:`anomalib.models.image.reverse_distillation.components.de_resnet`:
Decoder network implementation
"""

from .bottleneck import get_bottleneck_layer
from .de_resnet import get_decoder

__all__ = ["get_bottleneck_layer", "get_decoder"]
__all__ = ["get_bottleneck_layer"]
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@

from anomalib.data import InferenceBatch
from anomalib.models.components import TimmFeatureExtractor
from anomalib.models.components.backbone import get_decoder

from .anomaly_map import AnomalyMapGenerationMode, AnomalyMapGenerator
from .components import get_bottleneck_layer, get_decoder
from .components import get_bottleneck_layer

if TYPE_CHECKING:
from anomalib.data.utils.tiler import Tiler
Expand Down
29 changes: 29 additions & 0 deletions src/anomalib/models/image/uninet/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Copyright (c) 2025 Intel Corporation
SPDX-License-Identifier: Apache-2.0

Some files in this folder are based on the original UniNet implementation by Shun Wei, Jielin Jiang, and Xiaolong Xu.

Original license:
-----------------

MIT License

Copyright (c) 2025 Shun Wei

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
65 changes: 65 additions & 0 deletions src/anomalib/models/image/uninet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# UniNet: Unified Contrastive Learning Framework for Anomaly Detection

This is the implementation of the UniNet model, a unified contrastive learning framework for anomaly detection presented in CVPR 2025.

Model Type: Classification and Segmentation

## Description

UniNet is a contrastive learning-based anomaly detection model that uses a teacher-student architecture with attention bottleneck mechanisms. The model is designed for diverse domains and supports both supervised and unsupervised anomaly detection scenarios. It focuses on multi-class anomaly detection and leverages domain-related feature selection to improve performance across different categories.

The model consists of:

- **Teacher Networks**: Pre-trained backbone networks that provide reference features for normal samples
- **Student Network**: A decoder network that learns to reconstruct normal patterns
- **Attention Bottleneck**: Mechanisms that help focus on relevant features
- **Domain-Related Feature Selection**: Adaptive feature selection for different domains
- **Contrastive Loss**: Temperature-controlled similarity computation between student and teacher features

During training, the student network learns to match teacher features for normal samples while being trained to distinguish anomalous patterns through contrastive learning. The model uses a weighted decision mechanism during inference to combine multi-scale features for final anomaly scoring.

## Architecture

The UniNet architecture leverages contrastive learning with teacher-student networks, incorporating attention bottlenecks and domain-specific feature selection for robust anomaly detection across diverse domains.

## Usage

`anomalib train --model UniNet --data MVTecAD --data.category <category>`

## Benchmark

All results gathered with seed `42`.

## [MVTecAD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)

### Image-Level AUC

| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper |
| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: |
| 0.956 | 0.999 | 0.982 | 0.939 | 0.896 | 0.996 | 0.999 | 1.000 | 1.000 | 0.816 | 0.919 | 0.970 | 1.000 | 0.984 | 0.993 | 0.945 |

### Pixel-Level AUC

| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper |
| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: |
| 0.976 | 0.989 | 0.983 | 0.985 | 0.973 | 0.992 | 0.987 | 0.993 | 0.984 | 0.964 | 0.992 | 0.965 | 0.992 | 0.923 | 0.961 | 0.984 |

### Image F1 Score

| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper |
| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: |
| 0.957 | 0.984 | 0.944 | 0.964 | 0.883 | 0.973 | 0.986 | 0.995 | 0.989 | 0.921 | 0.905 | 0.946 | 0.983 | 0.961 | 0.974 | 0.959 |

## Model Features

- **Contrastive Learning**: Uses temperature-controlled contrastive loss for effective anomaly detection
- **Multi-Domain Support**: Designed to work across diverse domains with domain-related feature selection
- **Flexible Training**: Supports both supervised and unsupervised training modes
- **Attention Mechanisms**: Incorporates attention bottlenecks for focused feature learning
- **Multi-Scale Features**: Leverages multi-scale feature matching for robust detection

## Parameters

- `student_backbone`: Backbone model for student network (default: "wide_resnet50_2")
- `teacher_backbone`: Backbone model for teacher network (default: "wide_resnet50_2")
- `temperature`: Temperature parameter for contrastive loss (default: 0.1)
Loading
Loading