1
- from typing import Any , Dict , Optional
1
+ from typing import Any , Dict , List , Optional
2
2
3
3
import numpy as np
4
4
import pytorch_lightning as pl
5
5
import torch
6
6
import torch .nn .functional as F
7
+ from skimage .color import label2rgb
8
+ from tqdm import tqdm
7
9
8
10
try :
9
11
import wandb
10
12
except ImportError :
11
13
raise ImportError ("wandb required. `pip install wandb`" )
12
14
15
+ from ...inference import PostProcessor
16
+ from ...metrics .functional import iou_multiclass , panoptic_quality
17
+ from ...utils import get_type_instances , remap_label
13
18
from ..functional import iou
14
19
15
- __all__ = ["WandbImageCallback" , "WandbClassBarCallback " , "WandbClassLineCallback " ]
20
+ __all__ = ["WandbImageCallback" , "WandbClassLineCallback " , "WandbGetExamplesCallback " ]
16
21
17
22
18
23
class WandbImageCallback (pl .Callback ):
@@ -135,7 +140,7 @@ def compute(
135
140
met = iou (pred , target ).mean (dim = 0 )
136
141
return met .to ("cpu" ).numpy ()
137
142
138
- def on_train_batch_end (
143
+ def train_batch_end (
139
144
self ,
140
145
trainer : pl .Trainer ,
141
146
pl_module : pl .LightningModule ,
@@ -147,7 +152,7 @@ def on_train_batch_end(
147
152
"""Log the inputs and outputs of the model to wandb."""
148
153
self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "train" )
149
154
150
- def on_validation_batch_end (
155
+ def validation_batch_end (
151
156
self ,
152
157
trainer : pl .Trainer ,
153
158
pl_module : pl .LightningModule ,
@@ -159,47 +164,17 @@ def on_validation_batch_end(
159
164
"""Log the inputs and outputs of the model to wandb."""
160
165
self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "val" )
161
166
162
-
163
- class WandbClassBarCallback (WandbIoUCallback ):
164
- def __init__ (
165
- self ,
166
- type_classes : Dict [str , int ],
167
- sem_classes : Optional [Dict [str , int ]],
168
- freq : int = 100 ,
169
- ) -> None :
170
- """Create a wandb callback that logs per-class mIoU at batch ends."""
171
- super ().__init__ (type_classes , sem_classes , freq )
172
-
173
- def get_bar (self , iou : np .ndarray , classes : Dict [int , str ], title : str ) -> Any :
174
- """Return a wandb bar plot object of the current per class iou values."""
175
- batch_data = [[lab , val ] for lab , val in zip (list (classes .values ()), iou )]
176
- table = wandb .Table (data = batch_data , columns = ["label" , "value" ])
177
- return wandb .plot .bar (table , "label" , "value" , title = title )
178
-
179
- def batch_end (
167
+ def test_batch_end (
180
168
self ,
181
169
trainer : pl .Trainer ,
170
+ pl_module : pl .LightningModule ,
182
171
outputs : Dict [str , torch .Tensor ],
183
172
batch : Dict [str , torch .Tensor ],
184
173
batch_idx : int ,
185
- phase : str ,
174
+ dataloader_idx : int ,
186
175
) -> None :
187
- """Log metrics at every 100th step to wandb."""
188
- if batch_idx % self .freq == 0 :
189
- log_dict = {}
190
- if "type" in list (batch .keys ()):
191
- iou = self .compute ("type" , outputs , batch )
192
- log_dict [f"{ phase } /type_ious_bar" ] = self .get_bar (
193
- list (iou ), self .type_classes , title = "Cell class mIoUs"
194
- )
195
-
196
- if "sem" in list (batch .keys ()):
197
- iou = self .compute ("sem" , outputs , batch )
198
- log_dict [f"{ phase } /sem_ious_bar" ] = self .get_bar (
199
- list (iou ), self .sem_classes , title = "Sem class mIoUs"
200
- )
201
-
202
- trainer .logger .experiment .log (log_dict )
176
+ """Log the inputs and outputs of the model to wandb."""
177
+ self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "test" )
203
178
204
179
205
180
class WandbClassLineCallback (WandbIoUCallback ):
@@ -240,3 +215,245 @@ def batch_end(
240
215
)
241
216
242
217
trainer .logger .experiment .log (log_dict )
218
+
219
+ def on_validation_batch_end (
220
+ self ,
221
+ trainer : pl .Trainer ,
222
+ pl_module : pl .LightningModule ,
223
+ outputs : Dict [str , torch .Tensor ],
224
+ batch : Dict [str , torch .Tensor ],
225
+ batch_idx : int ,
226
+ dataloader_idx : int ,
227
+ ) -> None :
228
+ """Call the callback at val time."""
229
+ self .validation_batch_end (
230
+ trainer , pl_module , outputs , batch , batch_idx , dataloader_idx
231
+ )
232
+
233
+ def on_train_batch_end (
234
+ self ,
235
+ trainer : pl .Trainer ,
236
+ pl_module : pl .LightningModule ,
237
+ outputs : Dict [str , torch .Tensor ],
238
+ batch : Dict [str , torch .Tensor ],
239
+ batch_idx : int ,
240
+ dataloader_idx : int ,
241
+ ) -> None :
242
+ """Call the callback at val time."""
243
+ self .train_batch_end (
244
+ trainer , pl_module , outputs , batch , batch_idx , dataloader_idx
245
+ )
246
+
247
+
248
+ class WandbGetExamplesCallback (pl .Callback ):
249
+ def __init__ (
250
+ self ,
251
+ type_classes : Dict [str , int ],
252
+ sem_classes : Optional [Dict [str , int ]],
253
+ instance_postproc : str ,
254
+ inst_key : str ,
255
+ aux_key : str ,
256
+ inst_act : str = "softmax" ,
257
+ aux_act : str = None ,
258
+ ) -> None :
259
+ """Create a wandb callback that logs img examples at test time."""
260
+ super ().__init__ ()
261
+ self .type_classes = type_classes
262
+ self .sem_classes = sem_classes
263
+ self .inst_key = inst_key
264
+ self .aux_key = aux_key
265
+
266
+ self .inst_act = inst_act
267
+ self .aux_act = aux_act
268
+
269
+ self .postprocessor = PostProcessor (
270
+ instance_postproc = instance_postproc , inst_key = inst_key , aux_key = aux_key
271
+ )
272
+
273
+ def post_proc (
274
+ self , outputs : Dict [str , torch .Tensor ]
275
+ ) -> List [Dict [str , np .ndarray ]]:
276
+ """Apply post processing to the outputs."""
277
+ B , _ , _ , _ = outputs [self .inst_key ].shape
278
+
279
+ inst = outputs [self .inst_key ].detach ()
280
+ if self .inst_act == "softmax" :
281
+ inst = F .softmax (inst , dim = 1 )
282
+ if self .inst_act == "sigmoid" :
283
+ inst = torch .sigmoid (inst )
284
+
285
+ aux = outputs [self .aux_key ].detach ()
286
+ if self .aux_act == "tanh" :
287
+ aux = torch .tanh (aux )
288
+
289
+ sem = None
290
+ if "sem" in outputs .keys ():
291
+ sem = outputs ["sem" ].detach ()
292
+ sem = F .softmax (sem , dim = 1 ).cpu ().numpy ()
293
+
294
+ typ = None
295
+ if "type" in outputs .keys ():
296
+ typ = outputs ["type" ].detach ()
297
+ typ = F .softmax (typ , dim = 1 ).cpu ().numpy ()
298
+
299
+ inst = inst .cpu ().numpy ()
300
+ aux = aux .cpu ().numpy ()
301
+ outmaps = []
302
+ for i in range (B ):
303
+ maps = {
304
+ self .inst_key : inst [i ],
305
+ self .aux_key : aux [i ],
306
+ }
307
+ if sem is not None :
308
+ maps ["sem" ] = sem [i ]
309
+ if typ is not None :
310
+ maps ["type" ] = typ [i ]
311
+
312
+ out = self .postprocessor .post_proc_pipeline (maps )
313
+ outmaps .append (out )
314
+
315
+ return outmaps
316
+
317
+ def count_pixels (self , img : np .ndarray , shape : int ):
318
+ """Compute pixel proportions per class."""
319
+ return [float (p ) / shape ** 2 for p in np .bincount (img .astype (int ).flatten ())]
320
+
321
+ def epoch_end (self , trainer , pl_module ) -> None :
322
+ """Log metrics at the epoch end."""
323
+ decs = [list (it .keys ()) for it in pl_module .heads .values ()]
324
+ outheads = [item for sublist in decs for item in sublist ]
325
+
326
+ loader = trainer .datamodule .test_dataloader ()
327
+ runid = trainer .logger .experiment .id
328
+ test_res_at = wandb .Artifact ("test_pred_" + runid , "test_preds" )
329
+
330
+ # Create artifact
331
+ runid = trainer .logger .experiment .id
332
+ test_res_at = wandb .Artifact ("test_pred_" + runid , "test_preds" )
333
+
334
+ # Init wb table
335
+ cols = ["id" , "inst_gt" , "inst_pred" , "bPQ" ]
336
+
337
+ if "type" in outheads :
338
+ cols += [
339
+ "cell_types" ,
340
+ * [f"pq_{ c } " for c in self .type_classes .values () if c != "bg" ],
341
+ ]
342
+ if "sem" in outheads :
343
+ cols += [
344
+ "tissue_types" ,
345
+ * [f"iou_{ c } " for c in self .sem_classes .values () if c != "bg" ],
346
+ ]
347
+
348
+ model_res_table = wandb .Table (columns = cols )
349
+
350
+ #
351
+ with tqdm (loader , unit = "batch" ) as loader :
352
+ with torch .no_grad ():
353
+ for batch_idx , batch in enumerate (loader ):
354
+ soft_masks = pl_module .forward (batch ["image" ].to (pl_module .device ))
355
+ imgs = list (batch ["image" ].detach ().cpu ().numpy ()) # [(C, H, W)*B]
356
+ inst_targets = list (batch ["inst_map" ].detach ().cpu ().numpy ())
357
+
358
+ outmaps = self .post_proc (soft_masks )
359
+
360
+ type_targets = None
361
+ if "type" in list (batch .keys ()):
362
+ type_targets = list (
363
+ batch ["type" ].detach ().cpu ().numpy ()
364
+ ) # [(C, H, W)*B]
365
+
366
+ sem_targets = None
367
+ if "sem" in list (batch .keys ()):
368
+ sem_targets = list (
369
+ batch ["sem" ].detach ().cpu ().numpy ()
370
+ ) # [(C, H, W)*B]
371
+
372
+ # loop the images in batch
373
+ for i , (pred , im , inst_target ) in enumerate (
374
+ zip (outmaps , imgs , inst_targets )
375
+ ):
376
+ inst_targ = remap_label (inst_target )
377
+ inst_pred = remap_label (pred ["inst" ])
378
+
379
+ wb_inst_gt = wandb .Image (label2rgb (inst_targ , bg_label = 0 ))
380
+ wb_inst_pred = wandb .Image (label2rgb (inst_pred , bg_label = 0 ))
381
+ pq_inst = panoptic_quality (inst_targ , inst_pred )["pq" ]
382
+
383
+ row = [
384
+ f"test_batch_{ batch_idx } _{ i } " ,
385
+ wb_inst_gt ,
386
+ wb_inst_pred ,
387
+ pq_inst ,
388
+ ]
389
+
390
+ if type_targets is not None :
391
+ per_class_pq = [
392
+ panoptic_quality (
393
+ remap_label (
394
+ get_type_instances (
395
+ inst_targ , type_targets [i ], j
396
+ )
397
+ ),
398
+ remap_label (
399
+ get_type_instances (inst_pred , pred ["type" ], j )
400
+ ),
401
+ )["pq" ]
402
+ for j in self .type_classes .keys ()
403
+ if j != 0
404
+ ]
405
+
406
+ type_classes_set = wandb .Classes (
407
+ [
408
+ {"name" : name , "id" : id }
409
+ for id , name in self .type_classes .items ()
410
+ if id != 0
411
+ ]
412
+ )
413
+ wb_type = wandb .Image (
414
+ im .transpose (1 , 2 , 0 ),
415
+ classes = type_classes_set ,
416
+ masks = {
417
+ "ground_truth" : {"mask_data" : type_targets [i ]},
418
+ "pred" : {"mask_data" : pred ["type" ]},
419
+ },
420
+ )
421
+
422
+ row += [wb_type , * per_class_pq ]
423
+
424
+ if sem_targets is not None :
425
+ per_class_iou = list (
426
+ iou_multiclass (
427
+ sem_targets [i ], pred ["sem" ], len (self .sem_classes )
428
+ )
429
+ )
430
+
431
+ sem_classes_set = wandb .Classes (
432
+ [
433
+ {"name" : name , "id" : id }
434
+ for id , name in self .sem_classes .items ()
435
+ if id != 0
436
+ ]
437
+ )
438
+ wb_sem = wandb .Image (
439
+ im .transpose (1 , 2 , 0 ),
440
+ classes = sem_classes_set ,
441
+ masks = {
442
+ "ground_truth" : {"mask_data" : sem_targets [i ]},
443
+ "pred" : {"mask_data" : pred ["sem" ]},
444
+ },
445
+ )
446
+ row += [wb_sem , * per_class_iou [1 :]]
447
+
448
+ model_res_table .add_data (* row )
449
+
450
+ test_res_at .add (model_res_table , "model_batch_results" )
451
+ trainer .logger .experiment .log_artifact (test_res_at )
452
+
453
+ def on_test_epoch_end (
454
+ self ,
455
+ trainer : pl .Trainer ,
456
+ pl_module : pl .LightningModule ,
457
+ ) -> None :
458
+ """Call the callback at test time."""
459
+ self .epoch_end (trainer , pl_module )
0 commit comments