diff --git a/examples/configs/model/dinomaly.yaml b/examples/configs/model/dinomaly.yaml new file mode 100644 index 0000000000..73f044a949 --- /dev/null +++ b/examples/configs/model/dinomaly.yaml @@ -0,0 +1,15 @@ +model: + class_path: anomalib.models.Dinomaly + init_args: + encoder_name: dinov2reg_vit_base_14 + bottleneck_dropout: 0.2 + decoder_depth: 8 + +trainer: + max_steps: 5000 + callbacks: + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 20 + monitor: image_AUROC + mode: max diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 43509e3ef0..5a27c9b062 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -60,6 +60,7 @@ Csflow, Dfkde, Dfm, + Dinomaly, Draem, Dsr, EfficientAd, @@ -97,6 +98,7 @@ class UnknownModelError(ModuleNotFoundError): "Dfkde", "Dfm", "Draem", + "Dinomaly", "Dsr", "EfficientAd", "Fastflow", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 9dec63dd09..cc77bb0a45 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -49,6 +49,7 @@ from .csflow import Csflow from .dfkde import Dfkde from .dfm import Dfm +from .dinomaly import Dinomaly from .draem import Draem from .dsr import Dsr from .efficient_ad import EfficientAd @@ -84,4 +85,5 @@ "Uflow", "VlmAd", "WinClip", + "Dinomaly", ] diff --git a/src/anomalib/models/image/dinomaly/README.md b/src/anomalib/models/image/dinomaly/README.md new file mode 100644 index 0000000000..242d7dd4be --- /dev/null +++ b/src/anomalib/models/image/dinomaly/README.md @@ -0,0 +1,81 @@ +# Dinomaly: Vision Transformer-based Anomaly Detection with Feature Reconstruction + +This is the implementation of the Dinomaly model based on the [original implementation](https://github.com/guojiajeremy/Dinomaly). + +Model Type: Segmentation + +## Description + +Dinomaly is a Vision Transformer-based anomaly detection model that uses an encoder-decoder architecture +for feature reconstruction. +The model leverages pre-trained DINOv2 Vision Transformer features and employs a reconstruction-based approach +to detect anomalies by comparing encoder and decoder features. + +### Architecture + +The Dinomaly model consists of three main components: + +1. DINOv2 Encoder: A pre-trained Vision Transformer (ViT) which extracts multi-scale feature maps. +2. Bottleneck MLP: A simple feed-forward network that collects features from the encoder's middle layers + (e.g., 8 out of 12 layers for ViT-Base). +3. Vision Transformer Decoder: Consisting of Transformer layers (typically 8), it learns to reconstruct the + compressed middle-level features by maximising cosine similarity with the encoder's features. + +Only the parameters of the bottleneck MLP and the decoder are trained. + +#### Key Components + +1. Foundation Transformer Models: Dinomaly leverages pre-trained ViTs (like DinoV2) which provide universal and + discriminative features. This use of foundation models enables strong performance across various image patterns. +2. Noisy Bottleneck: This component activates built-in Dropout within the MLP bottleneck. + By randomly discarding neural activations, Dropout acts as a "pseudo feature anomaly," which forces the decoder + to restore only normal features. This helps prevent the decoder from becoming too adept at reconstructing + anomalous patterns it has not been specifically trained on. +3. Linear Attention: Instead of traditional Softmax Attention, Linear Attention is used in the decoder. + Linear Attention's inherent inability to heavily focus on local regions, a characteristic sometimes seen as a + "side effect" in supervised tasks, is exploited here. This property encourages attention to spread across + the entire image, reducing the likelihood of the decoder simply forwarding identical information + from unexpected or anomalous patterns. This also contributes to computational efficiency. +4. Loose Reconstruction: + 1. Loose Constraint: Rather than enforcing rigid layer-to-layer reconstruction, Dinomaly groups multiple + encoder layers as a whole for reconstruction (e.g., into low-semantic and high-semantic groups). + This provides the decoder with more degrees of freedom, allowing it to behave more distinctly from the + encoder when encountering unseen patterns. + 2. Loose Loss: The point-by-point reconstruction loss function is loosened by employing a hard-mining + global cosine loss. This approach detaches the gradients of feature points that are already well-reconstructed + during training, preventing the model from becoming overly proficient at reconstructing all features, + including those that might correspond to anomalies. + +### Anomaly Detection + +Anomaly detection is performed by computing cosine similarity between encoder and decoder features at multiple scales. +The model generates anomaly maps by analyzing the reconstruction quality of features, where poor reconstruction +indicates anomalous regions. Both anomaly detection (image-level) and localization (pixel-level) are supported. + +## Usage + +`anomalib train --model Dinomaly --data MVTecAD --data.category ` + +## Benchmark + +All results gathered with seed `42`. The `max_steps` parameter is set to `5000` for training. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | +| -------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | +| Dinomaly | 0.995 | 0.998 | 0.999 | 1.000 | 1.000 | 0.993 | 1.000 | 1.000 | 0.988 | 1.000 | 1.000 | 0.993 | 0.985 | 1.000 | 0.997 | + +### Pixel-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | +| -------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | +| Dinomaly | 0.981 | 0.993 | 0.993 | 0.993 | 0.975 | 0.975 | 0.990 | 0.981 | 0.986 | 0.994 | 0.969 | 0.977 | 0.997 | 0.988 | 0.950 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | +| -------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | +| Dinomaly | 0.985 | 0.983 | 0.991 | 0.995 | 0.994 | 0.975 | 1.000 | 0.995 | 0.982 | 1.000 | 1.000 | 0.986 | 0.957 | 0.983 | 0.976 | diff --git a/src/anomalib/models/image/dinomaly/__init__.py b/src/anomalib/models/image/dinomaly/__init__.py new file mode 100644 index 0000000000..0cb60f252a --- /dev/null +++ b/src/anomalib/models/image/dinomaly/__init__.py @@ -0,0 +1,38 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Dinomaly: Vision Transformer-based Anomaly Detection with Feature Reconstruction. + +The Dinomaly model implements a Vision Transformer encoder-decoder architecture for +anomaly detection using pre-trained DINOv2 features. The model extracts features from +multiple intermediate layers of a DINOv2 encoder, compresses them through a bottleneck +MLP, and reconstructs them using a Vision Transformer decoder. + +Anomaly detection is performed by computing cosine similarity between encoder and decoder +features at multiple scales. The model is particularly effective for visual anomaly +detection tasks where the goal is to identify regions or images that deviate from +normal patterns learned during training. + +Example: + >>> from anomalib.models.image import Dinomaly + >>> model = Dinomaly() + +The model can be used with any of the supported datasets and task modes in +anomalib. It leverages the powerful feature representations from DINOv2 Vision +Transformers combined with a reconstruction-based approach for robust anomaly detection. + +Notes: + - Uses DINOv2 Vision Transformer as the backbone encoder + - Features are extracted from intermediate layers for multi-scale analysis + - Employs feature reconstruction loss for unsupervised learning + - Supports both anomaly detection and localization tasks + - Requires significant GPU memory due to Vision Transformer architecture + +See Also: + :class:`anomalib.models.image.dinomaly.lightning_model.Dinomaly`: + Lightning implementation of the Dinomaly model. +""" + +from anomalib.models.image.dinomaly.lightning_model import Dinomaly + +__all__ = ["Dinomaly"] diff --git a/src/anomalib/models/image/dinomaly/components/__init__.py b/src/anomalib/models/image/dinomaly/components/__init__.py new file mode 100644 index 0000000000..70385c143b --- /dev/null +++ b/src/anomalib/models/image/dinomaly/components/__init__.py @@ -0,0 +1,38 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Components module for Dinomaly model. + +This module provides all the necessary components for the Dinomaly Vision Transformer +architecture including layers, model loader, utilities, and vision transformer implementations. +""" + +# Model loader +from .dinov2_loader import DinoV2Loader, load + +# Layer components +from .layers import Block, DinomalyMLP, LinearAttention, MemEffAttention + +# Training-related classes: Loss, Optimizer and scheduler +from .loss import CosineHardMiningLoss +from .optimizer import StableAdamW, WarmCosineScheduler + +# Vision transformer components +from .vision_transformer import DinoVisionTransformer + +__all__ = [ + # Layers + "Block", + "DinomalyMLP", + "LinearAttention", + "MemEffAttention", + # Model loader + "DinoV2Loader", + "load", + # Utils + "StableAdamW", + "WarmCosineScheduler", + "CosineHardMiningLoss", + # Vision transformer + "DinoVisionTransformer", +] diff --git a/src/anomalib/models/image/dinomaly/components/dinov2_loader.py b/src/anomalib/models/image/dinomaly/components/dinov2_loader.py new file mode 100644 index 0000000000..8e85d15749 --- /dev/null +++ b/src/anomalib/models/image/dinomaly/components/dinov2_loader.py @@ -0,0 +1,187 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Loader for DINOv2 Vision Transformer models. + +This module provides a simple interface for loading pre-trained DINOv2 Vision Transformer models for the +Dinomaly anomaly detection framework. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import ClassVar +from urllib.request import urlretrieve + +import torch + +from anomalib.data.utils import DownloadInfo +from anomalib.data.utils.download import DownloadProgressBar +from anomalib.models.image.dinomaly.components import vision_transformer as dinov2_models + +logger = logging.getLogger(__name__) + + +class DinoV2Loader: + """Simple loader for DINOv2 Vision Transformer models. + + Supports loading dinov2 and dinov2_reg models with small, base, and large architectures. + """ + + DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + # Model configurations + MODEL_CONFIGS: ClassVar[dict[str, dict[str, int]]] = { + "small": {"embed_dim": 384, "num_heads": 6}, + "base": {"embed_dim": 768, "num_heads": 12}, + "large": {"embed_dim": 1024, "num_heads": 16}, + } + + def __init__(self, cache_dir: str | Path = "./pre_trained/") -> None: + """Initialize model loader. + + Args: + cache_dir: Directory to store downloaded model weights. + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def load(self, model_name: str) -> torch.nn.Module: + """Load a DINOv2 model by name. + + Args: + model_name: Name like 'dinov2_vit_base_14' or 'dinov2reg_vit_small_14'. + + Returns: + Loaded PyTorch model ready for inference. + + Raises: + ValueError: If model name is invalid or unsupported. + """ + # Parse model name + model_type, architecture, patch_size = self._parse_name(model_name) + + # Create model + model = self._create_model(model_type, architecture, patch_size) + + # Load weights + self._load_weights(model, model_type, architecture, patch_size) + + logger.info(f"Loaded model: {model_name}") + return model + + def _parse_name(self, name: str) -> tuple[str, str, int]: + """Parse model name into components.""" + parts = name.split("_") + + if len(parts) < 3: + msg = f"Invalid model name format: {name}. Expected format: 'dinov2_vit__'" + raise ValueError(msg) + + # Determine model type and extract architecture/patch_size + if "dinov2reg" in name or "reg" in name: + model_type = "dinov2_reg" + architecture = parts[-2] + patch_size = int(parts[-1]) + else: + model_type = "dinov2" + architecture = parts[-2] + patch_size = int(parts[-1]) + + if architecture not in self.MODEL_CONFIGS: + valid_archs = list(self.MODEL_CONFIGS.keys()) + msg = f"Invalid architecture '{architecture}' in model name '{name}'. Valid architectures: {valid_archs}" + raise ValueError(msg) + + return model_type, architecture, patch_size + + @staticmethod + def _create_model(model_type: str, architecture: str, patch_size: int) -> torch.nn.Module: + """Create model with appropriate configuration.""" + model_kwargs = { + "patch_size": patch_size, + "img_size": 518, + "block_chunks": 0, + "init_values": 1e-8, + "interpolate_antialias": False, + "interpolate_offset": 0.1, + } + + # Add register tokens for reg models + if model_type == "dinov2_reg": + model_kwargs["num_register_tokens"] = 4 + + # Get model constructor function + model_fn = getattr(dinov2_models, f"vit_{architecture}", None) + if model_fn is None: + msg = f"Model function vit_{architecture} not found in dinov2_models" + raise ValueError(msg) + + return model_fn(**model_kwargs) + + def _load_weights(self, model: torch.nn.Module, model_type: str, architecture: str, patch_size: int) -> None: + """Download and load model weights using standardized Anomalib utilities.""" + weight_path = self._get_weight_path(model_type, architecture, patch_size) + + if not weight_path.exists(): + self._download_weights(model_type, architecture, patch_size) + + state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=False) + + def _get_weight_path(self, model_type: str, architecture: str, patch_size: int) -> Path: + """Get local path for model weights.""" + arch_code = architecture[0] # s, b, or l + + if model_type == "dinov2_reg": + filename = f"dinov2_vit{arch_code}{patch_size}_reg4_pretrain.pth" + else: + filename = f"dinov2_vit{arch_code}{patch_size}_pretrain.pth" + + return self.cache_dir / filename + + def _download_weights(self, model_type: str, architecture: str, patch_size: int) -> None: + """Download model weights using standardized Anomalib download utilities.""" + arch_code = architecture[0] + weight_path = self._get_weight_path(model_type, architecture, patch_size) + + # Build download URL + model_dir = f"dinov2_vit{arch_code}{patch_size}" + url = f"{self.DINOV2_BASE_URL}/{model_dir}/{weight_path.name}" + + # Create DownloadInfo for standardized download + download_info = DownloadInfo( + name=f"DINOv2 {model_type} {architecture} weights", + url=url, + hashsum="", # DINOv2 doesn't provide official hashes, but we use empty string for now + filename=weight_path.name, + ) + + logger.info(f"Downloading DINOv2 weights: {weight_path.name} to {self.cache_dir}") + + # Ensure cache directory exists + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Download with progress bar (following Anomalib patterns) + with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=download_info.name) as progress_bar: + urlretrieve( # noqa: S310 # nosec B310 + url=url, + filename=weight_path, + reporthook=progress_bar.update_to, + ) + + +def load(model_name: str) -> torch.nn.Module: + """Convenience function to load a model. + + This can be later extended to be a factory method to load other models. + + Args: + model_name: Name like 'dinov2_vit_base_14' or 'dinov2reg_vit_small_14'. + + Returns: + Loaded PyTorch model. + """ + loader = DinoV2Loader() + return loader.load(model_name) diff --git a/src/anomalib/models/image/dinomaly/components/layers.py b/src/anomalib/models/image/dinomaly/components/layers.py new file mode 100644 index 0000000000..ae23526dd6 --- /dev/null +++ b/src/anomalib/models/image/dinomaly/components/layers.py @@ -0,0 +1,331 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Consolidated layer implementations for Dinomaly model. + +This module contains all layer-level components used in the Dinomaly Vision Transformer +architecture, including attention mechanisms, transformer blocks, and MLP layers. + +References: + https://github.com/facebookresearch/dino/blob/master/vision_transformer.py + https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py +""" + +import logging +from collections.abc import Callable +from typing import Any + +from timm.layers.drop import DropPath +from timm.models.vision_transformer import Attention, LayerScale +from torch import Tensor, einsum, nn +from torch.nn import functional as F # noqa: N812 + +logger = logging.getLogger("dinov2") + + +class MemEffAttention(Attention): + """Memory-efficient attention from the dinov2 implementation with a small change. + + Reference: + https://github.com/facebookresearch/dinov2/blob/592541c8d842042bb5ab29a49433f73b544522d5/dinov2/eval/segmentation_m2f/models/backbones/vit.py#L159 + + Instead of using xformers's memory_efficient_attention() method, which requires adding a new dependency to anomalib, + this implementation uses the scaled dot product from torch. + """ + + def forward(self, x: Tensor, attn_bias: Tensor | None = None) -> Tensor: + """Compute memory-efficient attention using PyTorch's scaled dot product attention. + + Args: + x: Input tensor of shape (batch_size, seq_len, embed_dim). + attn_bias: Optional attention bias mask. Default: None. + + Returns: + Output tensor of shape (batch_size, seq_len, embed_dim). + """ + batch_size, seq_len, embed_dim = x.shape + qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, embed_dim // self.num_heads) + + q, k, v = qkv.unbind(2) + + # Use PyTorch's native scaled dot product attention for memory efficiency. + # Replaced xformers's memory_efficient_attention() method with pytorch's scaled + # dot product. + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=attn_bias, + ) + x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + + x = self.proj(x) + return self.proj_drop(x) + + +class LinearAttention(nn.Module): + """LLinear Attention is a Softmax-free Attention that serves as an alternative to vanilla Softmax Attention. + + As per the Dinomaly paper, using a Linear Attention leads to an "incompetence in focusing" on important + regions related to the query, such as foreground and neighbours. This property encourages attention to spread across + the entire image. This also contributes to computational efficiency. + Reference : + https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/models/vision_transformer.py#L213 + """ + + def __init__( + self, + input_dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float | None = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """Initialize the linear attention mechanism. + + Args: + input_dim: Input feature dimension. + num_heads: Number of attention heads. Default: 8. + qkv_bias: Whether to add bias to the query, key, value projections. Default: False. + qk_scale: Override default scale factor for attention. If None, uses head_dim**-0.5. Default: None. + attn_drop: Dropout probability for attention weights. Default: 0.0. + proj_drop: Dropout probability for output projection. Default: 0.0. + """ + super().__init__() + self.num_heads = num_heads + head_dim = input_dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(input_dim, input_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Linear(input_dim, input_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass through linear attention with ELU-based feature maps. + + This implements a linear attention mechanism that avoids the quadratic complexity + of standard attention by using ELU activation functions to create positive + feature maps for keys and queries. + + Args: + x: Input tensor of shape (batch_size, seq_len, embed_dim). + + Returns: + A tuple containing: + - Output tensor of shape (batch_size, seq_len, embed_dim). + - Key-value interaction tensor for potential downstream use. + """ + batch_size, seq_len, embed_dim = x.shape + qkv = ( + self.qkv(x) + .reshape(batch_size, seq_len, 3, self.num_heads, embed_dim // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = F.elu(q) + 1.0 + k = F.elu(k) + 1.0 + + kv = einsum("...sd,...se->...de", k, v) + z = 1.0 / einsum("...sd,...d->...s", q, k.sum(dim=-2)) + x = einsum("...de,...sd,...s->...se", kv, q, z) + x = x.transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + + x = self.proj(x) + x = self.proj_drop(x) + return x, kv + + +class DinomalyMLP(nn.Module): + """Unified MLP supporting bottleneck-style behavior, optional input dropout, and bias control. + + This can be used as a simple MLP layer or as the BottleNeck layer in Dinomaly models. + The code is a combined class representation of several MLP implementations in the Dinomaly codebase, + including the BottleNeck and Decoder MLPs. + References : + https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/models/vision_transformer.py#L67 + https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/models/vision_transformer.py#L128 + https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/dinov2/layers/mlp.py#L16 + + Example usage for BottleNeck: + >>> embedding_dim = 768 + >>> mlp = DinomalyMLP( + ... in_features=embedding_dim, + ... hidden_features=embedding_dim * 4, + ... out_features=embedding_dim, + ... drop=0.2, + ... bias=False, + ... apply_input_dropout=True) + + Example usage for a Decoder's MLP: + >>> embedding_dim = 768 + >>> mlp = DinomalyMLP( + ... in_features=embedding_dim, + ... hidden_features=embedding_dim * 4, + ... drop=0.2, + ... bias=False, + ... apply_input_dropout=False) + """ + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = False, + apply_input_dropout: bool = False, + ) -> None: + """Initialize the Dinomaly MLP layer. + + Args: + in_features: Number of input features. + hidden_features: Number of hidden features. If None, defaults to in_features. Default: None. + out_features: Number of output features. If None, defaults to in_features. Default: None. + act_layer: Activation layer class. Default: nn.GELU. + drop: Dropout probability. Default: 0.0. + bias: Whether to include bias in linear layers. Default: False. + apply_input_dropout: Whether to apply dropout to input before first linear layer. Default: False. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + self.apply_input_dropout = apply_input_dropout + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through the MLP with optional input dropout. + + Applies the following sequence: + 1. Optional input dropout (if apply_input_dropout=True) + 2. First linear transformation + 3. Activation function + 4. Dropout + 5. Second linear transformation + 6. Final dropout + + Args: + x: Input tensor of shape (batch_size, seq_len, feature_dim). + + Returns: + Output tensor of shape (batch_size, seq_len, out_features). + """ + if self.apply_input_dropout: + x = self.drop(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) + + +class Block(nn.Module): + """Transformer block with attention and MLP. + + The code is similar to the standard transformer block but has an extra fail-safe + in the forward method when using memory-efficient attention. + Reference: https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/dinov2/layers/block.py#L41 + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values: float | None = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = DinomalyMLP, + ) -> None: + """Initialize a transformer block with attention and MLP layers. + + Args: + dim: Input feature dimension. + num_heads: Number of attention heads. + mlp_ratio: Ratio of MLP hidden dimension to input dimension. Default: 4.0. + qkv_bias: Whether to add bias to query, key, value projections. Default: False. + proj_bias: Whether to add bias to attention output projection. Default: True. + ffn_bias: Whether to add bias to feed-forward network layers. Default: True. + drop: Dropout probability for MLP and projection layers. Default: 0.0. + attn_drop: Dropout probability for attention weights. Default: 0.0. + init_values: Initial values for layer scale. If None, layer scale is disabled. Default: None. + drop_path: Drop path probability for stochastic depth. Default: 0.0. + act_layer: Activation layer class for MLP. Default: nn.GELU. + norm_layer: Normalization layer class. Default: nn.LayerNorm. + attn_class: Attention mechanism class. Default: Attention. + ffn_layer: Feed-forward network layer class. Default: DinomalyMLP. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + apply_input_dropout=False, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, return_attention: bool = False) -> Tensor | tuple[Tensor, Any]: + """Forward pass through the transformer block. + + Applies the standard transformer architecture: + 1. Layer normalization followed by attention + 2. Residual connection with optional layer scaling and drop path + 3. Layer normalization followed by MLP + 4. Residual connection with optional layer scaling and drop path + + Args: + x: Input tensor of shape (batch_size, seq_len, embed_dim) + return_attention: Whether to return attention weights along with output. Default: False. + + Returns: + If return_attention is False: + Output tensor of shape (batch_size, seq_len, embed_dim). + If return_attention is True: + Tuple containing output tensor and attention weights. + Note: Attention weights are None for MemEffAttention. + """ + # Always use the MemEffAttention path for consistency + if isinstance(self.attn, MemEffAttention): + y = self.attn(self.norm1(x)) + attn = None + else: + y, attn = self.attn(self.norm1(x)) + + x = x + self.ls1(y) + x = x + self.ls2(self.mlp(self.norm2(x))) + if return_attention: + return x, attn + return x diff --git a/src/anomalib/models/image/dinomaly/components/loss.py b/src/anomalib/models/image/dinomaly/components/loss.py new file mode 100644 index 0000000000..e63943c360 --- /dev/null +++ b/src/anomalib/models/image/dinomaly/components/loss.py @@ -0,0 +1,122 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Cosine Hard Mining Loss for training Dinomaly model. + +The code is based on the `global_cosine_hm_percent()` method in the original dinomaly implementation +Reference: https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/utils.py#L70C5-L70C29 +""" + +from functools import partial + +import torch + + +class CosineHardMiningLoss(torch.nn.Module): + """Cosine similarity loss with hard mining for anomaly detection. + + This loss function implements a sophisticated training strategy for the Dinomaly model + that prevents the decoder from becoming too effective at reconstructing anomalous regions. + The key insight is to "loosen the point-by-point reconstruction constraint" by reducing + the gradient contribution of well-reconstructed (easy) feature points during training. + + The algorithm works by: + 1. Computing cosine similarity between encoder and decoder features + 2. Identifying well-reconstructed points (those with high cosine similarity) + 3. Reducing the gradient contribution of these easy points by factor + 4. Focusing training on harder-to-reconstruct points + + This prevents the decoder from learning to reconstruct anomalous patterns, which is + crucial for effective anomaly detection during inference. + + Args: + p (float): Percentage of well-reconstructed (easy) points to down-weight. + Higher values (closer to 1.0) down-weight more points, making training + focus on fewer, harder examples. Default is 0.9 (down-weight 90% of easy points). + factor (float): Gradient reduction factor for well-reconstructed points. + Lower values reduce gradient contribution more aggressively. Default is 0.1 + (reduce gradients to 10% of the original value). + + Note: + Despite the name "hard mining", this loss actually down-weights easy examples + rather than up-weighting hard ones. The naming follows the original implementation + for consistency. + """ + + def __init__(self, p: float = 0.9, factor: float = 0.1) -> None: + """Initialize the CosineHardMiningLoss. + + Args: + p (float): Percentage of well-reconstructed points to down-weight (0.0 to 1.0). + Higher values make training focus on fewer, harder examples. Default is 0.9. + factor (float): Gradient reduction factor for well-reconstructed points (0.0 to 1.0). + Lower values reduce gradient contribution more aggressively. Default is 0.1. + """ + super().__init__() + self.p = p + self.factor = factor + + def forward( + self, + encoder_features: list[torch.Tensor], + decoder_features: list[torch.Tensor], + ) -> torch.Tensor: + """Forward pass of the cosine hard mining loss. + + Computes cosine similarity loss between encoder and decoder features while + applying gradient modification to down-weight well-reconstructed points. + + Args: + encoder_features: List of feature tensors from encoder layers. + Each tensor should have a shape (batch_size, num_features, height, width). + decoder_features: List of corresponding feature tensors from decoder layers. + Must have the same length and compatible shapes as encoder_features. + + Returns: + Computed loss value averaged across all feature layers. + + Note: + The encoder features are detached to prevent gradient flow through the encoder, + focusing training only on the decoder parameters. + """ + cos_loss = torch.nn.CosineSimilarity() + loss = torch.tensor(0.0, device=encoder_features[0].device) + for item in range(len(encoder_features)): + en_ = encoder_features[item].detach() + de_ = decoder_features[item] + with torch.no_grad(): + point_dist = 1 - cos_loss(en_, de_).unsqueeze(1) + k = max(1, int(point_dist.numel() * (1 - self.p))) + thresh = torch.topk(point_dist.reshape(-1), k=k)[0][-1] + + loss += torch.mean(1 - cos_loss(en_.reshape(en_.shape[0], -1), de_.reshape(de_.shape[0], -1))) + + partial_func = partial( + self._modify_grad, + indices_to_modify=point_dist < thresh, + gradient_multiply_factor=self.factor, + ) + de_.register_hook(partial_func) + + return loss / len(encoder_features) + + @staticmethod + def _modify_grad( + x: torch.Tensor, + indices_to_modify: torch.Tensor, + gradient_multiply_factor: float = 0.0, + ) -> torch.Tensor: + """Modify gradients based on indices and factor. + + Args: + x: Input tensor + indices_to_modify: Boolean indices indicating which elements to modify + gradient_multiply_factor: Factor to multiply the selected gradients by + + Returns: + Modified tensor + """ + indices_to_modify = indices_to_modify.expand_as(x) + result = x.clone() + result[indices_to_modify] = result[indices_to_modify] * gradient_multiply_factor + return result diff --git a/src/anomalib/models/image/dinomaly/components/optimizer.py b/src/anomalib/models/image/dinomaly/components/optimizer.py new file mode 100644 index 0000000000..cfa3647401 --- /dev/null +++ b/src/anomalib/models/image/dinomaly/components/optimizer.py @@ -0,0 +1,212 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Learning rate scheduler and optimizer for the Dinomaly model. + +This module contains the WarmCosineScheduler and StableAdamW classes. +The code is based on the original dinomaly implementation: +https://github.com/guojiajeremy/Dinomaly/ + +""" + +import math +from collections.abc import Callable, Iterable +from typing import Any, TypeAlias + +import numpy as np +import torch +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + +ParamsT: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, torch.Tensor]] + + +class WarmCosineScheduler(_LRScheduler): + """Cosine annealing scheduler with warmup. + + Learning rate scheduler that combines warm-up with cosine annealing. + Reference: https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/utils.py#L775 + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_value (float): Initial learning rate after warmup. + final_value (float): Final learning rate after annealing. + total_iters (int): Total number of iterations. + warmup_iters (int, optional): Number of warmup iterations. Default is 0. + start_warmup_value (float, optional): Starting learning rate for warmup. Default is 0. + + """ + + def __init__( + self, + optimizer: Optimizer, + base_value: float, + final_value: float, + total_iters: int, + warmup_iters: int = 0, + start_warmup_value: float = 0, + ) -> None: + self.final_value = final_value + self.total_iters = total_iters + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((warmup_schedule, schedule)) + + super().__init__(optimizer) + + def get_lr(self) -> list[float]: + """Returns the learning rate for the current epoch. + + Returns: + list[float]: List of learning rates for each parameter group. + """ + if self.last_epoch >= self.total_iters: + return [self.final_value for base_lr in self.base_lrs] + return [self.schedule[self.last_epoch] for base_lr in self.base_lrs] + + +class StableAdamW(Optimizer): + """Implements "stable AdamW" algorithm with gradient clipping. + + This was introduced in "Stable and low-precision training for large-scale vision-language models". + Publication Reference : https://arxiv.org/abs/2304.13013 + Code reference (original implementation of the dinomaly model): + https://github.com/guojiajeremy/Dinomaly/blob/861a99b227fd2813b6ad8e8c703a7bea139ab735/optimizers/StableAdamW.py#L10 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__( + self, + params: ParamsT, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + clip_threshold: float = 1.0, + ) -> None: + if not lr >= 0.0: + msg = f"Invalid learning rate: {lr}" + raise ValueError(msg) + if not eps >= 0.0: + msg = f"Invalid epsilon value: {eps}" + raise ValueError(msg) + if not 0.0 <= betas[0] < 1.0: + msg = f"Invalid beta parameter at index 0: {betas[0]}" + raise ValueError(msg) + if not 0.0 <= betas[1] < 1.0: + msg = f"Invalid beta parameter at index 1: {betas[1]}" + raise ValueError(msg) + if not weight_decay >= 0.0: + msg = f"Invalid weight decay value: {weight_decay}" + raise ValueError(msg) + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "amsgrad": amsgrad, + "clip_threshold": clip_threshold, + } + super().__init__(params, defaults) + + def __setstate__(self, state: dict[str, Any]) -> None: + """Restores the optimizer state from a given state dictionary. + + Ensures that the `amsgrad` parameter is set for each parameter group, + maintaining compatibility when loading optimizer states from checkpoints. + + Args: + state (dict[str, Any]): State dictionary to restore. + """ + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + @staticmethod + def _rms(tensor: torch.Tensor) -> float: + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform step-weight decay + p.data.mul_(1 - group["lr"] * group["weight_decay"]) + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + msg = "StableAdamW does not support sparse gradients, please consider using SparseAdam instead." + raise RuntimeError(msg) + amsgrad = group["amsgrad"] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p) # , memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p) # , memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exponential moving averages of squared gradients + state["max_exp_avg_sq"] = torch.zeros_like(p) # , memory_format=torch.preserve_format) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running average till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running average of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) + + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) + + lr_scale = grad / denom + lr_scale = max(1.0, self._rms(lr_scale) / group["clip_threshold"]) + + step_size = group["lr"] / bias_correction1 / (lr_scale) + + p.data.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/src/anomalib/models/image/dinomaly/components/vision_transformer.py b/src/anomalib/models/image/dinomaly/components/vision_transformer.py new file mode 100644 index 0000000000..18a652890b --- /dev/null +++ b/src/anomalib/models/image/dinomaly/components/vision_transformer.py @@ -0,0 +1,472 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Vision Transformer implementation for Dinomaly model. + +This module contains the implementation of the Vision Transformer architecture used in the Dinomaly model. +Use the methods `vit_small`, `vit_base`, `vit_large`, and `vit_giant2` to create a DinoVisionTransformer instance +that can be used as en encoder for the Dinomaly model. + + +References: + https://github.com/facebookresearch/dino/blob/main/vision_transformer.py + https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py +""" + +import logging +import math +from collections.abc import Callable +from functools import partial + +import torch +import torch.utils.checkpoint +from timm.layers.patch_embed import PatchEmbed +from torch import nn +from torch.nn.init import trunc_normal_ + +from anomalib.models.image.dinomaly.components.layers import Block, DinomalyMLP, MemEffAttention + +logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, + module: nn.Module, + name: str = "", + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + """Apply a function recursively to module and its children. + + Args: + fn: Function to apply to each module + module: Root module to apply function to + name: Name of the current module + depth_first: Whether to apply function depth-first + include_root: Whether to include the root module + + Returns: + The modified module + """ + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + full_name = f"{name}.{child_name}" if name else child_name + named_apply(fn=fn, module=child_module, name=full_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + """A chunk of transformer blocks for efficient processing.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through all blocks in the chunk.""" + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + """DINOv2 Vision Transformer implementation for anomaly detection.""" + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + ffn_bias: bool = True, + proj_bias: bool = True, + drop_path_rate: float = 0.0, + drop_path_uniform: bool = False, + init_values: float | None = None, # for layerscale: None or 0 => no layerscale + embed_layer: type = PatchEmbed, + act_layer: type = nn.GELU, + block_fn: type | partial[Block] = Block, + block_chunks: int = 1, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + ) -> None: + """Initialize the DINOv2 Vision Transformer. + + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + strict_img_size=False, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=DinomalyMLP, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunksize = depth // block_chunks + chunked_blocks = [[nn.Identity()] * i + blocks_list[i : i + chunksize] for i in range(0, depth, chunksize)] + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self) -> None: + """Initialize weights of the vision transformer.""" + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + """Interpolate positional encodings for different input sizes. + + Args: + x: Input tensor + w: Width of the input + h: Height of the input + + Returns: + Interpolated positional encodings + """ + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + n_pos = self.pos_embed.shape[1] - 1 + if npatch == n_pos and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0_float = float(w0) + self.interpolate_offset + h0_float = float(h0) + self.interpolate_offset + + sqrt_n = math.sqrt(n_pos) + sx, sy = w0_float / sqrt_n, h0_float / sqrt_n + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_n), int(sqrt_n), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0_float) == patch_pos_embed.shape[-2] + assert int(h0_float) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x: torch.Tensor, masks: torch.Tensor | None = None) -> torch.Tensor: + """Prepare tokens for transformer input with optional masking. + + Args: + x: Input tensor + masks: Optional mask tensor + + Returns: + Prepared tokens with positional encoding + """ + _batch_size, _num_channels, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def prepare_tokens(self, x: torch.Tensor, masks: torch.Tensor | None = None) -> torch.Tensor: + """Prepare tokens for transformer input. + + Args: + x: Input tensor + masks: Optional mask tensor + + Returns: + Prepared tokens with positional encoding + """ + _batch_size, _num_channels, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the vision transformer. + + Args: + x: Input tensor + + Returns: + CLS token output + """ + x = self.prepare_tokens_with_masks(x) + for block in self.blocks: + x = block(x) + x = self.norm(x) + return x[:, 0] + + def get_intermediate_layers(self, x: torch.Tensor, n: int = 1) -> list[torch.Tensor]: + """Get outputs from the last n transformer blocks. + + Args: + x: Input tensor + n: Number of last blocks to get outputs from + + Returns: + List of outputs from the last n blocks + """ + x = self.prepare_tokens_with_masks(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, block in enumerate(self.blocks): + x = block(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + def get_last_selfattention(self, x: torch.Tensor) -> torch.Tensor: + """Get self-attention from the last transformer block. + + Args: + x: Input tensor + + Returns: + Self-attention tensor from the last block + """ + x = self.prepare_tokens_with_masks(x) + for i, block in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = block(x) + else: + # return attention of the last block + return block(x, return_attention=True)[:, self.num_register_tokens + 1 :] + + # This should never be reached, but added for type safety + msg = "No blocks found in the transformer" + raise RuntimeError(msg) + + def get_all_selfattention(self, x: torch.Tensor) -> list[torch.Tensor]: + """Get self-attention matrices from every transformer layer. + + Args: + x: Input tensor + + Returns: + List of self-attention tensors from all layers + """ + """Get a self-attention matrix from every layer.""" + x = self.prepare_tokens_with_masks(x) + attns = [] + + for block in self.blocks: + attn = block(x, return_attention=True) + attn = torch.cat([attn[:, :, :1, :], attn[:, :, self.num_register_tokens + 1 :, :]], dim=2) + attn = torch.cat([attn[:, :, :, :1], attn[:, :, :, self.num_register_tokens + 1 :]], dim=3) + + attns.append(attn) + x = block(x) + + return attns + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: # noqa: ARG001 + """ViT weight initialization, original timm impl (for reproducibility).""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size: int = 16, num_register_tokens: int = 0, **kwargs) -> DinoVisionTransformer: + """Create a small Vision Transformer model. + + Args: + patch_size: Size of image patches + num_register_tokens: Number of register tokens + **kwargs: Additional arguments + + Returns: + DinoVisionTransformer model instance + """ + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + +def vit_base(patch_size: int = 16, num_register_tokens: int = 0, **kwargs) -> DinoVisionTransformer: + """Create a base Vision Transformer model. + + Args: + patch_size: Size of image patches + num_register_tokens: Number of register tokens + **kwargs: Additional arguments + + Returns: + DinoVisionTransformer model instance + """ + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + +def vit_large(patch_size: int = 16, num_register_tokens: int = 0, **kwargs) -> DinoVisionTransformer: + """Create a large Vision Transformer model. + + Args: + patch_size: Size of image patches + num_register_tokens: Number of register tokens + **kwargs: Additional arguments + + Returns: + DinoVisionTransformer model instance + """ + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + +def vit_giant2(patch_size: int = 16, num_register_tokens: int = 0, **kwargs) -> DinoVisionTransformer: + """Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64. + + Args: + patch_size: Size of image patches + num_register_tokens: Number of register tokens + **kwargs: Additional arguments + + Returns: + DinoVisionTransformer model instance + """ + return DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) diff --git a/src/anomalib/models/image/dinomaly/lightning_model.py b/src/anomalib/models/image/dinomaly/lightning_model.py new file mode 100644 index 0000000000..24b4a82f5e --- /dev/null +++ b/src/anomalib/models/image/dinomaly/lightning_model.py @@ -0,0 +1,440 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Dinomaly: Vision Transformer-based Anomaly Detection with Feature Reconstruction. + +This module implements the Dinomaly model for anomaly detection using a Vision Transformer +encoder-decoder architecture. The model leverages pre-trained DINOv2 features and employs +a reconstruction-based approach to detect anomalies by comparing encoder and decoder features. + +Dinomaly extracts features from multiple intermediate layers of a DINOv2 Vision Transformer, +compresses them through a bottleneck MLP, and reconstructs them using a Vision Transformer +decoder. Anomaly detection is performed by computing cosine similarity between encoder +and decoder features at multiple scales. + +The model is particularly effective for visual anomaly detection tasks where the goal is +to identify regions or images that deviate from normal patterns learned during training. + +Example: + >>> from anomalib.data import MVTecAD + >>> from anomalib.models import Dinomaly + >>> from anomalib.engine import Engine + + >>> datamodule = MVTecAD() + >>> model = Dinomaly() + >>> engine = Engine() + + >>> engine.fit(model, datamodule=datamodule) # doctest: +SKIP + >>> predictions = engine.predict(model, datamodule=datamodule) # doctest: +SKIP + +Notes: + - The model uses DINOv2 Vision Transformer as the backbone encoder + - Features are extracted from intermediate layers (typically layers 2-9 for base models) + - A bottleneck MLP compresses multi-layer features before reconstruction + - Anomaly maps are computed using cosine similarity between encoder-decoder features + - The model supports both unsupervised anomaly detection and localization + +See Also: + :class:`anomalib.models.image.dinomaly.torch_model.DinomalyModel`: + PyTorch implementation of the Dinomaly model. +""" + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler +from torch.nn.init import trunc_normal_ +from torchvision.transforms.v2 import CenterCrop, Compose, Normalize, Resize + +from anomalib import LearningType +from anomalib.data import Batch +from anomalib.metrics import Evaluator +from anomalib.models.components import AnomalibModule +from anomalib.models.image.dinomaly.components import CosineHardMiningLoss, StableAdamW, WarmCosineScheduler +from anomalib.models.image.dinomaly.torch_model import DinomalyModel +from anomalib.post_processing import PostProcessor +from anomalib.pre_processing import PreProcessor +from anomalib.visualization import Visualizer + +logger = logging.getLogger(__name__) + +# Training constants +DEFAULT_IMAGE_SIZE = 448 +DEFAULT_CROP_SIZE = 392 +MAX_STEPS_DEFAULT = 5000 + +# Default Training hyperparameters +TRAINING_CONFIG: dict[str, Any] = { + "progressive_loss": { + "p_final": 0.9, + "p_schedule_steps": 1000, + "loss_factor": 0.1, + }, + "optimizer": { + "lr": 2e-3, + "betas": (0.9, 0.999), + "weight_decay": 1e-4, + "amsgrad": True, + "eps": 1e-8, + }, + "scheduler": { + "base_value": 2e-3, + "final_value": 2e-4, + "total_iters": MAX_STEPS_DEFAULT, + "warmup_iters": 100, + }, + "trainer": { + "gradient_clip_val": 0.1, + "num_sanity_val_steps": 0, + "max_steps": MAX_STEPS_DEFAULT, + }, +} + + +class Dinomaly(AnomalibModule): + """Dinomaly Lightning Module for Vision Transformer-based Anomaly Detection. + + This lightning module trains the Dinomaly anomaly detection model (DinomalyModel). + During training, the decoder learns to reconstruct normal features. + During inference, the trained decoder is expected to successfully reconstruct normal + regions of feature maps, but fail to reconstruct anomalous regions as + it has not seen such patterns. + + Args: + encoder_name (str): Name of the Vision Transformer encoder to use. + Supports DINOv2 variants (small, base, large) with different patch sizes. + Defaults to "dinov2reg_vit_base_14". + bottleneck_dropout (float): Dropout rate for the bottleneck MLP layer. + Helps prevent overfitting during feature compression. Defaults to 0.2. + decoder_depth (int): Number of Vision Transformer decoder layers. + More layers allow for more complex reconstruction. Defaults to 8. + target_layers (list[int] | None): List of encoder layer indices to extract + features from. If None, uses [2, 3, 4, 5, 6, 7, 8, 9] for base models + and [4, 6, 8, 10, 12, 14, 16, 18] for large models. + fuse_layer_encoder (list[list[int]] | None): Groupings of encoder layers + for feature fusion. If None, uses [[0, 1, 2, 3], [4, 5, 6, 7]]. + fuse_layer_decoder (list[list[int]] | None): Groupings of decoder layers + for feature fusion. If None, uses [[0, 1, 2, 3], [4, 5, 6, 7]]. + mask_neighbor_size (int): Size of neighborhood for attention masking in decoder. + Set to 0 to disable masking. Defaults to 0. + remove_class_token (bool): Whether to remove class token from features + before processing. Defaults to False. + pre_processor (PreProcessor | bool, optional): Pre-processor instance or + flag to use default. Defaults to ``True``. + post_processor (PostProcessor | bool, optional): Post-processor instance + or flag to use default. Defaults to ``True``. + evaluator (Evaluator | bool, optional): Evaluator instance or flag to use + default. Defaults to ``True``. + visualizer (Visualizer | bool, optional): Visualizer instance or flag to + use default. Defaults to ``True``. + + Example: + >>> from anomalib.data import MVTecAD + >>> from anomalib.models import Dinomaly + >>> + >>> # Basic usage with default parameters + >>> model = Dinomaly() + >>> + >>> # Custom configuration + >>> model = Dinomaly( + ... encoder_name="dinov2reg_vit_large_14", + ... decoder_depth=12, + ... bottleneck_dropout=0.1, + ... mask_neighbor_size=3 + ... ) + >>> + >>> # Training with datamodule + >>> datamodule = MVTecAD() + >>> engine = Engine() + >>> engine.fit(model, datamodule=datamodule) + + Note: + The model requires significant GPU memory due to the Vision Transformer + architecture. Consider using gradient checkpointing or smaller model + variants for memory-constrained environments. + """ + + def __init__( + self, + encoder_name: str = "dinov2reg_vit_base_14", + bottleneck_dropout: float = 0.2, + decoder_depth: int = 8, + target_layers: list[int] | None = None, + fuse_layer_encoder: list[list[int]] | None = None, + fuse_layer_decoder: list[list[int]] | None = None, + mask_neighbor_size: int = 0, + remove_class_token: bool = False, + pre_processor: PreProcessor | bool = True, + post_processor: PostProcessor | bool = True, + evaluator: Evaluator | bool = True, + visualizer: Visualizer | bool = True, + ) -> None: + super().__init__( + pre_processor=pre_processor, + post_processor=post_processor, + evaluator=evaluator, + visualizer=visualizer, + ) + + self.model: DinomalyModel = DinomalyModel( + encoder_name=encoder_name, + bottleneck_dropout=bottleneck_dropout, + decoder_depth=decoder_depth, + target_layers=target_layers, + fuse_layer_encoder=fuse_layer_encoder, + fuse_layer_decoder=fuse_layer_decoder, + mask_neighbor_size=mask_neighbor_size, + remove_class_token=remove_class_token, + ) + + # Set the trainable parameters for the model. + # Only the bottleneck and decoder parameters are trained. + + for param in self.model.parameters(): + param.requires_grad = False + # Unfreeze bottleneck and decoder + for param in self.model.bottleneck.parameters(): + param.requires_grad = True + for param in self.model.decoder.parameters(): + param.requires_grad = True + + self.trainable_modules = torch.nn.ModuleList([self.model.bottleneck, self.model.decoder]) + self._initialize_trainable_modules(self.trainable_modules) + + # Initialize the loss function + progressive_loss_config = TRAINING_CONFIG["progressive_loss"] + self.loss_fn = CosineHardMiningLoss( + p=progressive_loss_config["p_final"], # Will be updated dynamically during training + factor=progressive_loss_config["loss_factor"], + ) + + @classmethod + def configure_pre_processor( + cls, + image_size: tuple[int, int] | None = None, + crop_size: int | None = None, + ) -> PreProcessor: + """Configure the default pre-processor for Dinomaly. + + Sets up image preprocessing pipeline including resizing, center cropping, + and normalization with ImageNet statistics. The preprocessing is optimized + for DINOv2 Vision Transformer models. + + Args: + image_size (tuple[int, int] | None): Target size for image resizing + as (height, width). Defaults to (448, 448). + crop_size (int | None): Target size for center cropping (assumes square crop). + Should be smaller than image_size. Defaults to 392. + + Returns: + PreProcessor: Configured pre-processor with transforms for Dinomaly. + + Raises: + ValueError: If crop_size is larger than the minimum dimension of image_size. + + Note: + The default ImageNet normalization statistics are used: + - Mean: [0.485, 0.456, 0.406] + - Std: [0.229, 0.224, 0.225] + """ + crop_size = crop_size or DEFAULT_CROP_SIZE + image_size = image_size or (DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE) + + # Validate inputs + if crop_size > min(image_size): + msg = f"Crop size {crop_size} cannot be larger than image size {image_size}" + raise ValueError(msg) + + data_transforms = Compose([ + Resize(image_size), + CenterCrop(crop_size), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + return PreProcessor(transform=data_transforms) + + def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Training step for the Dinomaly model. + + Performs a single training iteration by computing feature reconstruction loss + between encoder and decoder features. Uses progressive cosine similarity loss + with hardest mining to focus training on difficult examples. + + Args: + batch (Batch): Input batch containing images and metadata. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Returns: + STEP_OUTPUT: Dictionary containing the computed loss value. + + Raises: + ValueError: If model output doesn't contain required features during training. + + Note: + The loss function uses progressive weight scheduling where the hardest + mining percentage increases from 0 to 0.9 over 1000 steps, focusing + on increasingly difficult examples as training progresses. + """ + del args, kwargs # These variables are not used. + model_output = self.model(batch.image) + en = model_output["encoder_features"] + de = model_output["decoder_features"] + # Progressive loss weight configuration + progressive_loss_config = TRAINING_CONFIG["progressive_loss"] + assert isinstance(progressive_loss_config, dict) + p_final = progressive_loss_config["p_final"] + p_schedule_steps = progressive_loss_config["p_schedule_steps"] + p = min(p_final * self.global_step / p_schedule_steps, p_final) + # Update the loss function's p parameter dynamically + self.loss_fn.p = p + loss = self.loss_fn(en, de) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train_p_schedule", p, on_step=True, on_epoch=False) + + return {"loss": loss} + + def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT: + """Validation step for the Dinomaly model. + + Performs inference on the validation batch to compute anomaly scores + and anomaly maps. The model operates in evaluation mode to generate + predictions for anomaly detection evaluation. + + Args: + batch (Batch): Input batch containing images and metadata. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Returns: + STEP_OUTPUT: Updated batch with pred_score (anomaly scores) and + anomaly_map (pixel-level anomaly maps) predictions. + + Raises: + Exception: If an error occurs during validation inference. + + Note: + During validation, the model returns InferenceBatch with anomaly + scores and maps computed from encoder-decoder feature comparisons. + """ + del args, kwargs # These variables are not used. + + predictions = self.model(batch.image) + return batch.update(pred_score=predictions.pred_score, anomaly_map=predictions.anomaly_map) + + def configure_optimizers(self) -> OptimizerLRScheduler: + """Configure optimizer and learning rate scheduler for Dinomaly training. + + Sets up the training configuration with frozen DINOv2 encoder and trainable + bottleneck and decoder components. Uses StableAdamW optimizer with warm + cosine learning rate scheduling. + + The total number of training steps is determined dynamically from the trainer + configuration, supporting both max_steps and max_epochs settings. + + Returns: + OptimizerLRScheduler: Tuple containing optimizer and scheduler configurations. + + Raises: + ValueError: If neither max_epochs nor max_steps is defined. + + Note: + - DINOv2 encoder parameters are frozen to preserve pre-trained features + - Only bottleneck MLP and decoder parameters are trained + - Uses truncated normal initialization for Linear layers + - Learning rate schedule: warmup (100 steps) + cosine decay + - Base learning rate: 2e-3, final learning rate: 2e-4 + - Total steps determined from trainer's max_steps or max_epochs + """ + # Determine total training steps dynamically from trainer configuration + # Check if the trainer has valid max_epochs and max_steps set + max_epochs = getattr(self.trainer, "max_epochs", -1) + max_steps = getattr(self.trainer, "max_steps", -1) + + if max_epochs is None: + max_epochs = -1 + if max_steps is None: + max_steps = -1 + + if max_epochs < 0 and max_steps < 0: + msg = "A finite number of steps or epochs must be defined" + raise ValueError(msg) + + if max_epochs < 0: + # max_epochs not set, use max_steps directly + total_steps = max_steps + elif max_steps < 0: + # max_steps not set, calculate from max_epochs + total_steps = max_epochs * len(self.trainer.datamodule.train_dataloader()) + else: + # Both are set, use the minimum (training stops at whichever comes first) + total_steps = min(max_steps, max_epochs * len(self.trainer.datamodule.train_dataloader())) + + optimizer_config = TRAINING_CONFIG["optimizer"] + assert isinstance(optimizer_config, dict) + optimizer = StableAdamW([{"params": self.trainable_modules.parameters()}], **optimizer_config) + + # Create a scheduler config with dynamically determined total steps + scheduler_config = TRAINING_CONFIG["scheduler"].copy() + assert isinstance(scheduler_config, dict) + scheduler_config["total_iters"] = total_steps + + lr_scheduler = WarmCosineScheduler(optimizer, **scheduler_config) + + return [optimizer], [lr_scheduler] + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Dinomaly is an unsupervised anomaly detection model that learns normal + data patterns without requiring anomaly labels during training. + + Returns: + LearningType: Always returns LearningType.ONE_CLASS for unsupervised learning. + + Note: + This property may be subject to change if supervised training support + is introduced in future versions. + """ + return LearningType.ONE_CLASS + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return Dinomaly-specific trainer arguments. + + Provides configuration arguments optimized for Dinomaly training, + excluding max_steps to allow users to set their own training duration. + + Returns: + dict[str, Any]: Dictionary of trainer arguments with strategy + configuration for optimal training performance. Does not include + max_steps so it can be set by the engine or user. + + Note: + Uses DDPStrategy when available for multi-GPU training to improve + training efficiency for the Vision Transformer architecture. + The max_steps is intentionally excluded to allow user override. + """ + trainer_config = TRAINING_CONFIG["trainer"].copy() + assert isinstance(trainer_config, dict) + # Remove max_steps to allow user override + trainer_config.pop("max_steps", None) + return trainer_config + + @staticmethod + def _initialize_trainable_modules(trainable_modules: torch.nn.ModuleList) -> None: + """Initialize trainable modules with truncated normal initialization. + + Args: + trainable_modules: ModuleList containing modules to initialize + """ + for m in trainable_modules.modules(): + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=0.01, a=-0.03, b=0.03) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) diff --git a/src/anomalib/models/image/dinomaly/torch_model.py b/src/anomalib/models/image/dinomaly/torch_model.py new file mode 100644 index 0000000000..99568c11d4 --- /dev/null +++ b/src/anomalib/models/image/dinomaly/torch_model.py @@ -0,0 +1,564 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""PyTorch model for the Dinomaly model implementation. + +Based on PyTorch Implementation of "Dinomaly" by guojiajeremy +Reference: https://github.com/guojiajeremy/Dinomaly) +License: MIT + +See Also: + :class:`anomalib.models.image.dinomaly.lightning_model.Dinomaly`: + Dinomaly Lightning model. +""" + +import math +from functools import partial + +import torch +import torch.nn.functional as F # noqa: N812 +from timm.layers.drop import DropPath +from torch import nn + +from anomalib.data import InferenceBatch +from anomalib.models.image.dinomaly.components import DinomalyMLP, LinearAttention +from anomalib.models.image.dinomaly.components import load as load_dinov2_model + +# Encoder architecture configurations for DINOv2 models. +# The target layers are the +DINOV2_ARCHITECTURES = { + "small": {"embed_dim": 384, "num_heads": 6, "target_layers": [2, 3, 4, 5, 6, 7, 8, 9]}, + "base": {"embed_dim": 768, "num_heads": 12, "target_layers": [2, 3, 4, 5, 6, 7, 8, 9]}, + "large": {"embed_dim": 1024, "num_heads": 16, "target_layers": [4, 6, 8, 10, 12, 14, 16, 18]}, +} + +# Default fusion layer configurations +# Instead of comparing layer to layer between encoder and decoder, dinomaly uses +# layer groups to fuse features from multiple layers. +# By Default, the first 4 layers and the last 4 layers are fused. +# Note that these are the layer indices of the encoder and decoder layers used for feature extraction. +DEFAULT_FUSE_LAYERS = [[0, 1, 2, 3], [4, 5, 6, 7]] + +# Default values for inference processing +DEFAULT_RESIZE_SIZE = 256 +DEFAULT_GAUSSIAN_KERNEL_SIZE = 5 +DEFAULT_GAUSSIAN_SIGMA = 4 +DEFAULT_MAX_RATIO = 0.01 + +# Transformer architecture constants +TRANSFORMER_CONFIG: dict[str, float | bool] = { + "mlp_ratio": 4.0, + "layer_norm_eps": 1e-8, + "qkv_bias": True, + "attn_drop": 0.0, +} + + +class DinomalyModel(nn.Module): + """DinomalyModel: Vision Transformer-based anomaly detection model from Dinomaly. + + This is a Vision Transformer-based anomaly detection model that uses an encoder-bottleneck-decoder + architecture for feature reconstruction. + + The architecture comprises three main components: + + An Encoder: A pre-trained Vision Transformer (ViT), by default a ViT-Base/14 based dinov2-reg model which + extracts universal and discriminative features from input images. + + Bottleneck: A simple MLP that collects feature representations from the encoder's middle-level layers. + + Decoder: Composed of Transformer layers (by default 8 layers), it learns to reconstruct the middle-level features. + + Args: + encoder_name (str): Name of the Vision Transformer encoder to use. + Supports DINOv2 variants like "dinov2reg_vit_base_14". + Defaults to "dinov2reg_vit_base_14". + bottleneck_dropout (float): Dropout rate for the bottleneck MLP layer. + Defaults to 0.2. + decoder_depth (int): Number of Vision Transformer decoder layers. + Defaults to 8. + target_layers (list[int] | None): List of encoder layer indices to extract features from. + If None, uses [2, 3, 4, 5, 6, 7, 8, 9] for base models. + For large models, uses [4, 6, 8, 10, 12, 14, 16, 18]. + fuse_layer_encoder (list[list[int]] | None): Layer groupings for encoder feature fusion. + If None, uses [[0, 1, 2, 3], [4, 5, 6, 7]]. + fuse_layer_decoder (list[list[int]] | None): Layer groupings for decoder feature fusion. + If None, uses [[0, 1, 2, 3], [4, 5, 6, 7]]. + mask_neighbor_size (int): Size of neighborhood masking for attention. + Set to 0 to disable masking. Defaults to 0. + remove_class_token (bool): Whether to remove class token from features + before processing. Defaults to False. + + Example: + >>> model = DinomalyModel( + ... encoder_name="dinov2reg_vit_base_14", + ... decoder_depth=8, + ... bottleneck_dropout=0.2 + ... ) + >>> features = model(batch) + """ + + def __init__( + self, + encoder_name: str = "dinov2reg_vit_base_14", + bottleneck_dropout: float = 0.2, + decoder_depth: int = 8, + target_layers: list[int] | None = None, + fuse_layer_encoder: list[list[int]] | None = None, + fuse_layer_decoder: list[list[int]] | None = None, + mask_neighbor_size: int = 0, + remove_class_token: bool = False, + ) -> None: + super().__init__() + + if target_layers is None: + # 8 middle layers of the encoder are used for feature extraction. + target_layers = [2, 3, 4, 5, 6, 7, 8, 9] + + # Instead of comparing layer to layer between encoder and decoder, dinomaly uses + # layer groups to fuse features from multiple layers. + if fuse_layer_encoder is None: + fuse_layer_encoder = DEFAULT_FUSE_LAYERS + if fuse_layer_decoder is None: + fuse_layer_decoder = DEFAULT_FUSE_LAYERS + + encoder = load_dinov2_model(encoder_name) + + # Extract architecture configuration based on model name + arch_config = self._get_architecture_config(encoder_name, target_layers) + embed_dim = arch_config["embed_dim"] + num_heads = arch_config["num_heads"] + target_layers = arch_config["target_layers"] + + # Add validation + if decoder_depth <= 1: + msg = f"decoder_depth must be greater than 1, got {decoder_depth}" + raise ValueError(msg) + + bottleneck = [] + bottle_neck_mlp = DinomalyMLP( + in_features=embed_dim, + hidden_features=embed_dim * 4, + out_features=embed_dim, + act_layer=nn.GELU, + drop=bottleneck_dropout, + bias=False, + apply_input_dropout=True, # Apply dropout to input + ) + bottleneck.append(bottle_neck_mlp) + bottleneck = nn.ModuleList(bottleneck) + + decoder = [] + for _ in range(decoder_depth): + # Extract and validate config values for type safety + mlp_ratio_val = TRANSFORMER_CONFIG["mlp_ratio"] + assert isinstance(mlp_ratio_val, float) + qkv_bias_val = TRANSFORMER_CONFIG["qkv_bias"] + assert isinstance(qkv_bias_val, bool) + layer_norm_eps_val = TRANSFORMER_CONFIG["layer_norm_eps"] + assert isinstance(layer_norm_eps_val, float) + attn_drop_val = TRANSFORMER_CONFIG["attn_drop"] + assert isinstance(attn_drop_val, float) + + decoder_block = DecoderViTBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio_val, + qkv_bias=qkv_bias_val, + norm_layer=partial(nn.LayerNorm, eps=layer_norm_eps_val), # type: ignore[arg-type] + attn_drop=attn_drop_val, + attn=LinearAttention, + ) + decoder.append(decoder_block) + decoder = nn.ModuleList(decoder) + + self.encoder = encoder + self.bottleneck = bottleneck + self.decoder = decoder + self.target_layers = target_layers + self.fuse_layer_encoder = fuse_layer_encoder + self.fuse_layer_decoder = fuse_layer_decoder + self.remove_class_token = remove_class_token + + if not hasattr(self.encoder, "num_register_tokens"): + self.encoder.num_register_tokens = 0 + self.mask_neighbor_size = mask_neighbor_size + + def get_encoder_decoder_outputs(self, x: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Extract and process features through encoder and decoder. + + This method processes input images through the DINOv2 encoder to extract + features from target layers, fuses them through a bottleneck MLP, and + reconstructs them using the decoder. Features are reshaped for spatial + anomaly map computation. + + Args: + x (torch.Tensor): Input images with shape (B, C, H, W). + + Returns: + tuple[list[torch.Tensor], list[torch.Tensor]]: Tuple containing: + - en: List of fused encoder features reshaped to spatial dimensions + - de: List of fused decoder features reshaped to spatial dimensions + """ + x = self.encoder.prepare_tokens(x) + + encoder_features = [] + decoder_features = [] + + for i, block in enumerate(self.encoder.blocks): + if i <= self.target_layers[-1]: + with torch.no_grad(): + x = block(x) + else: + continue + if i in self.target_layers: + encoder_features.append(x) + side = int(math.sqrt(encoder_features[0].shape[1] - 1 - self.encoder.num_register_tokens)) + + if self.remove_class_token: + encoder_features = [e[:, 1 + self.encoder.num_register_tokens :, :] for e in encoder_features] + + x = self._fuse_feature(encoder_features) + for _i, block in enumerate(self.bottleneck): + x = block(x) + + attn_mask = self.generate_mask(side, x.device) if self.mask_neighbor_size > 0 else None + + for _i, block in enumerate(self.decoder): + x = block(x, attn_mask=attn_mask) + decoder_features.append(x) + decoder_features = decoder_features[::-1] + + en = [self._fuse_feature([encoder_features[idx] for idx in idxs]) for idxs in self.fuse_layer_encoder] + de = [self._fuse_feature([decoder_features[idx] for idx in idxs]) for idxs in self.fuse_layer_decoder] + + # Process features for spatial output + en = self._process_features_for_spatial_output(en, side) + de = self._process_features_for_spatial_output(de, side) + return en, de + + def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: + """Forward pass of the Dinomaly model. + + During training, the model extracts features from the encoder and decoder + and returns them for loss computation. During inference, it computes + anomaly maps by comparing encoder and decoder features using cosine similarity, + applies Gaussian smoothing, and returns anomaly scores and maps. + + Args: + batch (torch.Tensor): Input batch of images with shape (B, C, H, W). + + Returns: + torch.Tensor | InferenceBatch: + - During training: Dictionary containing encoder and decoder features + for loss computation. + - During inference: InferenceBatch with pred_score (anomaly scores) + and anomaly_map (pixel-level anomaly maps). + + """ + en, de = self.get_encoder_decoder_outputs(batch) + image_size = batch.shape[2] + + if self.training: + return {"encoder_features": en, "decoder_features": de} + + # If inference, calculate anomaly maps, predictions, from the encoder and decoder features. + anomaly_map, _ = self.calculate_anomaly_maps(en, de, out_size=image_size) + anomaly_map_resized = anomaly_map.clone() + + # Resize anomaly map for processing + if DEFAULT_RESIZE_SIZE is not None: + anomaly_map = F.interpolate(anomaly_map, size=DEFAULT_RESIZE_SIZE, mode="bilinear", align_corners=False) + + # Apply Gaussian smoothing + gaussian_kernel = get_gaussian_kernel( + kernel_size=DEFAULT_GAUSSIAN_KERNEL_SIZE, + sigma=DEFAULT_GAUSSIAN_SIGMA, + device=anomaly_map.device, + ) + anomaly_map = gaussian_kernel(anomaly_map) + + # Calculate anomaly score + if DEFAULT_MAX_RATIO == 0: + sp_score = torch.max(anomaly_map.flatten(1), dim=1)[0] + else: + anomaly_map_flat = anomaly_map.flatten(1) + sp_score = torch.sort(anomaly_map_flat, dim=1, descending=True)[0][ + :, + : int(anomaly_map_flat.shape[1] * DEFAULT_MAX_RATIO), + ] + sp_score = sp_score.mean(dim=1) + pred_score = sp_score + + return InferenceBatch(pred_score=pred_score, anomaly_map=anomaly_map_resized) + + @staticmethod + def calculate_anomaly_maps( + source_feature_maps: list[torch.Tensor], + target_feature_maps: list[torch.Tensor], + out_size: int | tuple[int, int] = 392, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Calculate anomaly maps by comparing encoder and decoder features. + + Computes pixel-level anomaly maps by calculating cosine similarity between + corresponding encoder (source) and decoder (target) feature maps. Lower + cosine similarity indicates a higher anomaly likelihood. + + Args: + source_feature_maps (list[torch.Tensor]): List of encoder feature maps + from different layer groups. + target_feature_maps (list[torch.Tensor]): List of decoder feature maps + from different layer groups. + out_size (int | tuple[int, int]): Output size for anomaly maps. + Defaults to 392. + + Returns: + tuple[torch.Tensor, list[torch.Tensor]]: Tuple containing: + - anomaly_map: Combined anomaly map averaged across all feature scales + - anomaly_map_list: List of individual anomaly maps for each feature scale + """ + if not isinstance(out_size, tuple): + out_size = (out_size, out_size) + + anomaly_map_list = [] + for i in range(len(target_feature_maps)): + fs = source_feature_maps[i] + ft = target_feature_maps[i] + a_map = 1 - F.cosine_similarity(fs, ft) + a_map = torch.unsqueeze(a_map, dim=1) + a_map = F.interpolate(a_map, size=out_size, mode="bilinear", align_corners=True) + anomaly_map_list.append(a_map) + anomaly_map = torch.cat(anomaly_map_list, dim=1).mean(dim=1, keepdim=True) + return anomaly_map, anomaly_map_list + + @staticmethod + def _fuse_feature(feat_list: list[torch.Tensor]) -> torch.Tensor: + """Fuse multiple feature tensors by averaging. + + Takes a list of feature tensors and computes their element-wise average + to create a fused representation. + + Args: + feat_list (list[torch.Tensor]): List of feature tensors to fuse. + + Returns: + torch.Tensor: Averaged feature tensor. + + """ + return torch.stack(feat_list, dim=1).mean(dim=1) + + def generate_mask(self, feature_size: int, device: str = "cuda") -> torch.Tensor: + """Generate attention mask for neighborhood masking in decoder. + + Creates a square attention mask that restricts attention to local neighborhoods + for each spatial position. This helps the decoder focus on local patterns + during reconstruction. + + Args: + feature_size (int): Size of the feature map (assumes square features). + device (str): Device to create the mask on. Defaults to 'cuda'. + + Returns: + torch.Tensor: Attention mask with shape (H, W, H, W) where masked + positions are filled with 0 and unmasked positions with 1. + + Note: + The mask size is controlled by self.mask_neighbor_size. Set to 0 + to disable masking. + """ + # Height and width are same, but still used separately for clarity. + height, width = feature_size, feature_size + hm, wm = self.mask_neighbor_size, self.mask_neighbor_size + + # Vectorized approach - much faster than nested loops + idx_h1 = torch.arange(height, device=device).view(-1, 1, 1, 1) + idx_w1 = torch.arange(width, device=device).view(1, -1, 1, 1) + idx_h2 = torch.arange(height, device=device).view(1, 1, -1, 1) + idx_w2 = torch.arange(width, device=device).view(1, 1, 1, -1) + + # Compute distances and apply neighborhood masking + mask_h = (idx_h1 - idx_h2).abs() > hm // 2 + mask_w = (idx_w1 - idx_w2).abs() > wm // 2 + mask = (mask_h | mask_w).float() # Direct to float, avoiding double negation + mask = mask.view(height * width, height * width) + + if self.remove_class_token: + return mask + mask_all = torch.ones( + height * width + 1 + self.encoder.num_register_tokens, + height * width + 1 + self.encoder.num_register_tokens, + device=device, + ) + mask_all[1 + self.encoder.num_register_tokens :, 1 + self.encoder.num_register_tokens :] = mask + return mask_all + + @staticmethod + def _get_architecture_config(encoder_name: str, target_layers: list[int] | None) -> dict: + """Get architecture configuration based on model name. + + Args: + encoder_name: Name of the encoder model + target_layers: Override target layers if provided + + Returns: + Dictionary containing embed_dim, num_heads, and target_layers + """ + for arch_name, config in DINOV2_ARCHITECTURES.items(): + if arch_name in encoder_name: + result = config.copy() + # Override target_layers if explicitly provided + if target_layers is not None: + result["target_layers"] = target_layers + return result + + msg = f"Architecture not supported. Encoder name must contain one of {list(DINOV2_ARCHITECTURES.keys())}" + raise ValueError(msg) + + def _process_features_for_spatial_output( + self, + features: list[torch.Tensor], + side: int, + ) -> list[torch.Tensor]: + """Process features for spatial output by removing tokens and reshaping. + + Args: + features: List of feature tensors + side: Side length for spatial reshaping + + Returns: + List of processed feature tensors with spatial dimensions + """ + # Remove class token and register tokens if not already removed + if not self.remove_class_token: + features = [f[:, 1 + self.encoder.num_register_tokens :, :] for f in features] + + # Reshape to spatial dimensions + batch_size = features[0].shape[0] + return [f.permute(0, 2, 1).reshape([batch_size, -1, side, side]).contiguous() for f in features] + + +def get_gaussian_kernel( + kernel_size: int = 3, + sigma: int = 2, + channels: int = 1, + device: torch.device | None = None, +) -> torch.nn.Conv2d: + """Create a Gaussian smoothing Conv2d layer. + + Args: + kernel_size (int): Size of the Gaussian kernel. Defaults to 3. + sigma (int): Standard deviation of the Gaussian distribution. Defaults to 2. + channels (int): Number of channels for the filter. Defaults to 1. + device (torch.device | None): Device to place the kernel on. Defaults to None. + + Returns: + torch.nn.Conv2d: Depthwise Conv2d layer with Gaussian weights for smoothing. + """ + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_coord = torch.arange(kernel_size) + x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() + + mean = (kernel_size - 1) / 2.0 + variance = sigma**2.0 + + # Calculate the 2-dimensional gaussian kernel, which is + # the product of two gaussian distributions for two different + # variables (in this case called x and y) + gaussian_kernel = (1.0 / (2.0 * math.pi * variance)) * torch.exp( + -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2 * variance), + ) + + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + + # Reshape to 2d depthwise convolutional weight + gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) + gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) + + gaussian_filter = torch.nn.Conv2d( + in_channels=channels, + out_channels=channels, + kernel_size=kernel_size, + groups=channels, + bias=False, + padding=kernel_size // 2, + ) + + gaussian_filter.weight.data = gaussian_kernel + gaussian_filter.weight.requires_grad = False + + if device is not None: + gaussian_filter = gaussian_filter.to(device) + + return gaussian_filter + + +class DecoderViTBlock(nn.Module): + """Vision Transformer decoder block with attention and MLP layers.""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float | None = None, + qkv_bias: bool | None = None, + qk_scale: float | None = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: type[nn.Module] = nn.GELU, + norm_layer: type[nn.Module] = nn.LayerNorm, + attn: type[nn.Module] = LinearAttention, + ) -> None: + super().__init__() + + # Use default values from TRANSFORMER_CONFIG if not provided + mlp_ratio_config = TRANSFORMER_CONFIG["mlp_ratio"] + assert isinstance(mlp_ratio_config, float) + mlp_ratio = mlp_ratio if mlp_ratio is not None else mlp_ratio_config + + qkv_bias_config = TRANSFORMER_CONFIG["qkv_bias"] + assert isinstance(qkv_bias_config, bool) + qkv_bias = qkv_bias if qkv_bias is not None else qkv_bias_config + + attn_drop_config = TRANSFORMER_CONFIG["attn_drop"] + assert isinstance(attn_drop_config, float) + attn_drop = attn_drop if attn_drop is not None else attn_drop_config + + self.norm1 = norm_layer(dim) + self.attn = attn( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = DinomalyMLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + out_features=dim, + act_layer=act_layer, + drop=drop, + apply_input_dropout=False, + bias=False, + ) + + def forward( + self, + x: torch.Tensor, + return_attention: bool = False, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """Forward pass through decoder block.""" + if attn_mask is not None: + y, attn = self.attn(self.norm1(x), attn_mask=attn_mask) + else: + y, attn = self.attn(self.norm1(x)) + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + if return_attention: + return x, attn + return x