Skip to content

πŸ”„ refactor(model): memory usage optimisation #2813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
8b4fca4
inital commit: Don't re-init buffer, inplace fill
alfieroddan Jul 2, 2025
0c603c3
implement patchcore memory savings
alfieroddan Jul 2, 2025
73fd2ee
deprecation warning for internal subsample_embedding of patchcore
alfieroddan Jul 3, 2025
f2dca31
PaDiM gpu memory usage reduction
alfieroddan Jul 3, 2025
718b6ea
Merge branch 'open-edge-platform:main' into refactor/model/optimise-m…
alfieroddan Jul 3, 2025
0275ce9
refactor memory bank tensor assignment for readability. dfkdde refactor
alfieroddan Jul 3, 2025
a0bdbf2
unify memory bank models trainer arguments. refactor dfm to fit new m…
alfieroddan Jul 3, 2025
a386732
bugfix: padim device count set to zero
alfieroddan Jul 3, 2025
9e5adf8
Merge branch 'main' into refactor/model/optimise-memorybank-memory_gp…
samet-akcay Jul 4, 2025
572dc5e
ensure buffer is not replaced but instead resized and filled
alfieroddan Jul 4, 2025
183f630
Merge branch 'refactor/model/optimise-memorybank-memory_gpu_usage' of…
alfieroddan Jul 4, 2025
c714b85
give memory bank type
alfieroddan Jul 7, 2025
16fabb3
revert memory bank mixin back
alfieroddan Jul 7, 2025
b292713
to memory bank type and device
alfieroddan Jul 7, 2025
43a4898
Merge branch 'main' into refactor/model/optimise-memorybank-memory_gp…
alfieroddan Jul 10, 2025
585c81a
update return type, add comment about deprecation, duplicated copyright
alfieroddan Jul 11, 2025
9267dd1
new deprecation wrapper, now handles arg deprecation
alfieroddan Jul 15, 2025
d56a7bc
add flexible tests for new deprecation warning
alfieroddan Jul 15, 2025
9072bbd
update warnings to deprecations. Resolve comments re-fit vs fit_guass…
alfieroddan Jul 15, 2025
f123996
add type hint for dfkde
alfieroddan Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/anomalib/models/image/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
visualizer=visualizer,
)

self.model = DfkdeModel(
self.model: DfkdeModel = DfkdeModel(
layers=layers,
backbone=backbone,
pre_trained=pre_trained,
Expand All @@ -113,8 +113,6 @@ def __init__(
max_training_points=max_training_points,
)

self.embeddings: list[torch.Tensor] = []

@staticmethod
def configure_optimizers() -> None: # pylint: disable=arguments-differ
"""DFKDE doesn't require optimization, therefore returns no optimizers."""
Expand All @@ -133,18 +131,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
"""
del args, kwargs # These variables are not used.

embedding = self.model(batch.image)
self.embeddings.append(embedding)
_ = self.model(batch.image)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit KDE model to collected embeddings from the training set."""
embeddings = torch.vstack(self.embeddings)

logger.info("Fitting a KDE model to the embedding collected from the training set.")
self.model.classifier.fit(embeddings)
self.model.fit()

def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
"""Perform validation by computing anomaly scores.
Expand All @@ -167,9 +162,13 @@ def trainer_arguments(self) -> dict[str, Any]:
"""Get DFKDE-specific trainer arguments.

Returns:
dict[str, Any]: Dictionary of trainer arguments.
dict[str, Any]: Trainer arguments
- ``gradient_clip_val``: ``0`` (no gradient clipping needed)
- ``max_epochs``: ``1`` (single pass through training data)
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
- ``devices``: ``1`` (only single gpu supported)
"""
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0}
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1}

@property
def learning_type(self) -> LearningType:
Expand Down
27 changes: 27 additions & 0 deletions src/anomalib/models/image/dfkde/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
feature_scaling_method=feature_scaling_method,
max_training_points=max_training_points,
)
self.memory_bank = torch.empty(0)

def get_features(self, batch: torch.Tensor) -> torch.Tensor:
"""Extract features from the pre-trained backbone network.
Expand Down Expand Up @@ -141,8 +142,34 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch:
# 1. apply feature extraction
features = self.get_features(batch)
if self.training:
if self.memory_bank.size(0) == 0:
self.memory_bank = features
else:
new_bank = torch.cat((self.memory_bank, features), dim=0).to(self.memory_bank)
self.memory_bank = new_bank
return features

# 2. apply density estimation
scores = self.classifier(features)
return InferenceBatch(pred_score=scores)

def fit(self) -> None:
"""Fits the classifier using the current contents of the memory bank.

This method is typically called after the memory bank has been populated
during training.

After fitting, the memory bank is cleared to reduce GPU memory usage.

Raises:
ValueError: If the memory bank is empty.
"""
if self.memory_bank.size(0) == 0:
msg = "Memory bank is empty. Cannot perform coreset selection."
raise ValueError(msg)

# fit gaussian
self.classifier.fit(self.memory_bank)

# clear memory bank, redcues gpu size
self.memory_bank = torch.empty(0).to(self.memory_bank)
20 changes: 8 additions & 12 deletions src/anomalib/models/image/dfm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def __init__(
n_comps=pca_level,
score_type=score_type,
)
self.embeddings: list[torch.Tensor] = []
self.score_type = score_type

@staticmethod
Expand All @@ -137,8 +136,7 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
"""
del args, kwargs # These variables are not used.

embedding = self.model.get_features(batch.image).squeeze()
self.embeddings.append(embedding)
_ = self.model(batch.image)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)
Expand All @@ -149,11 +147,8 @@ def fit(self) -> None:
The method aggregates embeddings collected during training and fits
both the PCA transformation and Gaussian model used for scoring.
"""
logger.info("Aggregating the embedding extracted from the training set.")
embeddings = torch.vstack(self.embeddings)

logger.info("Fitting a PCA and a Gaussian model to dataset.")
self.model.fit(embeddings)
self.model.fit()

def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
"""Compute predictions for the input batch during validation.
Expand All @@ -176,12 +171,13 @@ def trainer_arguments(self) -> dict[str, Any]:
"""Get DFM-specific trainer arguments.

Returns:
dict[str, Any]: Dictionary of trainer arguments:
- ``gradient_clip_val`` (int): Disable gradient clipping
- ``max_epochs`` (int): Train for one epoch only
- ``num_sanity_val_steps`` (int): Skip validation sanity checks
dict[str, Any]: Trainer arguments
- ``gradient_clip_val``: ``0`` (no gradient clipping needed)
- ``max_epochs``: ``1`` (single pass through training data)
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
- ``devices``: ``1`` (only single gpu supported)
"""
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0}
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1}

@property
def learning_type(self) -> LearningType:
Expand Down
52 changes: 36 additions & 16 deletions src/anomalib/models/image/dfm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from anomalib.data import InferenceBatch
from anomalib.models.components import PCA, DynamicBufferMixin, TimmFeatureExtractor
from anomalib.utils.deprecation import deprecate


class SingleClassGaussian(DynamicBufferMixin):
Expand Down Expand Up @@ -153,18 +154,29 @@ def __init__(
layers=[layer],
).eval()

def fit(self, dataset: torch.Tensor) -> None:
self.memory_bank = torch.empty(0)

@deprecate(args={"dataset": None}, since="2.1.0")
def fit(self, dataset: torch.Tensor = None) -> None:
"""Fit PCA and Gaussian model to dataset.

Args:
dataset (torch.Tensor): Input dataset with shape
``(n_samples, n_features)``.
dataset (torch.Tensor, optional): **[DEPRECATED]**
Input dataset with shape ``(n_samples, n_features)``.
This argument is deprecated and will be ignored. Instead,
the method now uses `self.memory_bank` as the input dataset.
"""
self.pca_model.fit(dataset)
if dataset is not None:
del dataset

self.pca_model.fit(self.memory_bank)
if self.score_type == "nll":
features_reduced = self.pca_model.transform(dataset)
features_reduced = self.pca_model.transform(self.memory_bank)
self.gaussian_model.fit(features_reduced.T)

# clear memory bank, reduces GPU size
self.memory_bank = torch.empty(0).to(self.memory_bank)

def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor:
"""Compute anomaly scores.

Expand Down Expand Up @@ -194,25 +206,24 @@ def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor:

return (score, None) if self.score_type == "nll" else (score, score_map)

def get_features(self, batch: torch.Tensor) -> torch.Tensor:
def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Size]:
"""Extract features from the pretrained network.

Args:
batch (torch.Tensor): Input images with shape
``(batch_size, channels, height, width)``.

Returns:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Size]]: Features during
training, or tuple of (features, feature_shapes) during inference.
tuple of (features, feature_shapes).
"""
self.feature_extractor.eval()
features = self.feature_extractor(batch)[self.layer]
batch_size = len(features)
if self.pooling_kernel_size > 1:
features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size)
feature_shapes = features.shape
features = features.view(batch_size, -1).detach()
return features if self.training else (features, feature_shapes)
with torch.no_grad():
features = self.feature_extractor(batch)[self.layer]
batch_size = len(features)
if self.pooling_kernel_size > 1:
features = F.avg_pool2d(input=features, kernel_size=self.pooling_kernel_size)
feature_shapes = features.shape
features = features.view(batch_size, -1)
return features, feature_shapes

def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch:
"""Compute anomaly predictions from input images.
Expand All @@ -227,6 +238,15 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch:
``InferenceBatch`` with prediction scores and anomaly maps.
"""
feature_vector, feature_shapes = self.get_features(batch)

if self.training:
if self.memory_bank.size(0) == 0:
self.memory_bank = feature_vector
else:
new_bank = torch.cat((self.memory_bank, feature_vector), dim=0).to(self.memory_bank)
self.memory_bank = new_bank
return feature_vector

pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes)
if anomaly_map is not None:
anomaly_map = F.interpolate(anomaly_map, size=batch.shape[-2:], mode="bilinear", align_corners=False)
Expand Down
25 changes: 9 additions & 16 deletions src/anomalib/models/image/padim/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ def __init__(
n_features=n_features,
)

self.stats: list[torch.Tensor] = []
self.embeddings: list[torch.Tensor] = []

@staticmethod
def configure_optimizers() -> None:
"""PADIM doesn't require optimization, therefore returns no optimizers."""
Expand All @@ -154,19 +151,15 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
"""
del args, kwargs # These variables are not used.

embedding = self.model(batch.image)
self.embeddings.append(embedding)
_ = self.model(batch.image)

# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Fit a Gaussian to the embedding collected from the training set."""
logger.info("Aggregating the embedding extracted from the training set.")
embeddings = torch.vstack(self.embeddings)

logger.info("Fitting a Gaussian to the embedding collected from the training set.")
self.stats = self.model.gaussian.fit(embeddings)
self.model.fit()

def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
"""Perform a validation step of PADIM.
Expand All @@ -190,16 +183,16 @@ def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:

@property
def trainer_arguments(self) -> dict[str, int | float]:
"""Return PADIM trainer arguments.

Since the model does not require training, we limit the max_epochs to 1.
Since we need to run training epoch before validation, we also set the
sanity steps to 0.
"""Get default trainer arguments for Padim.

Returns:
dict[str, int | float]: Dictionary of trainer arguments
dict[str, Any]: Trainer arguments
- ``max_epochs``: ``1`` (single pass through training data)
- ``val_check_interval``: ``1.0`` (check validation every 1 step)
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
- ``devices``: ``1`` (only single gpu supported)
"""
return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0}
return {"max_epochs": 1, "val_check_interval": 1.0, "num_sanity_val_steps": 0, "devices": 1}

@property
def learning_type(self) -> LearningType:
Expand Down
26 changes: 26 additions & 0 deletions src/anomalib/models/image/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
self.anomaly_map_generator = AnomalyMapGenerator()

self.gaussian = MultiVariateGaussian()
self.memory_bank = torch.empty(0)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
"""Forward-pass image-batch (N, C, H, W) into model to extract features.
Expand Down Expand Up @@ -182,6 +183,11 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch:
embeddings = self.tiler.untile(embeddings)

if self.training:
if self.memory_bank.size(0) == 0:
self.memory_bank = embeddings
else:
new_bank = torch.cat((self.memory_bank, embeddings), dim=0).to(self.memory_bank)
self.memory_bank = new_bank
return embeddings

anomaly_map = self.anomaly_map_generator(
Expand Down Expand Up @@ -217,3 +223,23 @@ def generate_embedding(self, features: dict[str, torch.Tensor]) -> torch.Tensor:
# subsample embeddings
idx = self.idx.to(embeddings.device)
return torch.index_select(embeddings, 1, idx)

def fit(self) -> None:
"""Fits a Gaussian model to the current contents of the memory bank.

This method is typically called after the memory bank has been filled during training.

After fitting, the memory bank is cleared to free GPU memory before validation or testing.

Raises:
ValueError: If the memory bank is empty.
"""
if self.memory_bank.size(0) == 0:
msg = "Memory bank is empty. Cannot perform coreset selection."
raise ValueError(msg)

# fit gaussian
self.gaussian.fit(self.memory_bank)

# clear memory bank, redcues gpu usage
self.memory_bank = torch.empty(0).to(self.memory_bank)
16 changes: 5 additions & 11 deletions src/anomalib/models/image/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(
num_neighbors=num_neighbors,
)
self.coreset_sampling_ratio = coreset_sampling_ratio
self.embeddings: list[torch.Tensor] = []

@classmethod
def configure_pre_processor(
Expand Down Expand Up @@ -235,24 +234,18 @@ def training_step(self, batch: Batch, *args, **kwargs) -> None:
``fit()``.
"""
del args, kwargs # These variables are not used.

embedding = self.model(batch.image)
self.embeddings.append(embedding)
_ = self.model(batch.image)
# Return a dummy loss tensor
return torch.tensor(0.0, requires_grad=True, device=self.device)

def fit(self) -> None:
"""Apply subsampling to the embedding collected from the training set.

This method:
1. Aggregates embeddings from all training batches
2. Applies coreset subsampling to reduce memory requirements
1. Applies coreset subsampling to reduce memory requirements
"""
logger.info("Aggregating the embedding extracted from the training set.")
embeddings = torch.vstack(self.embeddings)

logger.info("Applying core-set subsampling to get the embedding.")
self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio)
self.model.subsample_embedding(self.coreset_sampling_ratio)

def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
"""Generate predictions for a batch of images.
Expand Down Expand Up @@ -286,8 +279,9 @@ def trainer_arguments(self) -> dict[str, Any]:
- ``gradient_clip_val``: ``0`` (no gradient clipping needed)
- ``max_epochs``: ``1`` (single pass through training data)
- ``num_sanity_val_steps``: ``0`` (skip validation sanity checks)
- ``devices``: ``1`` (only single gpu supported)
"""
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0}
return {"gradient_clip_val": 0, "max_epochs": 1, "num_sanity_val_steps": 0, "devices": 1}

@property
def learning_type(self) -> LearningType:
Expand Down
Loading