5
5
from pathos .multiprocessing import ThreadPool as Pool
6
6
from tqdm import tqdm
7
7
8
+ from cellseg_models_pytorch .inference import BaseInferer
9
+
8
10
from ..metrics import (
9
11
accuracy_multiclass ,
10
12
aggregated_jaccard_index ,
43
45
44
46
class BenchMarker :
45
47
def __init__ (
46
- self , pred_dir : str , true_dir : str , classes : Dict [str , int ] = None
48
+ self ,
49
+ true_dir : str ,
50
+ pred_dir : str = None ,
51
+ inferer : BaseInferer = None ,
52
+ type_classes : Dict [str , int ] = None ,
53
+ sem_classes : Dict [str , int ] = None ,
47
54
) -> None :
48
55
"""Run benchmarking, given prediction and ground truth mask folders.
49
56
57
+ NOTE: Can also take in an Inferer object.
58
+
50
59
Parameters
51
60
----------
52
- pred_dir : str
53
- Path to the prediction .mat files. The pred files have to have matching
54
- names to the gt filenames.
55
61
true_dir : str
56
62
Path to the ground truth .mat files. The gt files have to have matching
57
63
names to the pred filenames.
58
- classes : Dict[str, int], optional
59
- Class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2}
64
+ pred_dir : str, optional
65
+ Path to the prediction .mat files. The pred files have to have matching
66
+ names to the gt filenames. If None, the inferer object storing the
67
+ predictions will be used instead.
68
+ inferer : BaseInferer, optional
69
+ Infere object storing predictions of a model. If None, the `pred_dir`
70
+ will be used to load the predictions instead.
71
+ type_classes : Dict[str, int], optional
72
+ Cell type class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2}
73
+ sem_classes : Dict[str, int], optional
74
+ Tissue type class dict. E.g. {"bg": 0, "epithel": 1, "stroma": 2}
60
75
"""
61
- self .pred_dir = Path (pred_dir )
76
+ if pred_dir is None and inferer is None :
77
+ raise ValueError (
78
+ "Both `inferer` and `pred_dir` cannot be set to None at the same time."
79
+ )
80
+
62
81
self .true_dir = Path (true_dir )
63
- self .classes = classes
82
+ self .type_classes = type_classes
83
+ self .sem_classes = sem_classes
84
+
85
+ if pred_dir is not None :
86
+ self .pred_dir = Path (pred_dir )
87
+ else :
88
+ self .pred_dir = None
89
+
90
+ self .inferer = inferer
91
+
92
+ if inferer is not None and pred_dir is None :
93
+ try :
94
+ self .inferer .out_masks
95
+ self .inferer .soft_masks
96
+ except AttributeError :
97
+ raise AttributeError (
98
+ "Did not find `out_masks` or `soft_masks` attributes. "
99
+ "To get these, run inference with `inferer.infer()`. "
100
+ "Remember to set `save_intermediate` to True for the inferer.`"
101
+ )
64
102
65
103
@staticmethod
66
104
def compute_inst_metrics (
@@ -100,16 +138,16 @@ def compute_inst_metrics(
100
138
f"An illegal metric was given. Got: { metrics } , allowed: { allowed } "
101
139
)
102
140
103
- # Skip empty GTs
104
- if len (np .unique (true )) > 1 :
141
+ # Do not run metrics computation if there are no instances in neither of masks
142
+ res = {}
143
+ if len (np .unique (true )) > 1 or len (np .unique (pred )) > 1 :
105
144
true = remap_label (true )
106
145
pred = remap_label (pred )
107
146
108
147
met = {}
109
148
for m in metrics :
110
149
met [m ] = INST_METRIC_LOOKUP [m ]
111
150
112
- res = {}
113
151
for k , m in met .items ():
114
152
score = m (true , pred )
115
153
@@ -121,8 +159,19 @@ def compute_inst_metrics(
121
159
122
160
res ["name" ] = name
123
161
res ["type" ] = type
162
+ else :
163
+ res ["name" ] = name
164
+ res ["type" ] = type
124
165
125
- return res
166
+ for m in metrics :
167
+ if m == "pq" :
168
+ res ["pq" ] = - 1.0
169
+ res ["sq" ] = - 1.0
170
+ res ["dq" ] = - 1.0
171
+ else :
172
+ res [m ] = - 1.0
173
+
174
+ return res
126
175
127
176
@staticmethod
128
177
def compute_sem_metrics (
@@ -158,6 +207,9 @@ def compute_sem_metrics(
158
207
A dictionary where metric names are mapped to metric values.
159
208
e.g. {"iou": 0.5, "f1score": 0.55, "name": "sample1"}
160
209
"""
210
+ if not isinstance (metrics , tuple ) and not isinstance (metrics , list ):
211
+ raise ValueError ("`metrics` must be either a list or tuple of values." )
212
+
161
213
allowed = list (SEM_METRIC_LOOKUP .keys ())
162
214
if not all ([m in allowed for m in metrics ]):
163
215
raise ValueError (
@@ -227,20 +279,6 @@ def run_metrics(
227
279
228
280
return metrics
229
281
230
- def _read_files (self ) -> List [Tuple [np .ndarray , np .ndarray , str ]]:
231
- """Read in the files from the input folders."""
232
- preds = sorted (self .pred_dir .glob ("*" ))
233
- trues = sorted (self .true_dir .glob ("*" ))
234
-
235
- masks = []
236
- for truef , predf in zip (trues , preds ):
237
- true = FileHandler .read_mat (truef , return_all = True )
238
- pred = FileHandler .read_mat (predf , return_all = True )
239
- name = truef .name
240
- masks .append ((true , pred , name ))
241
-
242
- return masks
243
-
244
282
def run_inst_benchmark (
245
283
self , how : str = "binary" , metrics : Tuple [str , ...] = ("pq" ,)
246
284
) -> None :
@@ -268,17 +306,32 @@ def run_inst_benchmark(
268
306
if how not in allowed :
269
307
raise ValueError (f"Illegal arg `how`. Got: { how } , Allowed: { allowed } " )
270
308
271
- masks = self ._read_files ()
309
+ trues = sorted (self .true_dir .glob ("*" ))
310
+
311
+ preds = None
312
+ if self .pred_dir is not None :
313
+ preds = sorted (self .pred_dir .glob ("*" ))
314
+
315
+ ik = "inst" if self .pred_dir is None else "inst_map"
316
+ tk = "type" if self .pred_dir is None else "type_map"
272
317
273
318
res = []
274
- if how == "multi" and self .classes is not None :
275
- for c , i in list (self .classes .items ())[1 :]:
319
+ if how == "multi" and self .type_classes is not None :
320
+ for c , i in list (self .type_classes .items ())[1 :]:
276
321
args = []
277
- for true , pred , name in masks :
322
+ for j , true_fn in enumerate (trues ):
323
+ name = true_fn .name
324
+ true = FileHandler .read_mat (true_fn , return_all = True )
325
+
326
+ if preds is None :
327
+ pred = self .inferer .out_masks [name [:- 4 ]]
328
+ else :
329
+ pred = FileHandler .read_mat (preds [j ], return_all = True )
330
+
278
331
true_inst = true ["inst_map" ]
279
- pred_inst = pred ["inst_map" ]
280
332
true_type = true ["type_map" ]
281
- pred_type = pred ["type_map" ]
333
+ pred_inst = pred [ik ]
334
+ pred_type = pred [tk ]
282
335
283
336
pred_type = get_type_instances (pred_inst , pred_type , i )
284
337
true_type = get_type_instances (true_inst , true_type , i )
@@ -287,9 +340,17 @@ def run_inst_benchmark(
287
340
res .extend ([metric for metric in met if metric ])
288
341
else :
289
342
args = []
290
- for true , pred , name in masks :
343
+ for i , true_fn in enumerate (trues ):
344
+ name = true_fn .name
345
+ true = FileHandler .read_mat (true_fn , return_all = True )
346
+
347
+ if preds is None :
348
+ pred = self .inferer .out_masks [name [:- 4 ]]
349
+ else :
350
+ pred = FileHandler .read_mat (preds [i ], return_all = True )
351
+
291
352
true = true ["inst_map" ]
292
- pred = pred ["inst_map" ]
353
+ pred = pred [ik ]
293
354
args .append ((true , pred , name , metrics ))
294
355
met = self .run_metrics (args , "inst" , "binary instance seg" )
295
356
res .extend ([metric for metric in met if metric ])
@@ -310,14 +371,40 @@ def run_sem_benchmark(self, metrics: Tuple[str, ...] = ("iou",)) -> Dict[str, An
310
371
Dict[str, Any]:
311
372
Dictionary mapping the metrics to values + metadata.
312
373
"""
313
- masks = self ._read_files ()
374
+ trues = sorted (self .true_dir .glob ("*" ))
375
+
376
+ preds = None
377
+ if self .pred_dir is not None :
378
+ preds = sorted (self .pred_dir .glob ("*" ))
379
+
380
+ sk = "sem" if self .pred_dir is None else "sem_map"
314
381
315
382
args = []
316
- for true , pred , name in masks :
383
+ for i , true_fn in enumerate (trues ):
384
+ name = true_fn .name
385
+ true = FileHandler .read_mat (true_fn , return_all = True )
386
+
387
+ if preds is None :
388
+ pred = self .inferer .out_masks [name [:- 4 ]]
389
+ else :
390
+ pred = FileHandler .read_mat (preds [i ], return_all = True )
317
391
true = true ["sem_map" ]
318
- pred = pred ["sem_map" ]
319
- args .append ((true , pred , name , len (self .classes ), metrics ))
392
+ pred = pred [sk ]
393
+ args .append ((true , pred , name , len (self .sem_classes ), metrics ))
394
+
320
395
met = self .run_metrics (args , "sem" , "semantic seg" )
321
- res = [metric for metric in met if metric ]
396
+ ires = [metric for metric in met if metric ]
397
+
398
+ # re-format
399
+ res = []
400
+ for r in ires :
401
+ for k , val in self .sem_classes .items ():
402
+ cc = {
403
+ "name" : r ["name" ],
404
+ "type" : k ,
405
+ }
406
+ for m in metrics :
407
+ cc [m ] = r [m ][val ]
408
+ res .append (cc )
322
409
323
410
return res
0 commit comments