Skip to content

Commit f27279f

Browse files
committed
feat: add probs to gson
1 parent 59d323f commit f27279f

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

cellseg_models_pytorch/utils/file_manager.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def get_gson(
168168
inst: np.ndarray,
169169
type: np.ndarray,
170170
classes: Dict[str, int],
171+
soft_type: np.ndarray = None,
171172
x_offset: int = 0,
172173
y_offset: int = 0,
173174
geo_format: str = "qupath",
@@ -182,6 +183,8 @@ def get_gson(
182183
Cell type labelled semantic segmentation mask. Shape: (H, W).
183184
classes : Dict[str, int]
184185
Class dict e.g. {"inflam":1, "epithelial":2, "connec":3}
186+
soft_type : np.ndarray, default=None
187+
Softmax type mask. Shape: (C, H, W). C is the number of classes.
185188
x_offset : int, default=0
186189
x-coordinate offset. (to set geojson to .mrxs wsi coordinates)
187190
y_offset : int, default=0
@@ -211,6 +214,14 @@ def get_gson(
211214

212215
inst_type = [key for key in classes.keys() if classes[key] == inst_type][0]
213216

217+
# type probabilities
218+
if soft_type is not None:
219+
type_probs = soft_type[..., inst_map == inst_id].mean(axis=1)
220+
inst_type_soft = dict(zip(classes.keys(), type_probs))
221+
# convert to float for json serialization
222+
for key in inst_type_soft.keys():
223+
inst_type_soft[key] = float(inst_type_soft[key])
224+
214225
# get the cell contour coordinates
215226
contours, _ = cv2.findContours(inst, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
216227

@@ -230,6 +241,11 @@ def get_gson(
230241
poly.append(poly[0]) # close the polygon
231242
geo_obj["geometry"]["coordinates"] = [poly]
232243
geo_obj["properties"]["classification"]["name"] = inst_type
244+
if soft_type is not None:
245+
geo_obj["properties"]["classification"][
246+
"probabilities"
247+
] = inst_type_soft
248+
233249
geo_objs.append(geo_obj)
234250

235251
return geo_objs
@@ -364,6 +380,7 @@ def write_mat(
364380
sem: np.ndarray = None,
365381
compute_centorids: bool = False,
366382
compute_bboxes: bool = False,
383+
**kwargs,
367384
) -> None:
368385
"""
369386
Write multiple masks to .mat file.
@@ -429,6 +446,7 @@ def write_gson(
429446
inst: np.ndarray,
430447
type: np.ndarray = None,
431448
classes: Dict[str, int] = None,
449+
soft_type: np.ndarray = None,
432450
x_offset: int = 0,
433451
y_offset: int = 0,
434452
geo_format: str = "qupath",
@@ -444,6 +462,8 @@ def write_gson(
444462
type : np.ndarray, optional
445463
Cell type labelled semantic segmentation mask. Shape: (H, W). If None,
446464
the classes of the objects will be set to {background: 0, foreground: 1}
465+
soft_type : np.ndarray, default=None
466+
Softmax type mask. Shape: (C, H, W). C is the number of classes.
447467
classes : Dict[str, int], optional
448468
Class dict e.g. {"inflam":1, "epithelial":2, "connec":3}. Ignored if
449469
`type` is None.
@@ -489,7 +509,7 @@ def write_gson(
489509
)
490510

491511
geo_objs = FileHandler.get_gson(
492-
inst, type, classes, x_offset, y_offset, geo_format
512+
inst, type, classes, soft_type, x_offset, y_offset, geo_format
493513
)
494514

495515
fname = fname.with_suffix(".json")
@@ -564,6 +584,7 @@ def save_masks(
564584
inst=maps["inst"],
565585
type=type_map,
566586
classes=classes_type,
587+
soft_type=maps["soft_type"] if "soft_type" in maps.keys() else None,
567588
geo_format=json_format,
568589
x_offset=offs["x"],
569590
y_offset=offs["y"],
@@ -587,6 +608,7 @@ def save_masks(
587608
inst=label_semantic(maps["sem"]),
588609
type=maps["sem"],
589610
classes=classes_sem,
611+
soft_type=maps["soft_sem"] if "soft_sem" in maps.keys() else None,
590612
geo_format=json_format,
591613
x_offset=offs["x"],
592614
y_offset=offs["y"],

0 commit comments

Comments
 (0)