1
- """Imported most of the stuff from stardist repo. Minor modifications.
1
+ """Copied the polygons to label utilities from stardist repo (with minor modifications) .
2
2
3
3
BSD 3-Clause License
4
4
34
34
from typing import Tuple
35
35
36
36
import numpy as np
37
- import scipy .ndimage as ndi
38
- from skimage import img_as_ubyte
39
37
from skimage .draw import polygon
40
- from skimage .measure import regionprops
38
+ from skimage .morphology import disk , erosion
41
39
42
- from ...utils import bounding_box , remap_label , remove_small_objects
43
- from .drfns import find_maxima , h_minima_reconstruction
40
+ from .nms import get_bboxes , nms_stardist
44
41
45
- __all__ = ["post_proc_stardist" , " post_proc_stardist_orig" , "polygons_to_label" ]
42
+ __all__ = ["post_proc_stardist_orig" , "polygons_to_label" ]
46
43
47
44
48
45
def polygons_to_label_coord (
@@ -191,42 +188,25 @@ def polygons_to_label(
191
188
return polygons_to_label_coord (coord , shape = shape , labels = ind )
192
189
193
190
194
- def _clean_up (inst_map : np .ndarray , size : int = 150 , ** kwargs ) -> np .ndarray :
195
- """Clean up overlapping instances."""
196
- mask = remap_label (inst_map .copy ())
197
- mask_connected = ndi .label (mask )[0 ]
198
-
199
- labels_connected = np .unique (mask_connected )[1 :]
200
- for lab in labels_connected :
201
- inst = np .array (mask_connected == lab , copy = True )
202
- y1 , y2 , x1 , x2 = bounding_box (inst )
203
- y1 = y1 - 2 if y1 - 2 >= 0 else y1
204
- x1 = x1 - 2 if x1 - 2 >= 0 else x1
205
- x2 = x2 + 2 if x2 + 2 <= mask_connected .shape [1 ] - 1 else x2
206
- y2 = y2 + 2 if y2 + 2 <= mask_connected .shape [0 ] - 1 else y2
207
-
208
- box_insts = mask [y1 :y2 , x1 :x2 ]
209
- if len (np .unique (ndi .label (box_insts )[0 ])) <= 2 :
210
- real_labels , counts = np .unique (box_insts , return_counts = True )
211
- real_labels = real_labels [1 :]
212
- counts = counts [1 :]
213
- max_pixels = np .max (counts )
214
- max_label = real_labels [np .argmax (counts )]
215
- for real_lab , count in list (zip (list (real_labels ), list (counts ))):
216
- if count < max_pixels :
217
- if count < size :
218
- mask [mask == real_lab ] = max_label
219
-
220
- return mask
221
-
222
-
223
191
def post_proc_stardist (
224
- dist_map : np .ndarray , stardist_map : np .ndarray , thresh : float = 0.4 , ** kwargs
192
+ dist_map : np .ndarray ,
193
+ stardist_map : np .ndarray ,
194
+ score_thresh : float = 0.5 ,
195
+ iou_thresh : float = 0.5 ,
196
+ trim_bboxes : bool = True ,
197
+ ** kwargs ,
225
198
) -> np .ndarray :
226
- """Run post-processing for stardist.
199
+ """Run post-processing for stardist outputs.
200
+
201
+ NOTE: This is not the original cpp version.
202
+ This is a python re-implementation of the stardidst post-processing
203
+ pipeline that uses non-maximum-suppression. Here, critical parts of the
204
+ nms are accelerated with `numba` and `scipy.spatial.KDtree`.
227
205
228
- NOTE: This is not the original version that uses NMS.
229
- This is rather a workaround that is a little slower.
206
+ NOTE:
207
+ This implementaiton of the stardist post-processing is actually nearly twice
208
+ faster than the original version if `trim_bboxes` is set to True. The resulting
209
+ segmentation is not an exact match but the differences are mostly neglible.
230
210
231
211
Parameters
232
212
----------
@@ -236,37 +216,75 @@ def post_proc_stardist(
236
216
Predicted radial distances. Shape: (n_rays, H, W).
237
217
thresh : float, default=0.4
238
218
Threshold for the regressed distance transform.
219
+ trim_bboxes : bool, default=True
220
+ If True, The non-zero pixels are computed only from the cell contours
221
+ which prunes down the pixel search space drastically.
239
222
240
223
Returns
241
224
-------
242
225
np.ndarray:
243
226
Instance labelled mask. Shape: (H, W).
244
227
"""
245
- stardist_map = stardist_map .transpose (1 , 2 , 0 )
246
- mask = _ind_prob_thresh (dist_map , thresh , b = 2 )
247
-
248
- # invert distmap
249
- inv_dist_map = 255 - img_as_ubyte (dist_map )
250
-
251
- # find markers from minima erosion reconstructed maxima of inv dist map
252
- reconstructed = h_minima_reconstruction (inv_dist_map )
253
- markers = find_maxima (reconstructed , mask = mask )
254
- markers = ndi .label (markers )[0 ]
255
- markers = remove_small_objects (markers , min_size = 5 )
256
- points = np .array (
257
- tuple (np .array (r .centroid ).astype (int ) for r in regionprops (markers ))
258
- )
228
+ if (
229
+ not dist_map .ndim == 2
230
+ and not stardist_map .ndim == 3
231
+ and not dist_map .shape == stardist_map .shape [:2 ]
232
+ ):
233
+ raise ValueError (
234
+ "Illegal input shapes. Make sure that: "
235
+ f"`dist_map` has to have shape: (H, W). Got: { dist_map .shape } "
236
+ f"`stardist_map` has to have shape (H, W, nrays). Got: { stardist_map .shape } "
237
+ )
259
238
260
- if len ( points ) == 0 :
261
- return np .zeros_like ( mask )
239
+ dist = np . asarray ( stardist_map ). transpose ( 1 , 2 , 0 )
240
+ prob = np .asarray ( dist_map )
262
241
263
- dist = stardist_map [ tuple ( points . T )]
264
- scores = dist_map [ tuple ( points . T )]
242
+ # threshold the edt distance transform map
243
+ mask = _ind_prob_thresh ( prob , score_thresh )
265
244
266
- labels = polygons_to_label (
267
- dist , points , prob = scores , shape = mask .shape , scale_dist = (1 , 1 )
245
+ # get only the mask contours to trim down bbox search space
246
+ if trim_bboxes :
247
+ fp = disk (2 )
248
+ mask -= erosion (mask , fp )
249
+
250
+ points = np .stack (np .where (mask ), axis = 1 )
251
+
252
+ # Get only non-zero pixels of the transforms
253
+ dist = dist [mask > 0 ]
254
+ scores = prob [mask > 0 ]
255
+
256
+ # sort descendingly
257
+ ind = np .argsort (scores )[::- 1 ]
258
+ dist = dist [ind ]
259
+ scores = scores [ind ]
260
+ points = points [ind ]
261
+
262
+ # get bounding boxes
263
+ x1 , y1 , x2 , y2 , areas , max_dist = get_bboxes (dist , points )
264
+ boxes = np .stack ([x1 , y1 , x2 , y2 ], axis = 1 )
265
+
266
+ # consider only boxes above score threshold
267
+ score_cond = scores >= score_thresh
268
+ boxes = boxes [score_cond ]
269
+ scores = scores [score_cond ]
270
+ areas = areas [score_cond ]
271
+
272
+ # run nms
273
+ inds = nms_stardist (
274
+ boxes ,
275
+ points ,
276
+ scores ,
277
+ areas ,
278
+ max_dist ,
279
+ score_threshold = score_thresh ,
280
+ iou_threshold = iou_thresh ,
268
281
)
269
- labels = _clean_up (labels , ** kwargs )
282
+
283
+ # get the centroids
284
+ points = points [inds ]
285
+ scores = scores [inds ]
286
+ dist = dist [inds ]
287
+ labels = polygons_to_label (dist , points , prob = scores , shape = dist_map .shape )
270
288
271
289
return labels
272
290
0 commit comments