@@ -168,6 +168,7 @@ def get_gson(
168
168
inst : np .ndarray ,
169
169
type : np .ndarray ,
170
170
classes : Dict [str , int ],
171
+ soft_type : np .ndarray = None ,
171
172
x_offset : int = 0 ,
172
173
y_offset : int = 0 ,
173
174
geo_format : str = "qupath" ,
@@ -182,6 +183,8 @@ def get_gson(
182
183
Cell type labelled semantic segmentation mask. Shape: (H, W).
183
184
classes : Dict[str, int]
184
185
Class dict e.g. {"inflam":1, "epithelial":2, "connec":3}
186
+ soft_type : np.ndarray, default=None
187
+ Softmax type mask. Shape: (C, H, W). C is the number of classes.
185
188
x_offset : int, default=0
186
189
x-coordinate offset. (to set geojson to .mrxs wsi coordinates)
187
190
y_offset : int, default=0
@@ -211,6 +214,14 @@ def get_gson(
211
214
212
215
inst_type = [key for key in classes .keys () if classes [key ] == inst_type ][0 ]
213
216
217
+ # type probabilities
218
+ if soft_type is not None :
219
+ type_probs = soft_type [..., inst_map == inst_id ].mean (axis = 1 )
220
+ inst_type_soft = dict (zip (classes .keys (), type_probs ))
221
+ # convert to float for json serialization
222
+ for key in inst_type_soft .keys ():
223
+ inst_type_soft [key ] = float (inst_type_soft [key ])
224
+
214
225
# get the cell contour coordinates
215
226
contours , _ = cv2 .findContours (inst , cv2 .RETR_TREE , cv2 .CHAIN_APPROX_SIMPLE )
216
227
@@ -230,6 +241,11 @@ def get_gson(
230
241
poly .append (poly [0 ]) # close the polygon
231
242
geo_obj ["geometry" ]["coordinates" ] = [poly ]
232
243
geo_obj ["properties" ]["classification" ]["name" ] = inst_type
244
+ if soft_type is not None :
245
+ geo_obj ["properties" ]["classification" ][
246
+ "probabilities"
247
+ ] = inst_type_soft
248
+
233
249
geo_objs .append (geo_obj )
234
250
235
251
return geo_objs
@@ -364,6 +380,7 @@ def write_mat(
364
380
sem : np .ndarray = None ,
365
381
compute_centorids : bool = False ,
366
382
compute_bboxes : bool = False ,
383
+ ** kwargs ,
367
384
) -> None :
368
385
"""
369
386
Write multiple masks to .mat file.
@@ -429,6 +446,7 @@ def write_gson(
429
446
inst : np .ndarray ,
430
447
type : np .ndarray = None ,
431
448
classes : Dict [str , int ] = None ,
449
+ soft_type : np .ndarray = None ,
432
450
x_offset : int = 0 ,
433
451
y_offset : int = 0 ,
434
452
geo_format : str = "qupath" ,
@@ -444,6 +462,8 @@ def write_gson(
444
462
type : np.ndarray, optional
445
463
Cell type labelled semantic segmentation mask. Shape: (H, W). If None,
446
464
the classes of the objects will be set to {background: 0, foreground: 1}
465
+ soft_type : np.ndarray, default=None
466
+ Softmax type mask. Shape: (C, H, W). C is the number of classes.
447
467
classes : Dict[str, int], optional
448
468
Class dict e.g. {"inflam":1, "epithelial":2, "connec":3}. Ignored if
449
469
`type` is None.
@@ -489,7 +509,7 @@ def write_gson(
489
509
)
490
510
491
511
geo_objs = FileHandler .get_gson (
492
- inst , type , classes , x_offset , y_offset , geo_format
512
+ inst , type , classes , soft_type , x_offset , y_offset , geo_format
493
513
)
494
514
495
515
fname = fname .with_suffix (".json" )
@@ -564,6 +584,7 @@ def save_masks(
564
584
inst = maps ["inst" ],
565
585
type = type_map ,
566
586
classes = classes_type ,
587
+ soft_type = maps ["soft_type" ] if "soft_type" in maps .keys () else None ,
567
588
geo_format = json_format ,
568
589
x_offset = offs ["x" ],
569
590
y_offset = offs ["y" ],
@@ -587,6 +608,7 @@ def save_masks(
587
608
inst = label_semantic (maps ["sem" ]),
588
609
type = maps ["sem" ],
589
610
classes = classes_sem ,
611
+ soft_type = maps ["soft_sem" ] if "soft_sem" in maps .keys () else None ,
590
612
geo_format = json_format ,
591
613
x_offset = offs ["x" ],
592
614
y_offset = offs ["y" ],
0 commit comments