Skip to content

Commit 565bb3d

Browse files
committed
perf(postproc): rewrite stardist postproc
1 parent 35886d9 commit 565bb3d

File tree

5 files changed

+291
-62
lines changed

5 files changed

+291
-62
lines changed

cellseg_models_pytorch/postproc/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from .functional.drfns import post_proc_drfns
1414
from .functional.hovernet import post_proc_hovernet
1515
from .functional.omnipose import get_masks_omnipose, post_proc_omnipose
16-
from .functional.stardist import post_proc_stardist, post_proc_stardist_orig
16+
from .functional.stardist.nms import get_bboxes
17+
from .functional.stardist.stardist import post_proc_stardist, post_proc_stardist_orig
1718

1819
POSTPROC_LOOKUP = {
1920
"stardist_orig": post_proc_stardist_orig,
@@ -44,4 +45,5 @@
4445
"post_proc_drfns",
4546
"post_proc_dcan",
4647
"post_proc_dran",
48+
"get_bboxes",
4749
]
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import math
2+
from typing import List, Sequence, Tuple
3+
4+
import numpy as np
5+
from numba import njit, prange
6+
from scipy.spatial import KDTree
7+
8+
from ....utils import intersection
9+
10+
__all__ = ["get_bboxes", "nms_stardist"]
11+
12+
13+
@njit(parallel=True)
14+
def get_bboxes(
15+
dist: np.ndarray, points: np.ndarray
16+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
17+
"""Get bounding boxes from the non-zero pixels of the radial distance maps.
18+
19+
This is basically a translation from the stardist repo cpp code to python
20+
21+
NOTE: jit compiled and parallelized with numba.
22+
23+
Parameters
24+
----------
25+
dist : np.ndarray
26+
The non-zero values of the radial distance maps. Shape: (n_nonzero, n_rays).
27+
points : np.ndarray
28+
The yx-coordinates of the non-zero points. Shape (n_nonzero, 2).
29+
30+
Returns
31+
-------
32+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
33+
Returns the x0, y0, x1, y1 bbox coordinates, bbox areas and the maximum
34+
radial distance in the image.
35+
"""
36+
n_polys = dist.shape[0]
37+
n_rays = dist.shape[1]
38+
39+
bbox_x1 = np.zeros(n_polys)
40+
bbox_x2 = np.zeros(n_polys)
41+
bbox_y1 = np.zeros(n_polys)
42+
bbox_y2 = np.zeros(n_polys)
43+
44+
areas = np.zeros(n_polys)
45+
angle_pi = 2 * math.pi / n_rays
46+
max_dist = 0
47+
48+
for i in prange(n_polys):
49+
max_radius_outer = 0
50+
py = points[i, 0]
51+
px = points[i, 1]
52+
53+
for k in range(n_rays):
54+
d = dist[i, k]
55+
y = py + d * np.sin(angle_pi * k)
56+
x = px + d * np.cos(angle_pi * k)
57+
58+
if k == 0:
59+
bbox_x1[i] = x
60+
bbox_x2[i] = x
61+
bbox_y1[i] = y
62+
bbox_y2[i] = y
63+
else:
64+
bbox_x1[i] = min(x, bbox_x1[i])
65+
bbox_x2[i] = max(x, bbox_x2[i])
66+
bbox_y1[i] = min(y, bbox_y1[i])
67+
bbox_y2[i] = max(y, bbox_y2[i])
68+
69+
max_radius_outer = max(d, max_radius_outer)
70+
71+
areas[i] = (bbox_x2[i] - bbox_x1[i]) * (bbox_y2[i] - bbox_y1[i])
72+
max_dist = max(max_dist, max_radius_outer)
73+
74+
return bbox_x1, bbox_y1, bbox_x2, bbox_y2, areas, max_dist
75+
76+
77+
@njit
78+
def _suppress_bbox(
79+
query: Sequence[int],
80+
current_idx: int,
81+
boxes: np.ndarray,
82+
areas: np.ndarray,
83+
suppressed: List[bool],
84+
iou_threshold: float = 0.5,
85+
) -> np.ndarray:
86+
"""Inner loop of the stardist nms algorithm where bboxes are suppressed.
87+
88+
NOTE: Numba compiled only for performance.
89+
Parallelization had only a negative effect on run-time on.
90+
12-core hyperthreaded Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz.
91+
"""
92+
for i in range(len(query)):
93+
query_idx = query[i]
94+
95+
if suppressed[query_idx]:
96+
continue
97+
98+
overlap = intersection(boxes[current_idx], boxes[query_idx])
99+
iou = overlap / min(areas[current_idx] + 1e-10, areas[query_idx] + 1e-10)
100+
suppressed[query_idx] = iou > iou_threshold
101+
102+
return suppressed
103+
104+
105+
def nms_stardist(
106+
boxes: np.ndarray,
107+
points: np.ndarray,
108+
scores: np.ndarray,
109+
areas: np.ndarray,
110+
max_dist: float,
111+
score_threshold: float = 0.5,
112+
iou_threshold: float = 0.5,
113+
) -> np.ndarray:
114+
"""Non maximum suppression for stardist bboxes.
115+
116+
NOTE: This implementation relies on `scipy.spatial` `KDTree`
117+
118+
NOTE: This version of nms is faster than the original one in stardist repo
119+
and is fully written in python. The differenecs in the resulting instance
120+
segmentation masks are neglible.
121+
122+
Parameters
123+
----------
124+
boxes : np.ndarray
125+
An array of bbox coords in pascal VOC format (x0, y0, x1, y1).
126+
Shape: (n_points, 4). Dtype: float64.
127+
points : np.ndarray
128+
The yx-coordinates of the non-zero points. Shape (n_points, 2). Dtype: int64
129+
scores : np.ndarray
130+
The probability values at the point coordinates. Shape (n_points,).
131+
Dtype: float32/float64.
132+
areas : np.ndarray
133+
The areas of the bounding boxes at the point coordinates. Shape (n_points,).
134+
Dtype: float32/float64.
135+
radius_outer : np.ndarray
136+
The radial distances to background at each point. Shape (n_points, )
137+
max_dist : float
138+
The maximum radial distance of all the radial distances
139+
score_threshold : float, default=0.5
140+
Threshold for the probability distance map.
141+
iou_threshold : float, default=0.5
142+
Threshold for the IoU metric deciding whether to suppres a bbox.
143+
144+
Returns
145+
-------
146+
np.ndarray:
147+
The indices of the bboxes that are not suppressed. Shape: (n_kept, ).
148+
"""
149+
keep = []
150+
151+
if len(boxes) == 0:
152+
return np.zeros(0, dtype=np.int64)
153+
154+
kdtree = KDTree(points, leafsize=16)
155+
156+
suppressed = np.full(len(boxes), False)
157+
for current_idx in range(len(scores)):
158+
# If already visited or discarded
159+
if suppressed[current_idx]:
160+
continue
161+
162+
# If score is already below threshold then break
163+
if scores[current_idx] < score_threshold:
164+
break
165+
166+
# Query the points
167+
query = kdtree.query_ball_point(points[current_idx], max_dist)
168+
suppressed = _suppress_bbox(
169+
np.array(query), current_idx, boxes, areas, suppressed, iou_threshold
170+
)
171+
172+
# Add the current box
173+
keep.append(current_idx)
174+
175+
return np.array(keep)

cellseg_models_pytorch/postproc/functional/stardist.py renamed to cellseg_models_pytorch/postproc/functional/stardist/stardist.py

Lines changed: 79 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Imported most of the stuff from stardist repo. Minor modifications.
1+
"""Copied the polygons to label utilities from stardist repo (with minor modifications).
22
33
BSD 3-Clause License
44
@@ -34,15 +34,12 @@
3434
from typing import Tuple
3535

3636
import numpy as np
37-
import scipy.ndimage as ndi
38-
from skimage import img_as_ubyte
3937
from skimage.draw import polygon
40-
from skimage.measure import regionprops
38+
from skimage.morphology import disk, erosion
4139

42-
from ...utils import bounding_box, remap_label, remove_small_objects
43-
from .drfns import find_maxima, h_minima_reconstruction
40+
from .nms import get_bboxes, nms_stardist
4441

45-
__all__ = ["post_proc_stardist", "post_proc_stardist_orig", "polygons_to_label"]
42+
__all__ = ["post_proc_stardist_orig", "polygons_to_label"]
4643

4744

4845
def polygons_to_label_coord(
@@ -191,42 +188,25 @@ def polygons_to_label(
191188
return polygons_to_label_coord(coord, shape=shape, labels=ind)
192189

193190

194-
def _clean_up(inst_map: np.ndarray, size: int = 150, **kwargs) -> np.ndarray:
195-
"""Clean up overlapping instances."""
196-
mask = remap_label(inst_map.copy())
197-
mask_connected = ndi.label(mask)[0]
198-
199-
labels_connected = np.unique(mask_connected)[1:]
200-
for lab in labels_connected:
201-
inst = np.array(mask_connected == lab, copy=True)
202-
y1, y2, x1, x2 = bounding_box(inst)
203-
y1 = y1 - 2 if y1 - 2 >= 0 else y1
204-
x1 = x1 - 2 if x1 - 2 >= 0 else x1
205-
x2 = x2 + 2 if x2 + 2 <= mask_connected.shape[1] - 1 else x2
206-
y2 = y2 + 2 if y2 + 2 <= mask_connected.shape[0] - 1 else y2
207-
208-
box_insts = mask[y1:y2, x1:x2]
209-
if len(np.unique(ndi.label(box_insts)[0])) <= 2:
210-
real_labels, counts = np.unique(box_insts, return_counts=True)
211-
real_labels = real_labels[1:]
212-
counts = counts[1:]
213-
max_pixels = np.max(counts)
214-
max_label = real_labels[np.argmax(counts)]
215-
for real_lab, count in list(zip(list(real_labels), list(counts))):
216-
if count < max_pixels:
217-
if count < size:
218-
mask[mask == real_lab] = max_label
219-
220-
return mask
221-
222-
223191
def post_proc_stardist(
224-
dist_map: np.ndarray, stardist_map: np.ndarray, thresh: float = 0.4, **kwargs
192+
dist_map: np.ndarray,
193+
stardist_map: np.ndarray,
194+
score_thresh: float = 0.5,
195+
iou_thresh: float = 0.5,
196+
trim_bboxes: bool = True,
197+
**kwargs,
225198
) -> np.ndarray:
226-
"""Run post-processing for stardist.
199+
"""Run post-processing for stardist outputs.
200+
201+
NOTE: This is not the original cpp version.
202+
This is a python re-implementation of the stardidst post-processing
203+
pipeline that uses non-maximum-suppression. Here, critical parts of the
204+
nms are accelerated with `numba` and `scipy.spatial.KDtree`.
227205
228-
NOTE: This is not the original version that uses NMS.
229-
This is rather a workaround that is a little slower.
206+
NOTE:
207+
This implementaiton of the stardist post-processing is actually nearly twice
208+
faster than the original version if `trim_bboxes` is set to True. The resulting
209+
segmentation is not an exact match but the differences are mostly neglible.
230210
231211
Parameters
232212
----------
@@ -236,37 +216,75 @@ def post_proc_stardist(
236216
Predicted radial distances. Shape: (n_rays, H, W).
237217
thresh : float, default=0.4
238218
Threshold for the regressed distance transform.
219+
trim_bboxes : bool, default=True
220+
If True, The non-zero pixels are computed only from the cell contours
221+
which prunes down the pixel search space drastically.
239222
240223
Returns
241224
-------
242225
np.ndarray:
243226
Instance labelled mask. Shape: (H, W).
244227
"""
245-
stardist_map = stardist_map.transpose(1, 2, 0)
246-
mask = _ind_prob_thresh(dist_map, thresh, b=2)
247-
248-
# invert distmap
249-
inv_dist_map = 255 - img_as_ubyte(dist_map)
250-
251-
# find markers from minima erosion reconstructed maxima of inv dist map
252-
reconstructed = h_minima_reconstruction(inv_dist_map)
253-
markers = find_maxima(reconstructed, mask=mask)
254-
markers = ndi.label(markers)[0]
255-
markers = remove_small_objects(markers, min_size=5)
256-
points = np.array(
257-
tuple(np.array(r.centroid).astype(int) for r in regionprops(markers))
258-
)
228+
if (
229+
not dist_map.ndim == 2
230+
and not stardist_map.ndim == 3
231+
and not dist_map.shape == stardist_map.shape[:2]
232+
):
233+
raise ValueError(
234+
"Illegal input shapes. Make sure that: "
235+
f"`dist_map` has to have shape: (H, W). Got: {dist_map.shape} "
236+
f"`stardist_map` has to have shape (H, W, nrays). Got: {stardist_map.shape}"
237+
)
259238

260-
if len(points) == 0:
261-
return np.zeros_like(mask)
239+
dist = np.asarray(stardist_map).transpose(1, 2, 0)
240+
prob = np.asarray(dist_map)
262241

263-
dist = stardist_map[tuple(points.T)]
264-
scores = dist_map[tuple(points.T)]
242+
# threshold the edt distance transform map
243+
mask = _ind_prob_thresh(prob, score_thresh)
265244

266-
labels = polygons_to_label(
267-
dist, points, prob=scores, shape=mask.shape, scale_dist=(1, 1)
245+
# get only the mask contours to trim down bbox search space
246+
if trim_bboxes:
247+
fp = disk(2)
248+
mask -= erosion(mask, fp)
249+
250+
points = np.stack(np.where(mask), axis=1)
251+
252+
# Get only non-zero pixels of the transforms
253+
dist = dist[mask > 0]
254+
scores = prob[mask > 0]
255+
256+
# sort descendingly
257+
ind = np.argsort(scores)[::-1]
258+
dist = dist[ind]
259+
scores = scores[ind]
260+
points = points[ind]
261+
262+
# get bounding boxes
263+
x1, y1, x2, y2, areas, max_dist = get_bboxes(dist, points)
264+
boxes = np.stack([x1, y1, x2, y2], axis=1)
265+
266+
# consider only boxes above score threshold
267+
score_cond = scores >= score_thresh
268+
boxes = boxes[score_cond]
269+
scores = scores[score_cond]
270+
areas = areas[score_cond]
271+
272+
# run nms
273+
inds = nms_stardist(
274+
boxes,
275+
points,
276+
scores,
277+
areas,
278+
max_dist,
279+
score_threshold=score_thresh,
280+
iou_threshold=iou_thresh,
268281
)
269-
labels = _clean_up(labels, **kwargs)
282+
283+
# get the centroids
284+
points = points[inds]
285+
scores = scores[inds]
286+
dist = dist[inds]
287+
labels = polygons_to_label(dist, points, prob=scores, shape=dist_map.shape)
270288

271289
return labels
272290

cellseg_models_pytorch/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_inst_centroid,
1818
get_inst_types,
1919
get_type_instances,
20+
intersection,
2021
label_semantic,
2122
majority_vote_parallel,
2223
majority_vote_sequential,
@@ -132,4 +133,5 @@
132133
"majority_vote_parallel",
133134
"med_filt_parallel",
134135
"med_filt_sequential",
136+
"intersection",
135137
]

0 commit comments

Comments
 (0)