Skip to content

Commit 7baf8e2

Browse files
committed
feat(training): Add wandb test callback.
1 parent 3de8176 commit 7baf8e2

File tree

2 files changed

+264
-39
lines changed

2 files changed

+264
-39
lines changed

cellseg_models_pytorch/training/callbacks/wandb_callbacks.py

Lines changed: 256 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, Dict, List, Optional
22

33
import numpy as np
44
import pytorch_lightning as pl
55
import torch
66
import torch.nn.functional as F
7+
from skimage.color import label2rgb
8+
from tqdm import tqdm
79

810
try:
911
import wandb
1012
except ImportError:
1113
raise ImportError("wandb required. `pip install wandb`")
1214

15+
from ...inference import PostProcessor
16+
from ...metrics.functional import iou_multiclass, panoptic_quality
17+
from ...utils import get_type_instances, remap_label
1318
from ..functional import iou
1419

15-
__all__ = ["WandbImageCallback", "WandbClassBarCallback", "WandbClassLineCallback"]
20+
__all__ = ["WandbImageCallback", "WandbClassLineCallback", "WandbGetExamplesCallback"]
1621

1722

1823
class WandbImageCallback(pl.Callback):
@@ -135,7 +140,7 @@ def compute(
135140
met = iou(pred, target).mean(dim=0)
136141
return met.to("cpu").numpy()
137142

138-
def on_train_batch_end(
143+
def train_batch_end(
139144
self,
140145
trainer: pl.Trainer,
141146
pl_module: pl.LightningModule,
@@ -147,7 +152,7 @@ def on_train_batch_end(
147152
"""Log the inputs and outputs of the model to wandb."""
148153
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="train")
149154

150-
def on_validation_batch_end(
155+
def validation_batch_end(
151156
self,
152157
trainer: pl.Trainer,
153158
pl_module: pl.LightningModule,
@@ -159,47 +164,17 @@ def on_validation_batch_end(
159164
"""Log the inputs and outputs of the model to wandb."""
160165
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="val")
161166

162-
163-
class WandbClassBarCallback(WandbIoUCallback):
164-
def __init__(
165-
self,
166-
type_classes: Dict[str, int],
167-
sem_classes: Optional[Dict[str, int]],
168-
freq: int = 100,
169-
) -> None:
170-
"""Create a wandb callback that logs per-class mIoU at batch ends."""
171-
super().__init__(type_classes, sem_classes, freq)
172-
173-
def get_bar(self, iou: np.ndarray, classes: Dict[int, str], title: str) -> Any:
174-
"""Return a wandb bar plot object of the current per class iou values."""
175-
batch_data = [[lab, val] for lab, val in zip(list(classes.values()), iou)]
176-
table = wandb.Table(data=batch_data, columns=["label", "value"])
177-
return wandb.plot.bar(table, "label", "value", title=title)
178-
179-
def batch_end(
167+
def test_batch_end(
180168
self,
181169
trainer: pl.Trainer,
170+
pl_module: pl.LightningModule,
182171
outputs: Dict[str, torch.Tensor],
183172
batch: Dict[str, torch.Tensor],
184173
batch_idx: int,
185-
phase: str,
174+
dataloader_idx: int,
186175
) -> None:
187-
"""Log metrics at every 100th step to wandb."""
188-
if batch_idx % self.freq == 0:
189-
log_dict = {}
190-
if "type" in list(batch.keys()):
191-
iou = self.compute("type", outputs, batch)
192-
log_dict[f"{phase}/type_ious_bar"] = self.get_bar(
193-
list(iou), self.type_classes, title="Cell class mIoUs"
194-
)
195-
196-
if "sem" in list(batch.keys()):
197-
iou = self.compute("sem", outputs, batch)
198-
log_dict[f"{phase}/sem_ious_bar"] = self.get_bar(
199-
list(iou), self.sem_classes, title="Sem class mIoUs"
200-
)
201-
202-
trainer.logger.experiment.log(log_dict)
176+
"""Log the inputs and outputs of the model to wandb."""
177+
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="test")
203178

204179

205180
class WandbClassLineCallback(WandbIoUCallback):
@@ -240,3 +215,245 @@ def batch_end(
240215
)
241216

242217
trainer.logger.experiment.log(log_dict)
218+
219+
def on_validation_batch_end(
220+
self,
221+
trainer: pl.Trainer,
222+
pl_module: pl.LightningModule,
223+
outputs: Dict[str, torch.Tensor],
224+
batch: Dict[str, torch.Tensor],
225+
batch_idx: int,
226+
dataloader_idx: int,
227+
) -> None:
228+
"""Call the callback at val time."""
229+
self.validation_batch_end(
230+
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
231+
)
232+
233+
def on_train_batch_end(
234+
self,
235+
trainer: pl.Trainer,
236+
pl_module: pl.LightningModule,
237+
outputs: Dict[str, torch.Tensor],
238+
batch: Dict[str, torch.Tensor],
239+
batch_idx: int,
240+
dataloader_idx: int,
241+
) -> None:
242+
"""Call the callback at val time."""
243+
self.train_batch_end(
244+
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
245+
)
246+
247+
248+
class WandbGetExamplesCallback(pl.Callback):
249+
def __init__(
250+
self,
251+
type_classes: Dict[str, int],
252+
sem_classes: Optional[Dict[str, int]],
253+
instance_postproc: str,
254+
inst_key: str,
255+
aux_key: str,
256+
inst_act: str = "softmax",
257+
aux_act: str = None,
258+
) -> None:
259+
"""Create a wandb callback that logs img examples at test time."""
260+
super().__init__()
261+
self.type_classes = type_classes
262+
self.sem_classes = sem_classes
263+
self.inst_key = inst_key
264+
self.aux_key = aux_key
265+
266+
self.inst_act = inst_act
267+
self.aux_act = aux_act
268+
269+
self.postprocessor = PostProcessor(
270+
instance_postproc=instance_postproc, inst_key=inst_key, aux_key=aux_key
271+
)
272+
273+
def post_proc(
274+
self, outputs: Dict[str, torch.Tensor]
275+
) -> List[Dict[str, np.ndarray]]:
276+
"""Apply post processing to the outputs."""
277+
B, _, _, _ = outputs[self.inst_key].shape
278+
279+
inst = outputs[self.inst_key].detach()
280+
if self.inst_act == "softmax":
281+
inst = F.softmax(inst, dim=1)
282+
if self.inst_act == "sigmoid":
283+
inst = torch.sigmoid(inst)
284+
285+
aux = outputs[self.aux_key].detach()
286+
if self.aux_act == "tanh":
287+
aux = torch.tanh(aux)
288+
289+
sem = None
290+
if "sem" in outputs.keys():
291+
sem = outputs["sem"].detach()
292+
sem = F.softmax(sem, dim=1).cpu().numpy()
293+
294+
typ = None
295+
if "type" in outputs.keys():
296+
typ = outputs["type"].detach()
297+
typ = F.softmax(typ, dim=1).cpu().numpy()
298+
299+
inst = inst.cpu().numpy()
300+
aux = aux.cpu().numpy()
301+
outmaps = []
302+
for i in range(B):
303+
maps = {
304+
self.inst_key: inst[i],
305+
self.aux_key: aux[i],
306+
}
307+
if sem is not None:
308+
maps["sem"] = sem[i]
309+
if typ is not None:
310+
maps["type"] = typ[i]
311+
312+
out = self.postprocessor.post_proc_pipeline(maps)
313+
outmaps.append(out)
314+
315+
return outmaps
316+
317+
def count_pixels(self, img: np.ndarray, shape: int):
318+
"""Compute pixel proportions per class."""
319+
return [float(p) / shape**2 for p in np.bincount(img.astype(int).flatten())]
320+
321+
def epoch_end(self, trainer, pl_module) -> None:
322+
"""Log metrics at the epoch end."""
323+
decs = [list(it.keys()) for it in pl_module.heads.values()]
324+
outheads = [item for sublist in decs for item in sublist]
325+
326+
loader = trainer.datamodule.test_dataloader()
327+
runid = trainer.logger.experiment.id
328+
test_res_at = wandb.Artifact("test_pred_" + runid, "test_preds")
329+
330+
# Create artifact
331+
runid = trainer.logger.experiment.id
332+
test_res_at = wandb.Artifact("test_pred_" + runid, "test_preds")
333+
334+
# Init wb table
335+
cols = ["id", "inst_gt", "inst_pred", "bPQ"]
336+
337+
if "type" in outheads:
338+
cols += [
339+
"cell_types",
340+
*[f"pq_{c}" for c in self.type_classes.values() if c != "bg"],
341+
]
342+
if "sem" in outheads:
343+
cols += [
344+
"tissue_types",
345+
*[f"iou_{c}" for c in self.sem_classes.values() if c != "bg"],
346+
]
347+
348+
model_res_table = wandb.Table(columns=cols)
349+
350+
#
351+
with tqdm(loader, unit="batch") as loader:
352+
with torch.no_grad():
353+
for batch_idx, batch in enumerate(loader):
354+
soft_masks = pl_module.forward(batch["image"].to(pl_module.device))
355+
imgs = list(batch["image"].detach().cpu().numpy()) # [(C, H, W)*B]
356+
inst_targets = list(batch["inst_map"].detach().cpu().numpy())
357+
358+
outmaps = self.post_proc(soft_masks)
359+
360+
type_targets = None
361+
if "type" in list(batch.keys()):
362+
type_targets = list(
363+
batch["type"].detach().cpu().numpy()
364+
) # [(C, H, W)*B]
365+
366+
sem_targets = None
367+
if "sem" in list(batch.keys()):
368+
sem_targets = list(
369+
batch["sem"].detach().cpu().numpy()
370+
) # [(C, H, W)*B]
371+
372+
# loop the images in batch
373+
for i, (pred, im, inst_target) in enumerate(
374+
zip(outmaps, imgs, inst_targets)
375+
):
376+
inst_targ = remap_label(inst_target)
377+
inst_pred = remap_label(pred["inst"])
378+
379+
wb_inst_gt = wandb.Image(label2rgb(inst_targ, bg_label=0))
380+
wb_inst_pred = wandb.Image(label2rgb(inst_pred, bg_label=0))
381+
pq_inst = panoptic_quality(inst_targ, inst_pred)["pq"]
382+
383+
row = [
384+
f"test_batch_{batch_idx}_{i}",
385+
wb_inst_gt,
386+
wb_inst_pred,
387+
pq_inst,
388+
]
389+
390+
if type_targets is not None:
391+
per_class_pq = [
392+
panoptic_quality(
393+
remap_label(
394+
get_type_instances(
395+
inst_targ, type_targets[i], j
396+
)
397+
),
398+
remap_label(
399+
get_type_instances(inst_pred, pred["type"], j)
400+
),
401+
)["pq"]
402+
for j in self.type_classes.keys()
403+
if j != 0
404+
]
405+
406+
type_classes_set = wandb.Classes(
407+
[
408+
{"name": name, "id": id}
409+
for id, name in self.type_classes.items()
410+
if id != 0
411+
]
412+
)
413+
wb_type = wandb.Image(
414+
im.transpose(1, 2, 0),
415+
classes=type_classes_set,
416+
masks={
417+
"ground_truth": {"mask_data": type_targets[i]},
418+
"pred": {"mask_data": pred["type"]},
419+
},
420+
)
421+
422+
row += [wb_type, *per_class_pq]
423+
424+
if sem_targets is not None:
425+
per_class_iou = list(
426+
iou_multiclass(
427+
sem_targets[i], pred["sem"], len(self.sem_classes)
428+
)
429+
)
430+
431+
sem_classes_set = wandb.Classes(
432+
[
433+
{"name": name, "id": id}
434+
for id, name in self.sem_classes.items()
435+
if id != 0
436+
]
437+
)
438+
wb_sem = wandb.Image(
439+
im.transpose(1, 2, 0),
440+
classes=sem_classes_set,
441+
masks={
442+
"ground_truth": {"mask_data": sem_targets[i]},
443+
"pred": {"mask_data": pred["sem"]},
444+
},
445+
)
446+
row += [wb_sem, *per_class_iou[1:]]
447+
448+
model_res_table.add_data(*row)
449+
450+
test_res_at.add(model_res_table, "model_batch_results")
451+
trainer.logger.experiment.log_artifact(test_res_at)
452+
453+
def on_test_epoch_end(
454+
self,
455+
trainer: pl.Trainer,
456+
pl_module: pl.LightningModule,
457+
) -> None:
458+
"""Call the callback at test time."""
459+
self.epoch_end(trainer, pl_module)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## Fixes
2+
3+
- Add option to return binary and instance labelled mask from the dataloader. Previously binary was returned with `return_inst` flag which was confusing.
4+
- Fix the `SegmentationExperiment` to return preds and masks at test time.
5+
6+
## Features
7+
8+
- Add a Wandb artifact table callback for loading a table of test data metrics and insights to wandb.

0 commit comments

Comments
 (0)