8
8
import torch
9
9
from lightning .pytorch import Trainer
10
10
from lightning .pytorch .utilities .types import STEP_OUTPUT
11
+ from torchmetrics import MetricCollection
11
12
12
13
from anomalib .metrics import MinMax
13
14
from anomalib .models .components import AnomalyModule
@@ -27,13 +28,26 @@ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None =
27
28
del trainer , stage # These variables are not used.
28
29
29
30
if not hasattr (pl_module , "normalization_metrics" ):
30
- pl_module .normalization_metrics = MinMax ().cpu ()
31
- elif not isinstance (pl_module .normalization_metrics , MinMax ):
32
- msg = f"Expected normalization_metrics to be of type MinMax, got { type (pl_module .normalization_metrics )} "
33
- raise AttributeError (
34
- msg ,
31
+ pl_module .normalization_metrics = MetricCollection (
32
+ {
33
+ "anomaly_maps" : MinMax ().cpu (),
34
+ "box_scores" : MinMax ().cpu (),
35
+ "pred_scores" : MinMax ().cpu (),
36
+ },
35
37
)
36
38
39
+ elif not isinstance (pl_module .normalization_metrics , MetricCollection ):
40
+ msg = (
41
+ f"Expected normalization_metrics to be of type MetricCollection"
42
+ f"got { type (pl_module .normalization_metrics )} "
43
+ )
44
+ raise TypeError (msg )
45
+
46
+ for name , metric in pl_module .normalization_metrics .items ():
47
+ if not isinstance (metric , MinMax ):
48
+ msg = f"Expected normalization_metric { name } to be of type MinMax, got { type (metric )} "
49
+ raise TypeError (msg )
50
+
37
51
def on_test_start (self , trainer : Trainer , pl_module : AnomalyModule ) -> None :
38
52
"""Call when the test begins."""
39
53
del trainer # `trainer` variable is not used.
@@ -42,6 +56,13 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
42
56
if metric is not None :
43
57
metric .set_threshold (0.5 )
44
58
59
+ def on_validation_epoch_start (self , trainer : Trainer , pl_module : AnomalyModule ) -> None :
60
+ """Call when the validation epoch begins."""
61
+ del trainer # `trainer` variable is not used.
62
+
63
+ if hasattr (pl_module , "normalization_metrics" ):
64
+ pl_module .normalization_metrics .reset ()
65
+
45
66
def on_validation_batch_end (
46
67
self ,
47
68
trainer : Trainer ,
@@ -55,14 +76,11 @@ def on_validation_batch_end(
55
76
del trainer , batch , batch_idx , dataloader_idx # These variables are not used.
56
77
57
78
if "anomaly_maps" in outputs :
58
- pl_module .normalization_metrics (outputs ["anomaly_maps" ])
59
- elif "box_scores" in outputs :
60
- pl_module .normalization_metrics (torch .cat (outputs ["box_scores" ]))
61
- elif "pred_scores" in outputs :
62
- pl_module .normalization_metrics (outputs ["pred_scores" ])
63
- else :
64
- msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores"
65
- raise ValueError (msg )
79
+ pl_module .normalization_metrics ["anomaly_maps" ](outputs ["anomaly_maps" ])
80
+ if "box_scores" in outputs :
81
+ pl_module .normalization_metrics ["box_scores" ](torch .cat (outputs ["box_scores" ]))
82
+ if "pred_scores" in outputs :
83
+ pl_module .normalization_metrics ["pred_scores" ](outputs ["pred_scores" ])
66
84
67
85
def on_test_batch_end (
68
86
self ,
@@ -97,12 +115,14 @@ def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: A
97
115
"""Normalize a batch of predictions."""
98
116
image_threshold = pl_module .image_threshold .value .cpu ()
99
117
pixel_threshold = pl_module .pixel_threshold .value .cpu ()
100
- stats = pl_module .normalization_metrics .cpu ()
101
118
if "pred_scores" in outputs :
119
+ stats = pl_module .normalization_metrics ["pred_scores" ].cpu ()
102
120
outputs ["pred_scores" ] = normalize (outputs ["pred_scores" ], image_threshold , stats .min , stats .max )
103
121
if "anomaly_maps" in outputs :
122
+ stats = pl_module .normalization_metrics ["anomaly_maps" ].cpu ()
104
123
outputs ["anomaly_maps" ] = normalize (outputs ["anomaly_maps" ], pixel_threshold , stats .min , stats .max )
105
124
if "box_scores" in outputs :
125
+ stats = pl_module .normalization_metrics ["box_scores" ].cpu ()
106
126
outputs ["box_scores" ] = [
107
127
normalize (scores , pixel_threshold , stats .min , stats .max ) for scores in outputs ["box_scores" ]
108
128
]
0 commit comments