Skip to content

Commit ea99a70

Browse files
committed
feat(training): Add metric and wandb callbacks
1 parent 2ead010 commit ea99a70

File tree

6 files changed

+355
-128
lines changed

6 files changed

+355
-128
lines changed

cellseg_models_pytorch/training/callbacks/metric_callbacks.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55
from ..functional.train_metrics import accuracy, iou
66

77
try:
8-
from torchmetrics import Metric
8+
from torchmetrics import (
9+
MeanSquaredError,
10+
Metric,
11+
StructuralSimilarityIndexMeasure,
12+
UniversalImageQualityIndex,
13+
)
14+
915
except ModuleNotFoundError:
1016
raise ModuleNotFoundError(
1117
"`torchmetrics` package is required when using metric callbacks. "
@@ -26,6 +32,7 @@ def __init__(
2632
dist_sync_on_step: bool = False,
2733
progress_group: Any = None,
2834
dist_sync_func: Callable = None,
35+
**kwargs
2936
) -> None:
3037
"""Create a custom torchmetrics accuracy callback.
3138
@@ -57,9 +64,10 @@ def __init__(
5764

5865
def update(
5966
self,
60-
pred: torch.Tensor,
67+
preds: torch.Tensor,
6168
target: torch.Tensor,
6269
activation: str = "softmax",
70+
**kwargs
6371
) -> None:
6472
"""Update the batch accuracy list with one batch accuracy value.
6573
@@ -72,7 +80,7 @@ def update(
7280
activation : str, default="softmax"
7381
The activation function. One of: "softmax", "sigmoid" or None.
7482
"""
75-
batch_acc = accuracy(pred, target, activation)
83+
batch_acc = accuracy(preds, target, activation)
7684
self.batch_accuracies += batch_acc
7785
self.n_batches += 1
7886

@@ -96,6 +104,8 @@ def __init__(
96104
dist_sync_on_step: bool = False,
97105
progress_grouo: Any = None,
98106
dist_sync_func: Callable = None,
107+
num_classes: int = None,
108+
**kwargs
99109
) -> None:
100110
"""Create a custom torchmetrics mIoU callback.
101111
@@ -111,6 +121,9 @@ def __init__(
111121
dist_sync_func : Callable, optional
112122
Callback that performs the allgather operation on the metric state.
113123
When None, DDP will be used to perform the allgather.
124+
num_classes : int, optional
125+
If not None, multi-class miou will be returned.
126+
114127
"""
115128
super().__init__(
116129
compute_on_step=compute_on_step,
@@ -124,9 +137,10 @@ def __init__(
124137

125138
def update(
126139
self,
127-
pred: torch.Tensor,
140+
preds: torch.Tensor,
128141
target: torch.Tensor,
129142
activation: str = "softmax",
143+
**kwargs
130144
) -> None:
131145
"""Update the batch IoU list with one batch IoU matrix.
132146
@@ -139,7 +153,7 @@ def update(
139153
activation : str, default="softmax"
140154
The activation function. One of: "softmax", "sigmoid" or None.
141155
"""
142-
batch_iou = iou(pred, target, activation)
156+
batch_iou = iou(preds, target, activation)
143157
self.batch_ious += batch_iou.mean()
144158
self.n_batches += 1
145159

@@ -153,4 +167,10 @@ def compute(self) -> torch.Tensor:
153167
return self.batch_ious / self.n_batches
154168

155169

156-
METRIC_LOOKUP = {"acc": Accuracy, "miou": MeanIoU}
170+
METRIC_LOOKUP = {
171+
"acc": Accuracy,
172+
"miou": MeanIoU,
173+
"mse": MeanSquaredError,
174+
"ssim": StructuralSimilarityIndexMeasure,
175+
"iqi": UniversalImageQualityIndex,
176+
}
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
from typing import Any, Dict, Optional
2+
3+
import numpy as np
4+
import pytorch_lightning as pl
5+
import torch
6+
import torch.nn.functional as F
7+
8+
try:
9+
import wandb
10+
except ImportError:
11+
raise ImportError("wandb required. `pip install wandb`")
12+
13+
from ..functional import iou
14+
15+
__all__ = ["WandbImageCallback", "WandbClassMetricCallback"]
16+
17+
18+
class WandbImageCallback(pl.Callback):
19+
def __init__(
20+
self,
21+
type_classes: Dict[str, int],
22+
sem_classes: Optional[Dict[str, int]],
23+
freq: int = 100,
24+
) -> None:
25+
"""Create a callback that logs prediction masks to wandb."""
26+
super().__init__()
27+
self.freq = freq
28+
self.type_classes = type_classes
29+
self.sem_classes = sem_classes
30+
31+
def on_validation_batch_end(
32+
self,
33+
trainer: pl.Trainer,
34+
pl_module: pl.LightningModule,
35+
outputs: Dict[str, torch.Tensor],
36+
batch: Dict[str, torch.Tensor],
37+
batch_idx: int,
38+
dataloader_idx: int,
39+
) -> None:
40+
"""Log the inputs and outputs of the model to wandb."""
41+
if batch_idx % self.freq == 0:
42+
outputs = outputs["soft_masks"]
43+
44+
log_dict = {
45+
"global_step": trainer.global_step,
46+
"epoch": trainer.current_epoch,
47+
}
48+
49+
img = batch["image"].detach().to("cpu").numpy()
50+
51+
if "type" in list(batch.keys()):
52+
type_target = batch["type"].detach().to("cpu").numpy()
53+
soft_types = outputs["type"].detach().to("cpu")
54+
types = torch.argmax(F.softmax(soft_types, dim=1), dim=1).numpy()
55+
56+
log_dict["val/cell_types"] = [
57+
wandb.Image(
58+
im.transpose(1, 2, 0),
59+
masks={
60+
"predictions": {
61+
"mask_data": t,
62+
"class_labels": self.type_classes,
63+
},
64+
"ground_truth": {
65+
"mask_data": tt,
66+
"class_labels": self.type_classes,
67+
},
68+
},
69+
)
70+
for im, t, tt in zip(img, types, type_target)
71+
]
72+
73+
if "sem" in list(batch.keys()):
74+
sem_target = batch["sem"].detach().to("cpu").numpy()
75+
soft_sem = outputs["sem"].detach().to(device="cpu")
76+
sem = torch.argmax(F.softmax(soft_sem, dim=1), dim=1).numpy()
77+
78+
log_dict["val/tissue_areas"] = [
79+
wandb.Image(
80+
im.transpose(1, 2, 0),
81+
masks={
82+
"predictions": {
83+
"mask_data": s,
84+
"class_labels": self.sem_classes,
85+
},
86+
"ground_truth": {
87+
"mask_data": st,
88+
"class_labels": self.sem_classes,
89+
},
90+
},
91+
)
92+
for im, s, st in zip(img, sem, sem_target)
93+
]
94+
95+
for m in list(batch.keys()):
96+
if m not in ("sem", "type", "inst", "image"):
97+
aux = outputs[m].detach().to(device="cpu")
98+
log_dict[f"val/{m}"] = [
99+
wandb.Image(a[i, ...], caption=f"{m} maps")
100+
for a in aux
101+
for i in range(a.shape[0])
102+
]
103+
104+
trainer.logger.experiment.log(log_dict)
105+
106+
107+
class WandbClassMetricCallback(pl.Callback):
108+
def __init__(
109+
self,
110+
type_classes: Dict[str, int],
111+
sem_classes: Optional[Dict[str, int]],
112+
freq: int = 100,
113+
return_series: bool = True,
114+
return_bar: bool = True,
115+
return_table: bool = False,
116+
) -> None:
117+
"""Call back to compute per-class ious and log them to wandb."""
118+
super().__init__()
119+
self.type_classes = type_classes
120+
self.sem_classes = sem_classes
121+
self.freq = freq
122+
self.return_series = return_series
123+
self.return_bar = return_bar
124+
self.return_table = return_table
125+
self.cell_ious = np.empty(0)
126+
self.sem_ious = np.empty(0)
127+
128+
def compute(
129+
self,
130+
key: str,
131+
outputs: Dict[str, torch.Tensor],
132+
batch: Dict[str, torch.Tensor],
133+
) -> np.ndarray:
134+
"""Compute the iou per class."""
135+
target = batch[key].detach()
136+
soft_types = outputs[key].detach()
137+
pred = F.softmax(soft_types, dim=1)
138+
139+
met = iou(pred, target).mean(dim=0)
140+
return met.to("cpu").numpy()
141+
142+
def get_table(
143+
self, ious: np.ndarray, x: np.ndarray, classes: Dict[int, str]
144+
) -> wandb.Table:
145+
"""Return a wandb Table with step, iou and label values for every step."""
146+
batch_data = [
147+
[xi * self.freq, c, np.round(ious[xi, i], 4)]
148+
for i, c, in classes.items()
149+
for xi in x
150+
]
151+
152+
return wandb.Table(data=batch_data, columns=["step", "label", "value"])
153+
154+
def get_bar(self, iou: np.ndarray, classes: Dict[int, str], title: str) -> Any:
155+
"""Return a wandb bar plot object of the current per class iou values."""
156+
batch_data = [[lab, val] for lab, val in zip(list(classes.values()), iou)]
157+
table = wandb.Table(data=batch_data, columns=["label", "value"])
158+
return wandb.plot.bar(table, "label", "value", title=title)
159+
160+
def get_series(
161+
self, ious: np.ndarray, x: np.ndarray, classes: Dict[int, str], title: str
162+
) -> Any:
163+
"""Return a wandb series plot obj of the per class iou values over timesteps."""
164+
return wandb.plot.line_series(
165+
xs=x.tolist(),
166+
ys=[ious[:, c].tolist() for c in classes.keys()],
167+
keys=list(classes.values()),
168+
title=title,
169+
xname="step",
170+
)
171+
172+
def batch_end(
173+
self,
174+
trainer: pl.Trainer,
175+
outputs: Dict[str, torch.Tensor],
176+
batch: Dict[str, torch.Tensor],
177+
batch_idx: int,
178+
phase: str,
179+
) -> None:
180+
"""Log metrics at every 100th step to wandb."""
181+
if batch_idx % self.freq == 0:
182+
log_dict = {}
183+
if "type" in list(batch.keys()):
184+
iou = self.compute("type", outputs, batch)
185+
self.cell_ious = np.append(self.cell_ious, iou)
186+
cell_ious = self.cell_ious.reshape(-1, len(self.type_classes))
187+
x = np.arange(cell_ious.shape[0])
188+
189+
if self.return_table:
190+
log_dict[f"{phase}/type_ious_table"] = self.get_table(
191+
cell_ious, x, self.type_classes
192+
)
193+
194+
if self.return_series:
195+
log_dict[f"{phase}/type_ious_per_class"] = self.get_series(
196+
cell_ious, x, self.type_classes, title="Per type class mIoU"
197+
)
198+
199+
if self.return_bar:
200+
log_dict[f"{phase}/type_ious_bar"] = self.get_bar(
201+
list(iou), self.type_classes, title="Cell class mIoUs"
202+
)
203+
204+
if "sem" in list(batch.keys()):
205+
iou = self.compute("sem", outputs, batch)
206+
207+
self.sem_ious = np.append(self.sem_ious, iou)
208+
sem_ious = self.sem_ious.reshape(-1, len(self.sem_classes))
209+
x = np.arange(sem_ious.shape[0])
210+
211+
if self.return_table:
212+
log_dict[f"{phase}/sem_ious_table"] = self.get_table(
213+
cell_ious, x, self.type_classes
214+
)
215+
216+
if self.return_series:
217+
log_dict[f"{phase}/sem_ious_per_class"] = self.get_series(
218+
cell_ious, x, self.type_classes, title="Per sem class mIoU"
219+
)
220+
221+
if self.return_bar:
222+
log_dict[f"{phase}/sem_ious_bar"] = self.get_bar(
223+
list(iou), self.type_classes, title="Sem class mIoUs"
224+
)
225+
226+
trainer.logger.experiment.log(log_dict)
227+
228+
def on_train_batch_end(
229+
self,
230+
trainer: pl.Trainer,
231+
pl_module: pl.LightningModule,
232+
outputs: Dict[str, torch.Tensor],
233+
batch: Dict[str, torch.Tensor],
234+
batch_idx: int,
235+
dataloader_idx: int,
236+
) -> None:
237+
"""Log the inputs and outputs of the model to wandb."""
238+
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="train")
239+
240+
def on_validation_batch_end(
241+
self,
242+
trainer: pl.Trainer,
243+
pl_module: pl.LightningModule,
244+
outputs: Dict[str, torch.Tensor],
245+
batch: Dict[str, torch.Tensor],
246+
batch_idx: int,
247+
dataloader_idx: int,
248+
) -> None:
249+
"""Log the inputs and outputs of the model to wandb."""
250+
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="val")

0 commit comments

Comments
 (0)