8
8
import elf .segmentation as eseg
9
9
import nifty
10
10
from elf .tracking .tracking_utils import compute_edges_from_overlap
11
- from scipy .ndimage import distance_transform_edt
11
+ from scipy .ndimage import distance_transform_edt , binary_closing
12
12
from skimage .measure import label , regionprops
13
13
from skimage .segmentation import watershed
14
+ from skimage .morphology import remove_small_holes
14
15
15
- from synaptic_reconstruction .inference .util import apply_size_filter , get_prediction , _Scaler
16
-
17
-
18
- def _multicut (ws , prediction , beta , n_threads ):
19
- rag = eseg .features .compute_rag (ws , n_threads = n_threads )
20
- edge_features = eseg .features .compute_boundary_mean_and_length (rag , prediction , n_threads = n_threads )
21
- edge_probs , edge_sizes = edge_features [:, 0 ], edge_features [:, 1 ]
22
- edge_costs = eseg .multicut .compute_edge_costs (edge_probs , edge_sizes = edge_sizes , beta = beta )
23
- node_labels = eseg .multicut .multicut_kernighan_lin (rag , edge_costs )
24
- seg = eseg .features .project_node_labels_to_pixels (rag , node_labels , n_threads )
25
- return seg
16
+ from synaptic_reconstruction .inference .util import get_prediction , _Scaler
26
17
27
18
28
19
def _segment_compartments_2d (
29
- prediction ,
30
- distances = None ,
31
- boundary_threshold = 0.4 ,
32
- beta = 0.6 ,
33
- n_threads = 1 ,
34
- run_multicut = True ,
35
- min_size = 500 ,
20
+ boundaries ,
21
+ boundary_threshold = 0.4 , # Threshold for the boundary distance computation.
22
+ large_seed_distance = 30 , # The distance threshold for computing large seeds (= components).
23
+ distances = None , # Pre-computed distances to take into account z-context.
36
24
):
25
+ # Compoute distances if already not precomputed.
37
26
if distances is None :
38
- distances = distance_transform_edt (prediction < boundary_threshold ).astype ("float32" )
27
+ distances = distance_transform_edt (boundaries < boundary_threshold ).astype ("float32" )
28
+ distances_z = distances
29
+ else :
30
+ # If the distances were pre-computed then compute them again in 2d.
31
+ # This is needed for inserting small seeds from maxima, otherwise we will get spurious maxima.
32
+ distances_z = distance_transform_edt (boundaries < boundary_threshold ).astype ("float32" )
33
+
34
+ # Find the large seeds as connected components in the distances > large_seed_distance.
35
+ seeds = label (distances > large_seed_distance )
36
+
37
+ # Remove to small large seeds.
38
+ min_seed_area = 50
39
+ ids , sizes = np .unique (seeds , return_counts = True )
40
+ remove_ids = ids [sizes < min_seed_area ]
41
+ seeds [np .isin (seeds , remove_ids )] = 0
42
+
43
+ # Compute the small seeds = local maxima of the in-plane distance map
44
+ small_seeds = vigra .analysis .localMaxima (distances_z , marker = np .nan , allowAtBorder = True , allowPlateaus = True )
45
+ small_seeds = label (np .isnan (small_seeds ))
46
+
47
+ # We only keep small seeds that don't intersect with a large seed.
48
+ props = regionprops (small_seeds , seeds )
49
+ keep_seeds = [prop .label for prop in props if prop .max_intensity == 0 ]
50
+ keep_mask = np .isin (small_seeds , keep_seeds )
51
+
52
+ # Add up the small seeds we keep with the large seeds.
53
+ all_seeds = seeds .copy ()
54
+ seed_offset = seeds .max ()
55
+ all_seeds [keep_mask ] = (small_seeds [keep_mask ] + seed_offset )
56
+
57
+ # Run watershed to get the segmentation.
58
+ hmap = boundaries + (distances .max () - distances ) / distances .max ()
59
+ raw_segmentation = watershed (hmap , markers = all_seeds )
60
+
61
+ # Thee are the large seed ids that we will keep.
62
+ keep_ids = list (range (1 , seed_offset + 1 ))
63
+
64
+ # Iterate over the ids, only keep large seeds and remove holes in their respective masks.
65
+ props = regionprops (raw_segmentation )
66
+ segmentation = np .zeros_like (raw_segmentation )
67
+ for prop in props :
68
+ if prop .label not in keep_ids :
69
+ continue
39
70
40
- # replace with skimage?
41
- maxima = vigra . analysis . localMaxima ( distances , marker = np . nan , allowAtBorder = True , allowPlateaus = True )
42
- maxima = label ( np . isnan ( maxima ))
71
+ # Get bounding box and mask.
72
+ bb = tuple ( slice ( start , stop ) for start , stop in zip ( prop . bbox [: 2 ], prop . bbox [ 2 :]) )
73
+ mask = raw_segmentation [ bb ] == prop . label
43
74
44
- hmap = distances
45
- hmap = (hmap .max () - hmap )
46
- hmap /= hmap .max ()
47
- hmap_ws = hmap + prediction
48
- ws = watershed (hmap_ws , markers = maxima )
75
+ # Fill small holes and apply closing.
76
+ mask = remove_small_holes (mask , area_threshold = 500 )
77
+ mask = np .logical_or (binary_closing (mask , iterations = 4 ), mask )
78
+ segmentation [bb ][mask ] = prop .label
49
79
50
- hmap_mc = 0.8 * prediction + 0.2 * hmap
51
- seg = _multicut (ws , hmap_mc , beta , n_threads )
52
- seg = apply_size_filter (seg , min_size )
53
- return seg
80
+ return segmentation
54
81
55
82
56
83
def _merge_segmentation_3d (seg_2d , beta = 0.5 , min_z_extent = 10 ):
@@ -63,13 +90,12 @@ def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
63
90
graph = nifty .graph .undirectedGraph (n_nodes )
64
91
graph .insertEdges (uv_ids )
65
92
66
- costs = eseg .multicut .compute_edge_costs (overlaps )
93
+ costs = eseg .multicut .compute_edge_costs (1.0 - overlaps )
67
94
# set background weights to be maximally repulsive
68
95
bg_edges = (uv_ids == 0 ).any (axis = 1 )
69
96
costs [bg_edges ] = - 8.0
70
97
71
- node_labels = eseg .multicut .multicut_decomposition (graph , - 1 * costs , beta = beta )
72
-
98
+ node_labels = eseg .multicut .multicut_decomposition (graph , costs , beta = beta )
73
99
segmentation = nifty .tools .take (node_labels , seg_2d )
74
100
75
101
if min_z_extent is not None and min_z_extent > 0 :
@@ -86,25 +112,57 @@ def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
86
112
return segmentation
87
113
88
114
115
+ def _postprocess_seg_3d (seg ):
116
+ # Structure lement for 2d dilation in 3d.
117
+ structure_element = np .ones ((3 , 3 )) # 3x3 structure for XY plane
118
+ structure_3d = np .zeros ((1 , 3 , 3 )) # Only applied in the XY plane
119
+ structure_3d [0 ] = structure_element
120
+
121
+ props = regionprops (seg )
122
+ for prop in props :
123
+ # Get bounding box and mask.
124
+ bb = tuple (slice (start , stop ) for start , stop in zip (prop .bbox [:2 ], prop .bbox [2 :]))
125
+ mask = seg [bb ] == prop .label
126
+
127
+ # Fill small holes and apply closing.
128
+ mask = remove_small_holes (mask , area_threshold = 1000 )
129
+ mask = np .logical_or (binary_closing (mask , iterations = 4 ), mask )
130
+ mask = np .logical_or (binary_closing (mask , iterations = 8 , structure = structure_3d ), mask )
131
+ seg [bb ][mask ] = prop .label
132
+
133
+ return seg
134
+
135
+
89
136
def _segment_compartments_3d (
90
137
prediction ,
91
138
boundary_threshold = 0.4 ,
92
- n_slices_exclude = 5 ,
139
+ n_slices_exclude = 0 ,
93
140
min_z_extent = 10 ,
94
141
):
95
142
distances = distance_transform_edt (prediction < boundary_threshold ).astype ("float32" )
96
143
seg_2d = np .zeros (prediction .shape , dtype = "uint32" )
97
144
98
145
offset = 0
146
+ # Parallelize?
99
147
for z in range (seg_2d .shape [0 ]):
100
148
if z < n_slices_exclude or z >= seg_2d .shape [0 ] - n_slices_exclude :
101
149
continue
102
- seg_z = _segment_compartments_2d (prediction [z ], distances = distances [z ], run_multicut = True , min_size = 500 )
150
+ seg_z = _segment_compartments_2d (prediction [z ], distances = distances [z ])
103
151
seg_z [seg_z != 0 ] += offset
104
152
offset = int (seg_z .max ())
105
153
seg_2d [z ] = seg_z
106
154
107
155
seg = _merge_segmentation_3d (seg_2d , min_z_extent )
156
+ seg = _postprocess_seg_3d (seg )
157
+
158
+ # import napari
159
+ # v = napari.Viewer()
160
+ # v.add_image(prediction)
161
+ # v.add_image(distances)
162
+ # v.add_labels(seg_2d)
163
+ # v.add_labels(seg)
164
+ # napari.run()
165
+
108
166
return seg
109
167
110
168
@@ -116,7 +174,7 @@ def segment_compartments(
116
174
verbose : bool = True ,
117
175
return_predictions : bool = False ,
118
176
scale : Optional [List [float ]] = None ,
119
- mask : Optional [ np . ndarray ] = None ,
177
+ ** kwargs ,
120
178
) -> Union [np .ndarray , Tuple [np .ndarray , np .ndarray ]]:
121
179
"""
122
180
Segment synaptic compartments in an input volume.
@@ -159,6 +217,7 @@ def segment_compartments(
159
217
seg = _segment_compartments_3d (pred )
160
218
if verbose :
161
219
print ("Run segmentation in" , time .time () - t0 , "s" )
220
+
162
221
seg = scaler .rescale_output (seg , is_segmentation = True )
163
222
164
223
if return_predictions :
0 commit comments