Skip to content

Commit 046c7bc

Browse files
committed
feat(utils): add inferer as inp to bencmarker
1 parent e228a5f commit 046c7bc

File tree

4 files changed

+137
-40
lines changed

4 files changed

+137
-40
lines changed

cellseg_models_pytorch/utils/latency_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def inference_latency(
3030
) -> List[Dict[str, Any]]:
3131
"""Compute the inference-pipeline latency.
3232
33+
NOTE: computes only inference not post-processing latency.
34+
3335
Parameters
3436
----------
3537
reps : int, defalt=1

cellseg_models_pytorch/utils/seg_benchmark.py

Lines changed: 126 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pathos.multiprocessing import ThreadPool as Pool
66
from tqdm import tqdm
77

8+
from cellseg_models_pytorch.inference import BaseInferer
9+
810
from ..metrics import (
911
accuracy_multiclass,
1012
aggregated_jaccard_index,
@@ -43,24 +45,60 @@
4345

4446
class BenchMarker:
4547
def __init__(
46-
self, pred_dir: str, true_dir: str, classes: Dict[str, int] = None
48+
self,
49+
true_dir: str,
50+
pred_dir: str = None,
51+
inferer: BaseInferer = None,
52+
type_classes: Dict[str, int] = None,
53+
sem_classes: Dict[str, int] = None,
4754
) -> None:
4855
"""Run benchmarking, given prediction and ground truth mask folders.
4956
57+
NOTE: Can also take in an Inferer object.
58+
5059
Parameters
5160
----------
52-
pred_dir : str
53-
Path to the prediction .mat files. The pred files have to have matching
54-
names to the gt filenames.
5561
true_dir : str
5662
Path to the ground truth .mat files. The gt files have to have matching
5763
names to the pred filenames.
58-
classes : Dict[str, int], optional
59-
Class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2}
64+
pred_dir : str, optional
65+
Path to the prediction .mat files. The pred files have to have matching
66+
names to the gt filenames. If None, the inferer object storing the
67+
predictions will be used instead.
68+
inferer : BaseInferer, optional
69+
Infere object storing predictions of a model. If None, the `pred_dir`
70+
will be used to load the predictions instead.
71+
type_classes : Dict[str, int], optional
72+
Cell type class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2}
73+
sem_classes : Dict[str, int], optional
74+
Tissue type class dict. E.g. {"bg": 0, "epithel": 1, "stroma": 2}
6075
"""
61-
self.pred_dir = Path(pred_dir)
76+
if pred_dir is None and inferer is None:
77+
raise ValueError(
78+
"Both `inferer` and `pred_dir` cannot be set to None at the same time."
79+
)
80+
6281
self.true_dir = Path(true_dir)
63-
self.classes = classes
82+
self.type_classes = type_classes
83+
self.sem_classes = sem_classes
84+
85+
if pred_dir is not None:
86+
self.pred_dir = Path(pred_dir)
87+
else:
88+
self.pred_dir = None
89+
90+
self.inferer = inferer
91+
92+
if inferer is not None and pred_dir is None:
93+
try:
94+
self.inferer.out_masks
95+
self.inferer.soft_masks
96+
except AttributeError:
97+
raise AttributeError(
98+
"Did not find `out_masks` or `soft_masks` attributes. "
99+
"To get these, run inference with `inferer.infer()`. "
100+
"Remember to set `save_intermediate` to True for the inferer.`"
101+
)
64102

65103
@staticmethod
66104
def compute_inst_metrics(
@@ -100,16 +138,16 @@ def compute_inst_metrics(
100138
f"An illegal metric was given. Got: {metrics}, allowed: {allowed}"
101139
)
102140

103-
# Skip empty GTs
104-
if len(np.unique(true)) > 1:
141+
# Do not run metrics computation if there are no instances in neither of masks
142+
res = {}
143+
if len(np.unique(true)) > 1 or len(np.unique(pred)) > 1:
105144
true = remap_label(true)
106145
pred = remap_label(pred)
107146

108147
met = {}
109148
for m in metrics:
110149
met[m] = INST_METRIC_LOOKUP[m]
111150

112-
res = {}
113151
for k, m in met.items():
114152
score = m(true, pred)
115153

@@ -121,8 +159,19 @@ def compute_inst_metrics(
121159

122160
res["name"] = name
123161
res["type"] = type
162+
else:
163+
res["name"] = name
164+
res["type"] = type
124165

125-
return res
166+
for m in metrics:
167+
if m == "pq":
168+
res["pq"] = -1.0
169+
res["sq"] = -1.0
170+
res["dq"] = -1.0
171+
else:
172+
res[m] = -1.0
173+
174+
return res
126175

127176
@staticmethod
128177
def compute_sem_metrics(
@@ -158,6 +207,9 @@ def compute_sem_metrics(
158207
A dictionary where metric names are mapped to metric values.
159208
e.g. {"iou": 0.5, "f1score": 0.55, "name": "sample1"}
160209
"""
210+
if not isinstance(metrics, tuple) and not isinstance(metrics, list):
211+
raise ValueError("`metrics` must be either a list or tuple of values.")
212+
161213
allowed = list(SEM_METRIC_LOOKUP.keys())
162214
if not all([m in allowed for m in metrics]):
163215
raise ValueError(
@@ -227,20 +279,6 @@ def run_metrics(
227279

228280
return metrics
229281

230-
def _read_files(self) -> List[Tuple[np.ndarray, np.ndarray, str]]:
231-
"""Read in the files from the input folders."""
232-
preds = sorted(self.pred_dir.glob("*"))
233-
trues = sorted(self.true_dir.glob("*"))
234-
235-
masks = []
236-
for truef, predf in zip(trues, preds):
237-
true = FileHandler.read_mat(truef, return_all=True)
238-
pred = FileHandler.read_mat(predf, return_all=True)
239-
name = truef.name
240-
masks.append((true, pred, name))
241-
242-
return masks
243-
244282
def run_inst_benchmark(
245283
self, how: str = "binary", metrics: Tuple[str, ...] = ("pq",)
246284
) -> None:
@@ -268,17 +306,32 @@ def run_inst_benchmark(
268306
if how not in allowed:
269307
raise ValueError(f"Illegal arg `how`. Got: {how}, Allowed: {allowed}")
270308

271-
masks = self._read_files()
309+
trues = sorted(self.true_dir.glob("*"))
310+
311+
preds = None
312+
if self.pred_dir is not None:
313+
preds = sorted(self.pred_dir.glob("*"))
314+
315+
ik = "inst" if self.pred_dir is None else "inst_map"
316+
tk = "type" if self.pred_dir is None else "type_map"
272317

273318
res = []
274-
if how == "multi" and self.classes is not None:
275-
for c, i in list(self.classes.items())[1:]:
319+
if how == "multi" and self.type_classes is not None:
320+
for c, i in list(self.type_classes.items())[1:]:
276321
args = []
277-
for true, pred, name in masks:
322+
for j, true_fn in enumerate(trues):
323+
name = true_fn.name
324+
true = FileHandler.read_mat(true_fn, return_all=True)
325+
326+
if preds is None:
327+
pred = self.inferer.out_masks[name[:-4]]
328+
else:
329+
pred = FileHandler.read_mat(preds[j], return_all=True)
330+
278331
true_inst = true["inst_map"]
279-
pred_inst = pred["inst_map"]
280332
true_type = true["type_map"]
281-
pred_type = pred["type_map"]
333+
pred_inst = pred[ik]
334+
pred_type = pred[tk]
282335

283336
pred_type = get_type_instances(pred_inst, pred_type, i)
284337
true_type = get_type_instances(true_inst, true_type, i)
@@ -287,9 +340,17 @@ def run_inst_benchmark(
287340
res.extend([metric for metric in met if metric])
288341
else:
289342
args = []
290-
for true, pred, name in masks:
343+
for i, true_fn in enumerate(trues):
344+
name = true_fn.name
345+
true = FileHandler.read_mat(true_fn, return_all=True)
346+
347+
if preds is None:
348+
pred = self.inferer.out_masks[name[:-4]]
349+
else:
350+
pred = FileHandler.read_mat(preds[i], return_all=True)
351+
291352
true = true["inst_map"]
292-
pred = pred["inst_map"]
353+
pred = pred[ik]
293354
args.append((true, pred, name, metrics))
294355
met = self.run_metrics(args, "inst", "binary instance seg")
295356
res.extend([metric for metric in met if metric])
@@ -310,14 +371,40 @@ def run_sem_benchmark(self, metrics: Tuple[str, ...] = ("iou",)) -> Dict[str, An
310371
Dict[str, Any]:
311372
Dictionary mapping the metrics to values + metadata.
312373
"""
313-
masks = self._read_files()
374+
trues = sorted(self.true_dir.glob("*"))
375+
376+
preds = None
377+
if self.pred_dir is not None:
378+
preds = sorted(self.pred_dir.glob("*"))
379+
380+
sk = "sem" if self.pred_dir is None else "sem_map"
314381

315382
args = []
316-
for true, pred, name in masks:
383+
for i, true_fn in enumerate(trues):
384+
name = true_fn.name
385+
true = FileHandler.read_mat(true_fn, return_all=True)
386+
387+
if preds is None:
388+
pred = self.inferer.out_masks[name[:-4]]
389+
else:
390+
pred = FileHandler.read_mat(preds[i], return_all=True)
317391
true = true["sem_map"]
318-
pred = pred["sem_map"]
319-
args.append((true, pred, name, len(self.classes), metrics))
392+
pred = pred[sk]
393+
args.append((true, pred, name, len(self.sem_classes), metrics))
394+
320395
met = self.run_metrics(args, "sem", "semantic seg")
321-
res = [metric for metric in met if metric]
396+
ires = [metric for metric in met if metric]
397+
398+
# re-format
399+
res = []
400+
for r in ires:
401+
for k, val in self.sem_classes.items():
402+
cc = {
403+
"name": r["name"],
404+
"type": k,
405+
}
406+
for m in metrics:
407+
cc[m] = r[m][val]
408+
res.append(cc)
322409

323410
return res

cellseg_models_pytorch/utils/tests/test_seg_benchmark.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515

1616
@pytest.mark.parametrize("how", ["binary", "multi", None])
1717
def test_sem_seg_bm(mask_patch_dir, how):
18-
bm = BenchMarker(pred_dir=mask_patch_dir, true_dir=mask_patch_dir, classes=classes)
18+
bm = BenchMarker(
19+
pred_dir=mask_patch_dir,
20+
true_dir=mask_patch_dir,
21+
type_classes=classes,
22+
sem_classes=classes,
23+
)
1924

2025
if how == "binary":
2126
res = bm.run_inst_benchmark(how, metrics=("dice2",))
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Use the inferer class as input to segmentation benchmarker class

0 commit comments

Comments
 (0)