Skip to content

Commit d5faa99

Browse files
Updates to napari plugins and distance functionality
1 parent 9581b2b commit d5faa99

File tree

8 files changed

+272
-133
lines changed

8 files changed

+272
-133
lines changed

scripts/inner_ear/check_results.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import pandas
1010

11-
from synaptic_reconstruction.distance_measurements import create_object_distance_lines
11+
from synaptic_reconstruction.distance_measurements import create_object_distance_lines, load_distances
1212
from synaptic_reconstruction.file_utils import get_data_path
1313
from synaptic_reconstruction.tools.distance_measurement import _downsample
1414

@@ -21,11 +21,12 @@
2121
def get_distance_visualization(
2222
tomo, segmentations, distance_paths, vesicle_ids, scale, return_mem_props=False
2323
):
24-
ribbon_lines, _ = create_object_distance_lines(distance_paths["ribbon"], seg_ids=vesicle_ids, scale=scale)
25-
pd_lines, _ = create_object_distance_lines(distance_paths["PD"], seg_ids=vesicle_ids, scale=scale)
26-
membrane_lines, mem_props = create_object_distance_lines(
27-
distance_paths["membrane"], seg_ids=vesicle_ids, scale=scale
28-
)
24+
d, e1, e2, ids = load_distances(distance_paths["ribbon"])
25+
ribbon_lines, _ = create_object_distance_lines(d, e1, e2, ids, filter_seg_ids=vesicle_ids, scale=scale)
26+
d, e1, e2, ids = load_distances(distance_paths["PD"])
27+
pd_lines, _ = create_object_distance_lines(d, e1, e2, ids, filter_seg_ids=vesicle_ids, scale=scale)
28+
d, e1, e2, ids = load_distances(distance_paths["membrane"])
29+
membrane_lines, mem_props = create_object_distance_lines(d, e1, e2, ids, filter_seg_ids=vesicle_ids, scale=scale)
2930

3031
distance_lines = {
3132
"ribbon_distances": ribbon_lines,

synaptic_reconstruction/distance_measurements.py

Lines changed: 149 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import os
12
import multiprocessing as mp
3+
from typing import Dict, Optional, Tuple
24

35
import numpy as np
46

@@ -42,7 +44,7 @@ def compute_geodesic_distances(segmentation, distance_to, resolution=None, unsig
4244

4345

4446
# TODO update this
45-
def compute_centroid_distances(segmentation, resolution, n_neighbors):
47+
def _compute_centroid_distances(segmentation, resolution, n_neighbors):
4648
# TODO enable eccentricity centers instead
4749
props = regionprops(segmentation)
4850
centroids = np.array([prop.centroid for prop in props])
@@ -53,7 +55,7 @@ def compute_centroid_distances(segmentation, resolution, n_neighbors):
5355
return pair_distances
5456

5557

56-
def compute_boundary_distances(segmentation, resolution, n_threads):
58+
def _compute_boundary_distances(segmentation, resolution, n_threads):
5759

5860
seg_ids = np.unique(segmentation)[1:]
5961
n = len(seg_ids)
@@ -105,20 +107,37 @@ def compute_distances_for_object(i):
105107

106108

107109
def measure_pairwise_object_distances(
108-
segmentation,
109-
distance_type="boundary",
110-
resolution=None,
111-
n_threads=None,
112-
save_path=None,
113-
):
110+
segmentation: np.ndarray,
111+
distance_type: str = "boundary",
112+
resolution: Optional[Tuple[int, int, int]] = None,
113+
n_threads: Optional[int] = None,
114+
save_path: Optional[os.PathLike] = None,
115+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
116+
"""Compute the pairwise distances between all objects within a segmentation.
117+
118+
Args:
119+
segmentation: The input segmentation.
120+
distance_type: The type of distance to compute, can either be 'boundary' to
121+
compute the distance between the boundary / surface of the objects or 'centroid'
122+
to compute the distance between centroids.
123+
resolution: The resolution / pixel size of the data.
124+
n_threads: The number of threads for parallelizing the distance computation.
125+
save_path: Path for saving the measurement results in numpy zipped format.
126+
127+
Returns:
128+
The pairwise object distances.
129+
The 'left' endpoint coordinates of the distances.
130+
The 'right' endpoint coordinates of the distances.
131+
The segmentation id pairs of the distances.
132+
"""
114133
supported_distances = ("boundary", "centroid")
115134
assert distance_type in supported_distances
116135
if distance_type == "boundary":
117-
distances, endpoints1, endpoints2, seg_ids = compute_boundary_distances(segmentation, resolution, n_threads)
136+
distances, endpoints1, endpoints2, seg_ids = _compute_boundary_distances(segmentation, resolution, n_threads)
118137
elif distance_type == "centroid":
119138
raise NotImplementedError
120139
# TODO has to be adapted
121-
# distances, neighbors = compute_centroid_distances(segmentation, resolution)
140+
# distances, neighbors = _compute_centroid_distances(segmentation, resolution)
122141

123142
if save_path is not None:
124143
np.savez(
@@ -132,7 +151,7 @@ def measure_pairwise_object_distances(
132151
return distances, endpoints1, endpoints2, seg_ids
133152

134153

135-
def compute_seg_object_distances(segmentation, segmented_object, resolution, verbose):
154+
def _compute_seg_object_distances(segmentation, segmented_object, resolution, verbose):
136155
distance_map, indices = distance_transform_edt(segmented_object == 0, return_indices=True, sampling=resolution)
137156

138157
seg_ids = np.unique(segmentation)[1:].tolist()
@@ -177,15 +196,33 @@ def compute_seg_object_distances(segmentation, segmented_object, resolution, ver
177196

178197

179198
def measure_segmentation_to_object_distances(
180-
segmentation,
181-
segmented_object,
182-
distance_type="boundary",
183-
resolution=None,
184-
save_path=None,
185-
verbose=False,
186-
):
199+
segmentation: np.ndarray,
200+
segmented_object: np.ndarray,
201+
distance_type: str = "boundary",
202+
resolution: Optional[Tuple[int, int, int]] = None,
203+
save_path: Optional[os.PathLike] = None,
204+
verbose: bool = False,
205+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
206+
"""Compute the distance betwen all objects in a segmentation and another object.
207+
208+
Args:
209+
segmentation: The input segmentation.
210+
segmented_object: The segmented object.
211+
distance_type: The type of distance to compute, can either be 'boundary' to
212+
compute the distance between the boundary / surface of the objects or 'centroid'
213+
to compute the distance between centroids.
214+
resolution: The resolution / pixel size of the data.
215+
save_path: Path for saving the measurement results in numpy zipped format.
216+
verbose: Whether to print the progress of the distance computation.
217+
218+
Returns:
219+
The segmentation to object distances.
220+
The 'left' endpoint coordinates of the distances.
221+
The 'right' endpoint coordinates of the distances.
222+
The segmentation ids corresponding to the distances.
223+
"""
187224
if distance_type == "boundary":
188-
distances, endpoints1, endpoints2, seg_ids, object_ids = compute_seg_object_distances(
225+
distances, endpoints1, endpoints2, seg_ids, object_ids = _compute_seg_object_distances(
189226
segmentation, segmented_object, resolution, verbose
190227
)
191228
assert len(distances) == len(endpoints1) == len(endpoints2) == len(seg_ids) == len(object_ids)
@@ -204,7 +241,7 @@ def measure_segmentation_to_object_distances(
204241
return distances, endpoints1, endpoints2, seg_ids
205242

206243

207-
def extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_duplicates=True):
244+
def _extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_duplicates=True):
208245
distance_matrix = pairwise_distances.copy()
209246

210247
# Set the diagonal (distance to self) to infinity.
@@ -230,15 +267,56 @@ def extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_d
230267
return pairs
231268

232269

233-
# TODO update this for extracting only up to a max distance
234-
def create_distance_lines(measurement_path, n_neighbors=None, pairs=None, bb=None, scale=None, remove_duplicates=True):
270+
def load_distances(
271+
measurement_path: os.PathLike
272+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
273+
"""Load the saved distacnes from a zipped numpy file.
235274
275+
Args:
276+
measurement_path: The path where the distances where saved.
277+
278+
Returns:
279+
The segmentation to object distances.
280+
The 'left' endpoint coordinates of the distances.
281+
The 'right' endpoint coordinates of the distances.
282+
The segmentation ids corresponding to the distances.
283+
"""
236284
auto_dists = np.load(measurement_path)
237285
distances, seg_ids = auto_dists["distances"], list(auto_dists["seg_ids"])
238-
start_points, end_points = auto_dists["endpoints1"], auto_dists["endpoints2"]
286+
endpoints1, endpoints2 = auto_dists["endpoints1"], auto_dists["endpoints2"]
287+
return distances, endpoints1, endpoints2, seg_ids
239288

289+
290+
def create_pairwise_distance_lines(
291+
distances: np.ndarray,
292+
endpoints1: np.ndarray,
293+
endpoints2: np.ndarray,
294+
seg_ids: np.ndarray,
295+
n_neighbors: Optional[int] = None,
296+
pairs: Optional[np.ndarray] = None,
297+
bb: Optional[Tuple[slice]] = None,
298+
scale: Optional[float] = None,
299+
remove_duplicates: bool = True
300+
) -> Tuple[np.ndarray, Dict]:
301+
"""Create a line representation of pair-wise object distances for display in napari.
302+
303+
Args:
304+
distances: The pairwise distances.
305+
endpoints1: One set of distance end points.
306+
endpoints2: The other set of distance end points.
307+
seg_ids: The segmentation pair corresponding to each distance.
308+
n_neighbors: ...
309+
pairs: ...
310+
bb: ....
311+
scale: ...
312+
remove_duplicates: ...
313+
314+
Returns:
315+
The lines for plotting in napari.
316+
Additional attributes for the line layer in napari.
317+
"""
240318
if pairs is None and n_neighbors is not None:
241-
pairs = extract_nearest_neighbors(distances, seg_ids, n_neighbors, remove_duplicates=remove_duplicates)
319+
pairs = _extract_nearest_neighbors(distances, seg_ids, n_neighbors, remove_duplicates=remove_duplicates)
242320
elif pairs is None:
243321
pairs = [
244322
[id1, id2] for id1 in seg_ids for id2 in seg_ids if id1 < id2
@@ -252,27 +330,27 @@ def create_distance_lines(measurement_path, n_neighbors=None, pairs=None, bb=Non
252330

253331
pairs = np.array(pairs)
254332
distances = distances[pair_indices]
255-
start_points = start_points[pair_indices]
256-
end_points = end_points[pair_indices]
333+
endpoints1 = endpoints1[pair_indices]
334+
endpoints2 = endpoints2[pair_indices]
257335

258336
if bb is not None:
259337
in_bb = np.where(
260-
(start_points[:, 0] > bb[0].start) & (start_points[:, 0] < bb[0].stop) &
261-
(start_points[:, 1] > bb[1].start) & (start_points[:, 1] < bb[1].stop) &
262-
(start_points[:, 2] > bb[2].start) & (start_points[:, 2] < bb[2].stop) &
263-
(end_points[:, 0] > bb[0].start) & (end_points[:, 0] < bb[0].stop) &
264-
(end_points[:, 1] > bb[1].start) & (end_points[:, 1] < bb[1].stop) &
265-
(end_points[:, 2] > bb[2].start) & (end_points[:, 2] < bb[2].stop)
338+
(endpoints1[:, 0] > bb[0].start) & (endpoints1[:, 0] < bb[0].stop) &
339+
(endpoints1[:, 1] > bb[1].start) & (endpoints1[:, 1] < bb[1].stop) &
340+
(endpoints1[:, 2] > bb[2].start) & (endpoints1[:, 2] < bb[2].stop) &
341+
(endpoints2[:, 0] > bb[0].start) & (endpoints2[:, 0] < bb[0].stop) &
342+
(endpoints2[:, 1] > bb[1].start) & (endpoints2[:, 1] < bb[1].stop) &
343+
(endpoints2[:, 2] > bb[2].start) & (endpoints2[:, 2] < bb[2].stop)
266344
)
267345

268346
pairs = pairs[in_bb]
269-
distances, start_points, end_points = distances[in_bb], start_points[in_bb], end_points[in_bb]
347+
distances, endpoints1, endpoints2 = distances[in_bb], endpoints1[in_bb], endpoints2[in_bb]
270348

271349
offset = np.array([b.start for b in bb])[None]
272-
start_points -= offset
273-
end_points -= offset
350+
endpoints1 -= offset
351+
endpoints2 -= offset
274352

275-
lines = np.array([[start, end] for start, end in zip(start_points, end_points)])
353+
lines = np.array([[start, end] for start, end in zip(endpoints1, endpoints2)])
276354

277355
if scale is not None:
278356
scale_factor = np.array(3 * [scale])[None, None]
@@ -286,25 +364,43 @@ def create_distance_lines(measurement_path, n_neighbors=None, pairs=None, bb=Non
286364
return lines, properties
287365

288366

289-
def create_object_distance_lines(measurement_path, max_distance=None, seg_ids=None, scale=None):
290-
auto_dists = np.load(measurement_path)
291-
distances, all_seg_ids = auto_dists["distances"], auto_dists["seg_ids"]
292-
start_points, end_points = auto_dists["endpoints1"], auto_dists["endpoints2"]
367+
def create_object_distance_lines(
368+
distances: np.ndarray,
369+
endpoints1: np.ndarray,
370+
endpoints2: np.ndarray,
371+
seg_ids: np.ndarray,
372+
max_distance: Optional[float] = None,
373+
filter_seg_ids: Optional[np.ndarray] = None,
374+
scale: Optional[float] = None,
375+
) -> Tuple[np.ndarray, Dict]:
376+
"""Create a line representation of object distances for display in napari.
377+
378+
Args:
379+
distances: The measurd distances.
380+
endpoints1: One set of distance end points.
381+
endpoints2: The other set of distance end points.
382+
seg_ids: The segmentation ids corresponding to each distance.
383+
max_distance: ...
384+
scale: ...
385+
386+
Returns:
387+
The lines for plotting in napari.
388+
Additional attributes for the line layer in napari.
389+
"""
293390

294-
if seg_ids is None:
295-
seg_ids = all_seg_ids
296-
else:
297-
id_mask = np.isin(all_seg_ids, seg_ids)
391+
if filter_seg_ids is not None:
392+
id_mask = np.isin(seg_ids, filter_seg_ids)
298393
distances = distances[id_mask]
299-
start_points, end_points = start_points[id_mask], end_points[id_mask]
394+
endpoints1, endpoints2 = endpoints1[id_mask], endpoints2[id_mask]
395+
seg_ids = filter_seg_ids
300396

301397
if max_distance is not None:
302398
distance_mask = distances <= max_distance
303399
distances, seg_ids = distances[distance_mask], seg_ids[distance_mask]
304-
start_points, end_points = start_points[distance_mask], end_points[distance_mask]
400+
endpoints1, endpoints2 = endpoints1[distance_mask], endpoints2[distance_mask]
305401

306-
assert len(distances) == len(seg_ids) == len(start_points) == len(end_points)
307-
lines = np.array([[start, end] for start, end in zip(start_points, end_points)])
402+
assert len(distances) == len(seg_ids) == len(endpoints1) == len(endpoints2)
403+
lines = np.array([[start, end] for start, end in zip(endpoints1, endpoints2)])
308404

309405
if scale is not None and len(lines > 0):
310406
scale_factor = np.array(3 * [scale])[None, None]
@@ -318,7 +414,9 @@ def keep_direct_distances(segmentation, measurement_path, line_dilation=0, scale
318414
"""Filter out all distances that are not direct.
319415
I.e. distances that cross another segmented object.
320416
"""
321-
distance_lines, properties = create_distance_lines(measurement_path, scale=scale)
417+
418+
distances, ep1, ep2, seg_ids = load_distances(measurement_path)
419+
distance_lines, properties = create_object_distance_lines(distances, ep1, ep2, seg_ids, scale=scale)
322420

323421
ids_a, ids_b = properties["id_a"], properties["id_b"]
324422
filtered_ids_a, filtered_ids_b = [], []
@@ -357,7 +455,8 @@ def keep_direct_distances(segmentation, measurement_path, line_dilation=0, scale
357455
def filter_blocked_segmentation_to_object_distances(
358456
segmentation, measurement_path, line_dilation=0, scale=None, seg_ids=None, verbose=False,
359457
):
360-
distance_lines, properties = create_object_distance_lines(measurement_path, seg_ids=seg_ids, scale=scale)
458+
distances, ep1, ep2, seg_ids = load_distances(measurement_path)
459+
distance_lines, properties = create_object_distance_lines(distances, ep1, ep2, seg_ids, scale=scale)
361460
all_seg_ids = properties["id"]
362461

363462
filtered_ids = []

synaptic_reconstruction/napari.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@ contributions:
1111
python_name: synaptic_reconstruction.tools.synaptic_plugin.distance_measure_widget:get_distance_measure_widget
1212
title: Distance Measurement
1313
- id: synaptic_reconstruction.file_reader
14-
title: Read ".mrc, .rec" files
15-
python_name: synaptic_reconstruction.tools.file_reader_plugin.elf_reader:get_reader
14+
title: Read volumetric data
15+
python_name: synaptic_reconstruction.tools.volume_reader:get_reader
1616

1717
readers:
1818
- command: synaptic_reconstruction.file_reader
1919
filename_patterns:
2020
- '*.mrc'
21+
- '*.rec'
22+
- '*.h5'
2123
accepts_directories: false
2224

2325
widgets:
2426
- command: synaptic_reconstruction.segment
25-
display_name: Segmentation
27+
display_name: Segmentation
2628
- command: synaptic_reconstruction.distance_measure
2729
display_name: Distance Measurement

synaptic_reconstruction/tools/distance_measurement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from skimage.transform import rescale, resize
1717

1818
from ..distance_measurements import (
19-
create_distance_lines,
19+
create_object_distance_lines,
2020
measure_pairwise_object_distances,
2121
keep_direct_distances,
2222
)
@@ -68,7 +68,7 @@ def measurement_widget(
6868
n_neighbors = compute_neighbor_distances
6969
pairs = None
7070

71-
lines, properties = create_distance_lines(
71+
lines, properties = create_object_distance_lines(
7272
DISTANCE_MEASUREMENT_PATH, n_neighbors=n_neighbors, scale=VIEW_SCALE, pairs=pairs,
7373
)
7474
if "line" in viewer.layers: # TODO update the line layer

0 commit comments

Comments
 (0)