12
12
13
13
from ..functional import iou
14
14
15
- __all__ = ["WandbImageCallback" , "WandbClassMetricCallback " ]
15
+ __all__ = ["WandbImageCallback" , "WandbClassBarCallback" , "WandbClassLineCallback " ]
16
16
17
17
18
18
class WandbImageCallback (pl .Callback ):
@@ -104,26 +104,22 @@ def on_validation_batch_end(
104
104
trainer .logger .experiment .log (log_dict )
105
105
106
106
107
- class WandbClassMetricCallback (pl .Callback ):
107
+ class WandbIoUCallback (pl .Callback ):
108
108
def __init__ (
109
109
self ,
110
110
type_classes : Dict [str , int ],
111
111
sem_classes : Optional [Dict [str , int ]],
112
112
freq : int = 100 ,
113
- return_series : bool = True ,
114
- return_bar : bool = True ,
115
- return_table : bool = False ,
116
113
) -> None :
117
- """Call back to compute per- class ious and log them to wandb."""
114
+ """Create a base class for IoU wandb callbacks ."""
118
115
super ().__init__ ()
119
116
self .type_classes = type_classes
120
117
self .sem_classes = sem_classes
121
118
self .freq = freq
122
- self .return_series = return_series
123
- self .return_bar = return_bar
124
- self .return_table = return_table
125
- self .cell_ious = np .empty (0 )
126
- self .sem_ious = np .empty (0 )
119
+
120
+ def batch_end (self ) -> None :
121
+ """Abstract batch end method."""
122
+ raise NotImplementedError
127
123
128
124
def compute (
129
125
self ,
@@ -139,36 +135,47 @@ def compute(
139
135
met = iou (pred , target ).mean (dim = 0 )
140
136
return met .to ("cpu" ).numpy ()
141
137
142
- def get_table (
143
- self , ious : np .ndarray , x : np .ndarray , classes : Dict [int , str ]
144
- ) -> wandb .Table :
145
- """Return a wandb Table with step, iou and label values for every step."""
146
- batch_data = [
147
- [xi * self .freq , c , np .round (ious [xi , i ], 4 )]
148
- for i , c , in classes .items ()
149
- for xi in x
150
- ]
138
+ def on_train_batch_end (
139
+ self ,
140
+ trainer : pl .Trainer ,
141
+ pl_module : pl .LightningModule ,
142
+ outputs : Dict [str , torch .Tensor ],
143
+ batch : Dict [str , torch .Tensor ],
144
+ batch_idx : int ,
145
+ dataloader_idx : int ,
146
+ ) -> None :
147
+ """Log the inputs and outputs of the model to wandb."""
148
+ self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "train" )
151
149
152
- return wandb .Table (data = batch_data , columns = ["step" , "label" , "value" ])
150
+ def on_validation_batch_end (
151
+ self ,
152
+ trainer : pl .Trainer ,
153
+ pl_module : pl .LightningModule ,
154
+ outputs : Dict [str , torch .Tensor ],
155
+ batch : Dict [str , torch .Tensor ],
156
+ batch_idx : int ,
157
+ dataloader_idx : int ,
158
+ ) -> None :
159
+ """Log the inputs and outputs of the model to wandb."""
160
+ self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "val" )
161
+
162
+
163
+ class WandbClassBarCallback (WandbIoUCallback ):
164
+ def __init__ (
165
+ self ,
166
+ type_classes : Dict [str , int ],
167
+ sem_classes : Optional [Dict [str , int ]],
168
+ freq : int = 100 ,
169
+ ) -> None :
170
+ """Create a wandb callback that logs per-class mIoU at batch ends."""
171
+ super ().__init__ (type_classes , sem_classes , freq )
153
172
154
173
def get_bar (self , iou : np .ndarray , classes : Dict [int , str ], title : str ) -> Any :
155
174
"""Return a wandb bar plot object of the current per class iou values."""
156
175
batch_data = [[lab , val ] for lab , val in zip (list (classes .values ()), iou )]
157
176
table = wandb .Table (data = batch_data , columns = ["label" , "value" ])
158
177
return wandb .plot .bar (table , "label" , "value" , title = title )
159
178
160
- def get_series (
161
- self , ious : np .ndarray , x : np .ndarray , classes : Dict [int , str ], title : str
162
- ) -> Any :
163
- """Return a wandb series plot obj of the per class iou values over timesteps."""
164
- return wandb .plot .line_series (
165
- xs = x .tolist (),
166
- ys = [ious [:, c ].tolist () for c in classes .keys ()],
167
- keys = list (classes .values ()),
168
- title = title ,
169
- xname = "step" ,
170
- )
171
-
172
179
def batch_end (
173
180
self ,
174
181
trainer : pl .Trainer ,
@@ -182,69 +189,54 @@ def batch_end(
182
189
log_dict = {}
183
190
if "type" in list (batch .keys ()):
184
191
iou = self .compute ("type" , outputs , batch )
185
- self .cell_ious = np .append (self .cell_ious , iou )
186
- cell_ious = self .cell_ious .reshape (- 1 , len (self .type_classes ))
187
- x = np .arange (cell_ious .shape [0 ])
188
-
189
- if self .return_table :
190
- log_dict [f"{ phase } /type_ious_table" ] = self .get_table (
191
- cell_ious , x , self .type_classes
192
- )
193
-
194
- if self .return_series :
195
- log_dict [f"{ phase } /type_ious_per_class" ] = self .get_series (
196
- cell_ious , x , self .type_classes , title = "Per type class mIoU"
197
- )
198
-
199
- if self .return_bar :
200
- log_dict [f"{ phase } /type_ious_bar" ] = self .get_bar (
201
- list (iou ), self .type_classes , title = "Cell class mIoUs"
202
- )
192
+ log_dict [f"{ phase } /type_ious_bar" ] = self .get_bar (
193
+ list (iou ), self .type_classes , title = "Cell class mIoUs"
194
+ )
203
195
204
196
if "sem" in list (batch .keys ()):
205
197
iou = self .compute ("sem" , outputs , batch )
206
-
207
- self .sem_ious = np .append (self .sem_ious , iou )
208
- sem_ious = self .sem_ious .reshape (- 1 , len (self .sem_classes ))
209
- x = np .arange (sem_ious .shape [0 ])
210
-
211
- if self .return_table :
212
- log_dict [f"{ phase } /sem_ious_table" ] = self .get_table (
213
- cell_ious , x , self .type_classes
214
- )
215
-
216
- if self .return_series :
217
- log_dict [f"{ phase } /sem_ious_per_class" ] = self .get_series (
218
- cell_ious , x , self .type_classes , title = "Per sem class mIoU"
219
- )
220
-
221
- if self .return_bar :
222
- log_dict [f"{ phase } /sem_ious_bar" ] = self .get_bar (
223
- list (iou ), self .type_classes , title = "Sem class mIoUs"
224
- )
198
+ log_dict [f"{ phase } /sem_ious_bar" ] = self .get_bar (
199
+ list (iou ), self .sem_classes , title = "Sem class mIoUs"
200
+ )
225
201
226
202
trainer .logger .experiment .log (log_dict )
227
203
228
- def on_train_batch_end (
204
+
205
+ class WandbClassLineCallback (WandbIoUCallback ):
206
+ def __init__ (
229
207
self ,
230
- trainer : pl .Trainer ,
231
- pl_module : pl .LightningModule ,
232
- outputs : Dict [str , torch .Tensor ],
233
- batch : Dict [str , torch .Tensor ],
234
- batch_idx : int ,
235
- dataloader_idx : int ,
208
+ type_classes : Dict [str , int ],
209
+ sem_classes : Optional [Dict [str , int ]],
210
+ freq : int = 100 ,
236
211
) -> None :
237
- """Log the inputs and outputs of the model to wandb ."""
238
- self . batch_end ( trainer , outputs [ "soft_masks" ], batch , batch_idx , phase = "train" )
212
+ """Create a wandb callback that logs per-class mIoU at batch ends ."""
213
+ super (). __init__ ( type_classes , sem_classes , freq )
239
214
240
- def on_validation_batch_end (
215
+ def get_points (self , iou : np .ndarray , classes : Dict [int , str ]) -> Any :
216
+ """Return a wandb bar plot object of the current per class iou values."""
217
+ return {lab : val for lab , val in zip (list (classes .values ()), iou )}
218
+
219
+ def batch_end (
241
220
self ,
242
221
trainer : pl .Trainer ,
243
- pl_module : pl .LightningModule ,
244
222
outputs : Dict [str , torch .Tensor ],
245
223
batch : Dict [str , torch .Tensor ],
246
224
batch_idx : int ,
247
- dataloader_idx : int ,
225
+ phase : str ,
248
226
) -> None :
249
- """Log the inputs and outputs of the model to wandb."""
250
- self .batch_end (trainer , outputs ["soft_masks" ], batch , batch_idx , phase = "val" )
227
+ """Log metrics at every 100th step to wandb."""
228
+ if batch_idx % self .freq == 0 :
229
+ log_dict = {}
230
+ if "type" in list (batch .keys ()):
231
+ iou = self .compute ("type" , outputs , batch )
232
+ log_dict [f"{ phase } /type_ious_points" ] = self .get_points (
233
+ list (iou ), self .type_classes
234
+ )
235
+
236
+ if "sem" in list (batch .keys ()):
237
+ iou = self .compute ("sem" , outputs , batch )
238
+ log_dict [f"{ phase } /sem_ious_points" ] = self .get_points (
239
+ list (iou ), self .sem_classes
240
+ )
241
+
242
+ trainer .logger .experiment .log (log_dict )
0 commit comments