Skip to content

Commit ce4bc0c

Browse files
committed
fix(callbacks): fix wandb callbacks, changelog mod
1 parent 2719ef7 commit ce4bc0c

File tree

4 files changed

+92
-107
lines changed

4 files changed

+92
-107
lines changed

CHANGELOG.md

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,51 +32,50 @@
3232

3333
## Test
3434

35-
- Update tests for Inferes and mask utils.
36-
37-
- Add tests for the benchmarkers.
35+
- Update tests for Inferes and mask utils.
36+
- Add tests for the benchmarkers.
3837

3938
## Fixes
4039

41-
- init and typing fixes
40+
- init and typing fixes
4241

4342
## Docs
4443

45-
- Typo fies in docs
44+
- Typo fies in docs
4645

4746
## Features
4847

49-
- Add numba parallellized median filter and majority voting for post-processing
50-
- Add support for own semantic and type seg post-proc funcs in Inferers
48+
- Add numba parallellized median filter and majority voting for post-processing
49+
- Add support for own semantic and type seg post-proc funcs in Inferers
5150

52-
- Add segmentation performance benchmarking helper class.
53-
- Add segmentation latency benchmarking helper class.
51+
- Add segmentation performance benchmarking helper class.
52+
- Add segmentation latency benchmarking helper class.
5453

5554
<a id='changelog-0.1.2'></a>
5655

5756
# 0.1.2 — 2022-09-09
5857

5958
## Fixes
6059

61-
- **datasets.writers**: Update `save2db` & `save2folder` for optional type_map and sem_map args.
62-
- **datasets.writers**: Pre-processing (`pre-proc`) callable arg for `_get_tiles` method. This enables the Lizard datamodule.
63-
- **inference**: Fix- padding bug with sliding window inference.
60+
- Update `save2db` & `save2folder` for optional type_map and sem_map args.
61+
- Pre-processing (`pre-proc`) callable arg for `_get_tiles` method. This enables the Lizard datamodule.
62+
- Fix- padding bug with sliding window inference.
6463

6564
## Features
6665

67-
- **datamodules**: Lizard datamodule (https://arxiv.org/abs/2108.11195)
66+
- Lizard datamodule (https://arxiv.org/abs/2108.11195)
6867

69-
- **models**: Add a universal multi-task U-net model builder (experimental)
68+
- Add a universal multi-task U-net model builder (experimental)
7069

7170
## Test
7271

73-
- **dataset**: Update dataset tests.
72+
- Update dataset tests.
7473

75-
- **models**: Update tests for multi-task U-Net
74+
- Update tests for multi-task U-Net
7675

7776
## Type Hints
7877

79-
- **models**: Fix incorrect type hints.
78+
- Fix incorrect type hints.
8079

8180
## Examples
8281

cellseg_models_pytorch/training/callbacks/metric_callbacks.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def __init__(
104104
dist_sync_on_step: bool = False,
105105
progress_grouo: Any = None,
106106
dist_sync_func: Callable = None,
107-
num_classes: int = None,
108107
**kwargs
109108
) -> None:
110109
"""Create a custom torchmetrics mIoU callback.
@@ -121,9 +120,6 @@ def __init__(
121120
dist_sync_func : Callable, optional
122121
Callback that performs the allgather operation on the metric state.
123122
When None, DDP will be used to perform the allgather.
124-
num_classes : int, optional
125-
If not None, multi-class miou will be returned.
126-
127123
"""
128124
super().__init__(
129125
compute_on_step=compute_on_step,

cellseg_models_pytorch/training/callbacks/wandb_callbacks.py

Lines changed: 76 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ..functional import iou
1414

15-
__all__ = ["WandbImageCallback", "WandbClassMetricCallback"]
15+
__all__ = ["WandbImageCallback", "WandbClassBarCallback", "WandbClassLineCallback"]
1616

1717

1818
class WandbImageCallback(pl.Callback):
@@ -104,26 +104,22 @@ def on_validation_batch_end(
104104
trainer.logger.experiment.log(log_dict)
105105

106106

107-
class WandbClassMetricCallback(pl.Callback):
107+
class WandbIoUCallback(pl.Callback):
108108
def __init__(
109109
self,
110110
type_classes: Dict[str, int],
111111
sem_classes: Optional[Dict[str, int]],
112112
freq: int = 100,
113-
return_series: bool = True,
114-
return_bar: bool = True,
115-
return_table: bool = False,
116113
) -> None:
117-
"""Call back to compute per-class ious and log them to wandb."""
114+
"""Create a base class for IoU wandb callbacks."""
118115
super().__init__()
119116
self.type_classes = type_classes
120117
self.sem_classes = sem_classes
121118
self.freq = freq
122-
self.return_series = return_series
123-
self.return_bar = return_bar
124-
self.return_table = return_table
125-
self.cell_ious = np.empty(0)
126-
self.sem_ious = np.empty(0)
119+
120+
def batch_end(self) -> None:
121+
"""Abstract batch end method."""
122+
raise NotImplementedError
127123

128124
def compute(
129125
self,
@@ -139,36 +135,47 @@ def compute(
139135
met = iou(pred, target).mean(dim=0)
140136
return met.to("cpu").numpy()
141137

142-
def get_table(
143-
self, ious: np.ndarray, x: np.ndarray, classes: Dict[int, str]
144-
) -> wandb.Table:
145-
"""Return a wandb Table with step, iou and label values for every step."""
146-
batch_data = [
147-
[xi * self.freq, c, np.round(ious[xi, i], 4)]
148-
for i, c, in classes.items()
149-
for xi in x
150-
]
138+
def on_train_batch_end(
139+
self,
140+
trainer: pl.Trainer,
141+
pl_module: pl.LightningModule,
142+
outputs: Dict[str, torch.Tensor],
143+
batch: Dict[str, torch.Tensor],
144+
batch_idx: int,
145+
dataloader_idx: int,
146+
) -> None:
147+
"""Log the inputs and outputs of the model to wandb."""
148+
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="train")
151149

152-
return wandb.Table(data=batch_data, columns=["step", "label", "value"])
150+
def on_validation_batch_end(
151+
self,
152+
trainer: pl.Trainer,
153+
pl_module: pl.LightningModule,
154+
outputs: Dict[str, torch.Tensor],
155+
batch: Dict[str, torch.Tensor],
156+
batch_idx: int,
157+
dataloader_idx: int,
158+
) -> None:
159+
"""Log the inputs and outputs of the model to wandb."""
160+
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="val")
161+
162+
163+
class WandbClassBarCallback(WandbIoUCallback):
164+
def __init__(
165+
self,
166+
type_classes: Dict[str, int],
167+
sem_classes: Optional[Dict[str, int]],
168+
freq: int = 100,
169+
) -> None:
170+
"""Create a wandb callback that logs per-class mIoU at batch ends."""
171+
super().__init__(type_classes, sem_classes, freq)
153172

154173
def get_bar(self, iou: np.ndarray, classes: Dict[int, str], title: str) -> Any:
155174
"""Return a wandb bar plot object of the current per class iou values."""
156175
batch_data = [[lab, val] for lab, val in zip(list(classes.values()), iou)]
157176
table = wandb.Table(data=batch_data, columns=["label", "value"])
158177
return wandb.plot.bar(table, "label", "value", title=title)
159178

160-
def get_series(
161-
self, ious: np.ndarray, x: np.ndarray, classes: Dict[int, str], title: str
162-
) -> Any:
163-
"""Return a wandb series plot obj of the per class iou values over timesteps."""
164-
return wandb.plot.line_series(
165-
xs=x.tolist(),
166-
ys=[ious[:, c].tolist() for c in classes.keys()],
167-
keys=list(classes.values()),
168-
title=title,
169-
xname="step",
170-
)
171-
172179
def batch_end(
173180
self,
174181
trainer: pl.Trainer,
@@ -182,69 +189,54 @@ def batch_end(
182189
log_dict = {}
183190
if "type" in list(batch.keys()):
184191
iou = self.compute("type", outputs, batch)
185-
self.cell_ious = np.append(self.cell_ious, iou)
186-
cell_ious = self.cell_ious.reshape(-1, len(self.type_classes))
187-
x = np.arange(cell_ious.shape[0])
188-
189-
if self.return_table:
190-
log_dict[f"{phase}/type_ious_table"] = self.get_table(
191-
cell_ious, x, self.type_classes
192-
)
193-
194-
if self.return_series:
195-
log_dict[f"{phase}/type_ious_per_class"] = self.get_series(
196-
cell_ious, x, self.type_classes, title="Per type class mIoU"
197-
)
198-
199-
if self.return_bar:
200-
log_dict[f"{phase}/type_ious_bar"] = self.get_bar(
201-
list(iou), self.type_classes, title="Cell class mIoUs"
202-
)
192+
log_dict[f"{phase}/type_ious_bar"] = self.get_bar(
193+
list(iou), self.type_classes, title="Cell class mIoUs"
194+
)
203195

204196
if "sem" in list(batch.keys()):
205197
iou = self.compute("sem", outputs, batch)
206-
207-
self.sem_ious = np.append(self.sem_ious, iou)
208-
sem_ious = self.sem_ious.reshape(-1, len(self.sem_classes))
209-
x = np.arange(sem_ious.shape[0])
210-
211-
if self.return_table:
212-
log_dict[f"{phase}/sem_ious_table"] = self.get_table(
213-
cell_ious, x, self.type_classes
214-
)
215-
216-
if self.return_series:
217-
log_dict[f"{phase}/sem_ious_per_class"] = self.get_series(
218-
cell_ious, x, self.type_classes, title="Per sem class mIoU"
219-
)
220-
221-
if self.return_bar:
222-
log_dict[f"{phase}/sem_ious_bar"] = self.get_bar(
223-
list(iou), self.type_classes, title="Sem class mIoUs"
224-
)
198+
log_dict[f"{phase}/sem_ious_bar"] = self.get_bar(
199+
list(iou), self.sem_classes, title="Sem class mIoUs"
200+
)
225201

226202
trainer.logger.experiment.log(log_dict)
227203

228-
def on_train_batch_end(
204+
205+
class WandbClassLineCallback(WandbIoUCallback):
206+
def __init__(
229207
self,
230-
trainer: pl.Trainer,
231-
pl_module: pl.LightningModule,
232-
outputs: Dict[str, torch.Tensor],
233-
batch: Dict[str, torch.Tensor],
234-
batch_idx: int,
235-
dataloader_idx: int,
208+
type_classes: Dict[str, int],
209+
sem_classes: Optional[Dict[str, int]],
210+
freq: int = 100,
236211
) -> None:
237-
"""Log the inputs and outputs of the model to wandb."""
238-
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="train")
212+
"""Create a wandb callback that logs per-class mIoU at batch ends."""
213+
super().__init__(type_classes, sem_classes, freq)
239214

240-
def on_validation_batch_end(
215+
def get_points(self, iou: np.ndarray, classes: Dict[int, str]) -> Any:
216+
"""Return a wandb bar plot object of the current per class iou values."""
217+
return {lab: val for lab, val in zip(list(classes.values()), iou)}
218+
219+
def batch_end(
241220
self,
242221
trainer: pl.Trainer,
243-
pl_module: pl.LightningModule,
244222
outputs: Dict[str, torch.Tensor],
245223
batch: Dict[str, torch.Tensor],
246224
batch_idx: int,
247-
dataloader_idx: int,
225+
phase: str,
248226
) -> None:
249-
"""Log the inputs and outputs of the model to wandb."""
250-
self.batch_end(trainer, outputs["soft_masks"], batch, batch_idx, phase="val")
227+
"""Log metrics at every 100th step to wandb."""
228+
if batch_idx % self.freq == 0:
229+
log_dict = {}
230+
if "type" in list(batch.keys()):
231+
iou = self.compute("type", outputs, batch)
232+
log_dict[f"{phase}/type_ious_points"] = self.get_points(
233+
list(iou), self.type_classes
234+
)
235+
236+
if "sem" in list(batch.keys()):
237+
iou = self.compute("sem", outputs, batch)
238+
log_dict[f"{phase}/sem_ious_points"] = self.get_points(
239+
list(iou), self.sem_classes
240+
)
241+
242+
trainer.logger.experiment.log(log_dict)

cellseg_models_pytorch/training/lit/lightning_experiment.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ def __init__(
7373
scheduler_params : Dict[str, Any]
7474
Params dict for the scheduler. Refer to torch lr_scheduler docs
7575
for the possible scheduler arguments.
76-
return_soft_masks : bool, default=True
77-
Return the model outputs for logging if True. Saves mem if set to False.
7876
log_freq : int, default=100
7977
Return soft masks every every n batches for callbacks and logging.
8078

0 commit comments

Comments
 (0)