Skip to content

Commit d1f824a

Browse files
Fix normalization (#2130)
* fix normalization * precommit config... * reset normalization metrics on validation start * fix model loading and saving normalitzation metrics * Update src/anomalib/callbacks/normalization/min_max_normalization.py * Update src/anomalib/callbacks/normalization/min_max_normalization.py --------- Co-authored-by: Samet Akcay <samet.akcay@intel.com>
1 parent d094d4b commit d1f824a

File tree

5 files changed

+64
-30
lines changed

5 files changed

+64
-30
lines changed

src/anomalib/callbacks/normalization/min_max_normalization.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from lightning.pytorch import Trainer
1010
from lightning.pytorch.utilities.types import STEP_OUTPUT
11+
from torchmetrics import MetricCollection
1112

1213
from anomalib.metrics import MinMax
1314
from anomalib.models.components import AnomalyModule
@@ -27,13 +28,26 @@ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None =
2728
del trainer, stage # These variables are not used.
2829

2930
if not hasattr(pl_module, "normalization_metrics"):
30-
pl_module.normalization_metrics = MinMax().cpu()
31-
elif not isinstance(pl_module.normalization_metrics, MinMax):
32-
msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
33-
raise AttributeError(
34-
msg,
31+
pl_module.normalization_metrics = MetricCollection(
32+
{
33+
"anomaly_maps": MinMax().cpu(),
34+
"box_scores": MinMax().cpu(),
35+
"pred_scores": MinMax().cpu(),
36+
},
3537
)
3638

39+
elif not isinstance(pl_module.normalization_metrics, MetricCollection):
40+
msg = (
41+
f"Expected normalization_metrics to be of type MetricCollection"
42+
f"got {type(pl_module.normalization_metrics)}"
43+
)
44+
raise TypeError(msg)
45+
46+
for name, metric in pl_module.normalization_metrics.items():
47+
if not isinstance(metric, MinMax):
48+
msg = f"Expected normalization_metric {name} to be of type MinMax, got {type(metric)}"
49+
raise TypeError(msg)
50+
3751
def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
3852
"""Call when the test begins."""
3953
del trainer # `trainer` variable is not used.
@@ -42,6 +56,13 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
4256
if metric is not None:
4357
metric.set_threshold(0.5)
4458

59+
def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
60+
"""Call when the validation epoch begins."""
61+
del trainer # `trainer` variable is not used.
62+
63+
if hasattr(pl_module, "normalization_metrics"):
64+
pl_module.normalization_metrics.reset()
65+
4566
def on_validation_batch_end(
4667
self,
4768
trainer: Trainer,
@@ -55,14 +76,11 @@ def on_validation_batch_end(
5576
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
5677

5778
if "anomaly_maps" in outputs:
58-
pl_module.normalization_metrics(outputs["anomaly_maps"])
59-
elif "box_scores" in outputs:
60-
pl_module.normalization_metrics(torch.cat(outputs["box_scores"]))
61-
elif "pred_scores" in outputs:
62-
pl_module.normalization_metrics(outputs["pred_scores"])
63-
else:
64-
msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
65-
raise ValueError(msg)
79+
pl_module.normalization_metrics["anomaly_maps"](outputs["anomaly_maps"])
80+
if "box_scores" in outputs:
81+
pl_module.normalization_metrics["box_scores"](torch.cat(outputs["box_scores"]))
82+
if "pred_scores" in outputs:
83+
pl_module.normalization_metrics["pred_scores"](outputs["pred_scores"])
6684

6785
def on_test_batch_end(
6886
self,
@@ -97,12 +115,14 @@ def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: A
97115
"""Normalize a batch of predictions."""
98116
image_threshold = pl_module.image_threshold.value.cpu()
99117
pixel_threshold = pl_module.pixel_threshold.value.cpu()
100-
stats = pl_module.normalization_metrics.cpu()
101118
if "pred_scores" in outputs:
119+
stats = pl_module.normalization_metrics["pred_scores"].cpu()
102120
outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max)
103121
if "anomaly_maps" in outputs:
122+
stats = pl_module.normalization_metrics["anomaly_maps"].cpu()
104123
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max)
105124
if "box_scores" in outputs:
125+
stats = pl_module.normalization_metrics["box_scores"].cpu()
106126
outputs["box_scores"] = [
107127
normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"]
108128
]

src/anomalib/deploy/inferencers/base_inferencer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,19 @@ def _normalize(
101101
visualized and predicted scores.
102102
"""
103103
# min max normalization
104-
if "min" in metadata and "max" in metadata:
105-
if anomaly_maps is not None:
104+
if "pred_scores.min" in metadata and "pred_scores.max" in metadata:
105+
if anomaly_maps is not None and "anomaly_maps.max" in metadata:
106106
anomaly_maps = normalize_min_max(
107107
anomaly_maps,
108108
metadata["pixel_threshold"],
109-
metadata["min"],
110-
metadata["max"],
109+
metadata["anomaly_maps.min"],
110+
metadata["anomaly_maps.max"],
111111
)
112112
pred_scores = normalize_min_max(
113113
pred_scores,
114114
metadata["image_threshold"],
115-
metadata["min"],
116-
metadata["max"],
115+
metadata["pred_scores.min"],
116+
metadata["pred_scores.max"],
117117
)
118118

119119
return anomaly_maps, float(pred_scores)

src/anomalib/models/components/base/anomaly_module.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from lightning.pytorch.trainer.states import TrainerFn
1616
from lightning.pytorch.utilities.types import STEP_OUTPUT
1717
from torch import nn
18+
from torchmetrics import MetricCollection
1819
from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform
1920

2021
from anomalib import LearningType
@@ -25,7 +26,6 @@
2526

2627
if TYPE_CHECKING:
2728
from lightning.pytorch.callbacks import Callback
28-
from torchmetrics import Metric
2929

3030

3131
logger = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ def __init__(self) -> None:
4949
self.image_threshold: BaseThreshold
5050
self.pixel_threshold: BaseThreshold
5151

52-
self.normalization_metrics: Metric
52+
self.normalization_metrics: MetricCollection
5353

5454
self.image_metrics: AnomalibMetricCollection
5555
self.pixel_metrics: AnomalibMetricCollection
@@ -155,8 +155,9 @@ def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars:
155155
f"{self.pixel_threshold.__class__.__module__}.{self.pixel_threshold.__class__.__name__}"
156156
)
157157
if hasattr(self, "normalization_metrics"):
158-
normalization_class = self.normalization_metrics.__class__
159-
destination["normalization_class"] = f"{normalization_class.__module__}.{normalization_class.__name__}"
158+
for metric in self.normalization_metrics:
159+
metric_class = self.normalization_metrics[metric].__class__
160+
destination[f"{metric}_normalization_class"] = f"{metric_class.__module__}.{metric_class.__name__}"
160161

161162
return super()._save_to_state_dict(destination, prefix, keep_vars)
162163

@@ -166,8 +167,21 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True
166167
self.image_threshold = self._get_instance(state_dict, "image_threshold_class")
167168
if "pixel_threshold_class" in state_dict:
168169
self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class")
169-
if "normalization_class" in state_dict:
170-
self.normalization_metrics = self._get_instance(state_dict, "normalization_class")
170+
171+
if "anomaly_maps_normalization_class" in state_dict:
172+
self.anomaly_maps_normalization_metrics = self._get_instance(state_dict, "anomaly_maps_normalization_class")
173+
if "box_scores_normalization_class" in state_dict:
174+
self.box_scores_normalization_metrics = self._get_instance(state_dict, "box_scores_normalization_class")
175+
if "pred_scores_normalization_class" in state_dict:
176+
self.pred_scores_normalization_metrics = self._get_instance(state_dict, "pred_scores_normalization_class")
177+
178+
self.normalization_metrics = MetricCollection(
179+
{
180+
"anomaly_maps": self.anomaly_maps_normalization_metrics,
181+
"box_scores": self.box_scores_normalization_metrics,
182+
"pred_scores": self.pred_scores_normalization_metrics,
183+
},
184+
)
171185
# Used to load metrics if there is any related data in state_dict
172186
self._load_metrics(state_dict)
173187

src/anomalib/models/image/csflow/lightning_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs)
100100
"""
101101
del args, kwargs # These variables are not used.
102102

103-
anomaly_maps, anomaly_scores = self.model(batch["image"])
104-
batch["anomaly_maps"] = anomaly_maps
105-
batch["pred_scores"] = anomaly_scores
103+
output = self.model(batch["image"])
104+
batch["anomaly_maps"] = output["anomaly_map"]
105+
batch["pred_scores"] = output["pred_score"]
106106
return batch
107107

108108
@property

src/anomalib/models/image/csflow/torch_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
588588
z_dist, _ = self.graph(features) # Ignore Jacobians
589589
anomaly_scores = self._compute_anomaly_scores(z_dist)
590590
anomaly_maps = self.anomaly_map_generator(z_dist)
591-
output = anomaly_maps, anomaly_scores
591+
output = {"anomaly_map": anomaly_maps, "pred_score": anomaly_scores}
592592
return output
593593

594594
def _compute_anomaly_scores(self, z_dists: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)