Skip to content

Commit 6daf9de

Browse files
Update compartment segmentation
1 parent 54204cc commit 6daf9de

File tree

5 files changed

+158
-38
lines changed

5 files changed

+158
-38
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import os
2+
from glob import glob
3+
4+
5+
INPUT_ROOT = ""
6+
OUTPUT_ROOT = ""
7+
8+
9+
def main():
10+
pass
11+
12+
13+
if __name__ == "__main__":
14+
main()

scripts/cooper/full_reconstruction/segment_compartments.py

Whitespace-only changes.

scripts/cooper/full_reconstruction/segment_mitochondria.py

Whitespace-only changes.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import napari
6+
7+
from synaptic_reconstruction.inference.compartments import _segment_compartments_3d
8+
9+
10+
def check_pred(path, pred_path, name):
11+
with h5py.File(path, "r") as f:
12+
raw = f["raw"][:]
13+
# seg = f["labels/compartments"][:]
14+
15+
with h5py.File(pred_path, "r") as f:
16+
pred = f["prediction"][:]
17+
18+
print("Run segmentation ...")
19+
seg_new = _segment_compartments_3d(pred)
20+
print("done")
21+
22+
v = napari.Viewer()
23+
v.add_image(raw)
24+
v.add_image(pred, visible=False)
25+
# v.add_labels(seg, visible=False)
26+
v.add_labels(seg_new)
27+
v.title = name
28+
napari.run()
29+
30+
31+
def main():
32+
seg_paths = sorted(glob("./predictions/segmentation/**/*.h5", recursive=True))
33+
34+
for seg_path in seg_paths:
35+
ds_name, fname = os.path.split(seg_path)
36+
ds_name = os.path.split(ds_name)[1]
37+
38+
# if ds_name in ("20241019_Tomo-eval_MF_Synapse", "20241019_Tomo-eval_PS_Synapse"):
39+
# continue
40+
41+
name = f"{ds_name}/{fname}"
42+
pred_path = os.path.join("./predictions/prediction", ds_name, fname)
43+
assert os.path.exists(pred_path), pred_path
44+
check_pred(seg_path, pred_path, name)
45+
46+
47+
main()

synaptic_reconstruction/inference/compartments.py

Lines changed: 97 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,49 +8,76 @@
88
import elf.segmentation as eseg
99
import nifty
1010
from elf.tracking.tracking_utils import compute_edges_from_overlap
11-
from scipy.ndimage import distance_transform_edt
11+
from scipy.ndimage import distance_transform_edt, binary_closing
1212
from skimage.measure import label, regionprops
1313
from skimage.segmentation import watershed
14+
from skimage.morphology import remove_small_holes
1415

15-
from synaptic_reconstruction.inference.util import apply_size_filter, get_prediction, _Scaler
16-
17-
18-
def _multicut(ws, prediction, beta, n_threads):
19-
rag = eseg.features.compute_rag(ws, n_threads=n_threads)
20-
edge_features = eseg.features.compute_boundary_mean_and_length(rag, prediction, n_threads=n_threads)
21-
edge_probs, edge_sizes = edge_features[:, 0], edge_features[:, 1]
22-
edge_costs = eseg.multicut.compute_edge_costs(edge_probs, edge_sizes=edge_sizes, beta=beta)
23-
node_labels = eseg.multicut.multicut_kernighan_lin(rag, edge_costs)
24-
seg = eseg.features.project_node_labels_to_pixels(rag, node_labels, n_threads)
25-
return seg
16+
from synaptic_reconstruction.inference.util import get_prediction, _Scaler
2617

2718

2819
def _segment_compartments_2d(
29-
prediction,
30-
distances=None,
31-
boundary_threshold=0.4,
32-
beta=0.6,
33-
n_threads=1,
34-
run_multicut=True,
35-
min_size=500,
20+
boundaries,
21+
boundary_threshold=0.4, # Threshold for the boundary distance computation.
22+
large_seed_distance=30, # The distance threshold for computing large seeds (= components).
23+
distances=None, # Pre-computed distances to take into account z-context.
3624
):
25+
# Compoute distances if already not precomputed.
3726
if distances is None:
38-
distances = distance_transform_edt(prediction < boundary_threshold).astype("float32")
27+
distances = distance_transform_edt(boundaries < boundary_threshold).astype("float32")
28+
distances_z = distances
29+
else:
30+
# If the distances were pre-computed then compute them again in 2d.
31+
# This is needed for inserting small seeds from maxima, otherwise we will get spurious maxima.
32+
distances_z = distance_transform_edt(boundaries < boundary_threshold).astype("float32")
33+
34+
# Find the large seeds as connected components in the distances > large_seed_distance.
35+
seeds = label(distances > large_seed_distance)
36+
37+
# Remove to small large seeds.
38+
min_seed_area = 50
39+
ids, sizes = np.unique(seeds, return_counts=True)
40+
remove_ids = ids[sizes < min_seed_area]
41+
seeds[np.isin(seeds, remove_ids)] = 0
42+
43+
# Compute the small seeds = local maxima of the in-plane distance map
44+
small_seeds = vigra.analysis.localMaxima(distances_z, marker=np.nan, allowAtBorder=True, allowPlateaus=True)
45+
small_seeds = label(np.isnan(small_seeds))
46+
47+
# We only keep small seeds that don't intersect with a large seed.
48+
props = regionprops(small_seeds, seeds)
49+
keep_seeds = [prop.label for prop in props if prop.max_intensity == 0]
50+
keep_mask = np.isin(small_seeds, keep_seeds)
51+
52+
# Add up the small seeds we keep with the large seeds.
53+
all_seeds = seeds.copy()
54+
seed_offset = seeds.max()
55+
all_seeds[keep_mask] = (small_seeds[keep_mask] + seed_offset)
56+
57+
# Run watershed to get the segmentation.
58+
hmap = boundaries + (distances.max() - distances) / distances.max()
59+
raw_segmentation = watershed(hmap, markers=all_seeds)
60+
61+
# Thee are the large seed ids that we will keep.
62+
keep_ids = list(range(1, seed_offset + 1))
63+
64+
# Iterate over the ids, only keep large seeds and remove holes in their respective masks.
65+
props = regionprops(raw_segmentation)
66+
segmentation = np.zeros_like(raw_segmentation)
67+
for prop in props:
68+
if prop.label not in keep_ids:
69+
continue
3970

40-
# replace with skimage?
41-
maxima = vigra.analysis.localMaxima(distances, marker=np.nan, allowAtBorder=True, allowPlateaus=True)
42-
maxima = label(np.isnan(maxima))
71+
# Get bounding box and mask.
72+
bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:]))
73+
mask = raw_segmentation[bb] == prop.label
4374

44-
hmap = distances
45-
hmap = (hmap.max() - hmap)
46-
hmap /= hmap.max()
47-
hmap_ws = hmap + prediction
48-
ws = watershed(hmap_ws, markers=maxima)
75+
# Fill small holes and apply closing.
76+
mask = remove_small_holes(mask, area_threshold=500)
77+
mask = np.logical_or(binary_closing(mask, iterations=4), mask)
78+
segmentation[bb][mask] = prop.label
4979

50-
hmap_mc = 0.8 * prediction + 0.2 * hmap
51-
seg = _multicut(ws, hmap_mc, beta, n_threads)
52-
seg = apply_size_filter(seg, min_size)
53-
return seg
80+
return segmentation
5481

5582

5683
def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
@@ -63,13 +90,12 @@ def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
6390
graph = nifty.graph.undirectedGraph(n_nodes)
6491
graph.insertEdges(uv_ids)
6592

66-
costs = eseg.multicut.compute_edge_costs(overlaps)
93+
costs = eseg.multicut.compute_edge_costs(1.0 - overlaps)
6794
# set background weights to be maximally repulsive
6895
bg_edges = (uv_ids == 0).any(axis=1)
6996
costs[bg_edges] = -8.0
7097

71-
node_labels = eseg.multicut.multicut_decomposition(graph, -1 * costs, beta=beta)
72-
98+
node_labels = eseg.multicut.multicut_decomposition(graph, costs, beta=beta)
7399
segmentation = nifty.tools.take(node_labels, seg_2d)
74100

75101
if min_z_extent is not None and min_z_extent > 0:
@@ -86,25 +112,57 @@ def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
86112
return segmentation
87113

88114

115+
def _postprocess_seg_3d(seg):
116+
# Structure lement for 2d dilation in 3d.
117+
structure_element = np.ones((3, 3)) # 3x3 structure for XY plane
118+
structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane
119+
structure_3d[0] = structure_element
120+
121+
props = regionprops(seg)
122+
for prop in props:
123+
# Get bounding box and mask.
124+
bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:]))
125+
mask = seg[bb] == prop.label
126+
127+
# Fill small holes and apply closing.
128+
mask = remove_small_holes(mask, area_threshold=1000)
129+
mask = np.logical_or(binary_closing(mask, iterations=4), mask)
130+
mask = np.logical_or(binary_closing(mask, iterations=8, structure=structure_3d), mask)
131+
seg[bb][mask] = prop.label
132+
133+
return seg
134+
135+
89136
def _segment_compartments_3d(
90137
prediction,
91138
boundary_threshold=0.4,
92-
n_slices_exclude=5,
139+
n_slices_exclude=0,
93140
min_z_extent=10,
94141
):
95142
distances = distance_transform_edt(prediction < boundary_threshold).astype("float32")
96143
seg_2d = np.zeros(prediction.shape, dtype="uint32")
97144

98145
offset = 0
146+
# Parallelize?
99147
for z in range(seg_2d.shape[0]):
100148
if z < n_slices_exclude or z >= seg_2d.shape[0] - n_slices_exclude:
101149
continue
102-
seg_z = _segment_compartments_2d(prediction[z], distances=distances[z], run_multicut=True, min_size=500)
150+
seg_z = _segment_compartments_2d(prediction[z], distances=distances[z])
103151
seg_z[seg_z != 0] += offset
104152
offset = int(seg_z.max())
105153
seg_2d[z] = seg_z
106154

107155
seg = _merge_segmentation_3d(seg_2d, min_z_extent)
156+
seg = _postprocess_seg_3d(seg)
157+
158+
# import napari
159+
# v = napari.Viewer()
160+
# v.add_image(prediction)
161+
# v.add_image(distances)
162+
# v.add_labels(seg_2d)
163+
# v.add_labels(seg)
164+
# napari.run()
165+
108166
return seg
109167

110168

@@ -116,7 +174,7 @@ def segment_compartments(
116174
verbose: bool = True,
117175
return_predictions: bool = False,
118176
scale: Optional[List[float]] = None,
119-
mask: Optional[np.ndarray] = None,
177+
**kwargs,
120178
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
121179
"""
122180
Segment synaptic compartments in an input volume.
@@ -159,6 +217,7 @@ def segment_compartments(
159217
seg = _segment_compartments_3d(pred)
160218
if verbose:
161219
print("Run segmentation in", time.time() - t0, "s")
220+
162221
seg = scaler.rescale_output(seg, is_segmentation=True)
163222

164223
if return_predictions:

0 commit comments

Comments
 (0)