Skip to content

Commit 1cc595e

Browse files
Add ribbon model and refactor IO functionality
1 parent 0fe01c4 commit 1cc595e

File tree

9 files changed

+189
-79
lines changed

9 files changed

+189
-79
lines changed

synaptic_reconstruction/file_utils.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
2-
from typing import List, Optional, Union
2+
from typing import Dict, List, Optional, Tuple, Union
3+
4+
import mrcfile
5+
import numpy as np
36

47

58
def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]:
@@ -23,3 +26,58 @@ def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, Lis
2326
return tomograms
2427
assert len(tomograms) == n_tomograms, f"{folder}: {len(tomograms)}, {n_tomograms}"
2528
return tomograms[0] if n_tomograms == 1 else tomograms
29+
30+
31+
def _parse_voxel_size(voxel_size):
32+
parsed_voxel_size = None
33+
try:
34+
# The voxel sizes are stored in Angsrrom in the MRC header, but we want them
35+
# in nanometer. Hence we divide by a factor of 10 here.
36+
parsed_voxel_size = {
37+
"x": voxel_size.x / 10,
38+
"y": voxel_size.y / 10,
39+
"z": voxel_size.z / 10,
40+
}
41+
except Exception as e:
42+
print(f"Failed to read voxel size: {e}")
43+
return parsed_voxel_size
44+
45+
46+
def read_voxel_size(path: str) -> Dict[str, float] | None:
47+
"""Read voxel size from mrc/rec file.
48+
49+
The original unit of voxel size is Angstrom and we convert it to nanometers by dividing it by ten.
50+
51+
Args:
52+
path: Path to mrc/rec file.
53+
54+
Returns:
55+
Mapping from the axis name to voxel size. None if the voxel size could not be read.
56+
"""
57+
with mrcfile.open(path, permissive=True) as mrc:
58+
voxel_size = _parse_voxel_size(mrc.voxel_size)
59+
return voxel_size
60+
61+
62+
# TODO: double check axis ordering with elf
63+
def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]:
64+
"""Read data and voxel size from mrc/rec file.
65+
66+
Args:
67+
path: Path to mrc/rec file.
68+
69+
Returns:
70+
The data read from the file.
71+
The voxel size read from the file.
72+
"""
73+
with mrcfile.open(path, permissive=True) as mrc:
74+
voxel_size = _parse_voxel_size(mrc.voxel_size)
75+
data = np.asarray(mrc.data[:])
76+
77+
# Transpose the data to match python axis order.
78+
if data.ndim == 3:
79+
data = np.flip(data, axis=1)
80+
else:
81+
data = np.flip(data, axis=0)
82+
83+
return data, voxel_size
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
3+
# TODO
4+
# - merge compartments which share vesicles (based on threshold for merging)
5+
# - filter out compartments with less than some threshold vesicles
6+
def postpocess_compartments():
7+
pass

synaptic_reconstruction/inference/postprocessing/ribbon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def segment_ribbon(
2020
n_slices_exclude: The number of slices to exclude on the top / bottom
2121
in order to avoid segmentation errors due to imaging artifacts in top and bottom.
2222
n_ribbons: The number of ribbons in the tomogram.
23-
max_vesicle_distance: The maximal distance to associate a vesicle with a ribbon.
23+
max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.
2424
"""
2525
assert ribbon_prediction.shape == vesicle_segmentation.shape
2626

synaptic_reconstruction/napari.yaml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
name: synaptic_reconstruction
22
display_name: SynapseNet
3-
# see https://napari.org/stable/plugins/manifest.html for valid categories
3+
4+
# See https://napari.org/stable/plugins/manifest.html for valid categories.
45
categories: ["Image Processing", "Annotation"]
6+
57
contributions:
68
commands:
9+
# Commands for widgets.
710
- id: synaptic_reconstruction.segment
811
python_name: synaptic_reconstruction.tools.segmentation_widget:SegmentationWidget
912
title: Segment
@@ -20,6 +23,11 @@ contributions:
2023
python_name: synaptic_reconstruction.tools.vesicle_pool_widget:VesiclePoolWidget
2124
title: Vesicle Pooling
2225

26+
# Commands for sample data.
27+
- id: synaptic_reconstruction.sample_data_tem_2d
28+
python_name: synaptic_reconstruction.sample_data:sample_data_tem_2d
29+
title: Load TEM 2D sample data
30+
2331
readers:
2432
- command: synaptic_reconstruction.file_reader
2533
filename_patterns:
@@ -37,3 +45,8 @@ contributions:
3745
display_name: Morphology Analysis
3846
- command: synaptic_reconstruction.vesicle_pooling
3947
display_name: Vesicle Pooling
48+
49+
sample_data:
50+
- command: synaptic_reconstruction.sample_data_tem_2d
51+
display_name: TEM 2D Sample Data
52+
key: synapse-net-tem-2d

synaptic_reconstruction/sample_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import pooch
33

4+
from .file_utils import read_mrc
5+
46

57
def get_sample_data(name: str) -> str:
68
"""Get the filepath to SynapseNet sample data, stored as mrc file.
@@ -32,3 +34,11 @@ def get_sample_data(name: str) -> str:
3234
)
3335
file_path = data_registry.fetch(key)
3436
return file_path
37+
38+
39+
def sample_data_tem_2d():
40+
file_path = get_sample_data("tem_2d")
41+
data, voxel_size = read_mrc(file_path)
42+
metadata = {"file_path": file_path, "voxel_size": voxel_size}
43+
add_image_kwargs = {"name": "tem_2d", "metadata": metadata, "colormap": "gray"}
44+
return [(data, add_image_kwargs)]

synaptic_reconstruction/tools/base_widget.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ def __init__(self):
2323
self.attribute_dict = {}
2424

2525
def _create_layer_selector(self, selector_name, layer_type="Image"):
26-
"""
27-
Create a layer selector for an image or labels and store it in a dictionary.
26+
"""Create a layer selector for an image or labels and store it in a dictionary.
2827
29-
Parameters:
30-
- selector_name (str): The name of the selector, used as a key in the dictionary.
31-
- layer_type (str): The type of layer to filter for ("Image" or "Labels").
28+
Args:
29+
selector_name (str): The name of the selector, used as a key in the dictionary.
30+
layer_type (str): The type of layer to filter for ("Image" or "Labels").
3231
"""
3332
if not hasattr(self, "layer_selectors"):
3433
self.layer_selectors = {}
@@ -286,17 +285,19 @@ def _get_file_path(self, name, textbox, tooltip=None):
286285
# Handle the case where the selected path is not a file
287286
print("Invalid file selected. Please try again.")
288287

289-
def _handle_resolution(self, metadata, voxel_size_param, ndim):
288+
def _handle_resolution(self, metadata, voxel_size_param, ndim, return_as_list=True):
290289
# Get the resolution / voxel size from the layer metadata if available.
291290
resolution = metadata.get("voxel_size", None)
292-
if resolution is not None:
293-
resolution = [resolution[ax] for ax in ("zyx" if ndim == 3 else "yx")]
294291

295292
# If user input was given then override resolution from metadata.
293+
axes = "zyx" if ndim == 3 else "yx"
296294
if voxel_size_param.value() != 0.0: # Changed from default.
297-
resolution = ndim * [voxel_size_param.value()]
295+
resolution = {ax: voxel_size_param.value() for ax in axes}
296+
297+
if resolution is not None and return_as_list:
298+
resolution = [resolution[ax] for ax in axes]
299+
assert len(resolution) == ndim
298300

299-
assert len(resolution) == ndim
300301
return resolution
301302

302303
def _save_table(self, save_path, data):

synaptic_reconstruction/tools/segmentation_widget.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import copy
2+
13
import napari
4+
import numpy as np
5+
26
from napari.utils.notifications import show_info
37
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
48

59
from .base_widget import BaseWidget
610
from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,
711
get_current_tiling, compute_scale_from_voxel_size, load_custom_model)
8-
from synaptic_reconstruction.inference.util import get_default_tiling
9-
import copy
12+
from ..inference.util import get_default_tiling
1013

1114

1215
class SegmentationWidget(BaseWidget):
@@ -79,37 +82,41 @@ def on_predict(self):
7982
show_info("INFO: Please choose an image.")
8083
return
8184

82-
# load current tiling
85+
# Get the current tiling.
8386
self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)
8487

88+
# Get the voxel size.
8589
metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
86-
voxel_size = metadata.get("voxel_size", None)
87-
scale = None
90+
voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
8891

89-
if self.voxel_size_param.value() != 0.0: # changed from default
90-
voxel_size = {}
91-
# override voxel size with user input
92-
if len(image.shape) == 3:
93-
voxel_size["x"] = self.voxel_size_param.value()
94-
voxel_size["y"] = self.voxel_size_param.value()
95-
voxel_size["z"] = self.voxel_size_param.value()
96-
else:
97-
voxel_size["x"] = self.voxel_size_param.value()
98-
voxel_size["y"] = self.voxel_size_param.value()
92+
# Determine the scaling based on the voxel size.
93+
scale = None
9994
if voxel_size:
10095
if model_type == "custom":
10196
show_info("INFO: The image is not rescaled for a custom model.")
10297
else:
10398
# calculate scale so voxel_size is the same as in training
10499
scale = compute_scale_from_voxel_size(voxel_size, model_type)
105-
show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.")
106-
100+
scale_info = list(map(lambda x: np.round(x, 2), scale))
101+
show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
102+
103+
# Some models require an additional segmentation for inference or postprocessing.
104+
# For these models we read out the 'Extra Segmentation' widget.
105+
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
106+
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
107+
kwargs = {"extra_segmentation": extra_seg}
108+
else:
109+
kwargs = {}
107110
segmentation = run_segmentation(
108-
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale
111+
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
109112
)
110113

111-
# Add the segmentation layer
112-
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
114+
# Add the segmentation layer(s).
115+
if isinstance(segmentation, dict):
116+
for name, seg in segmentation.items():
117+
self.viewer.add_labels(seg, name=name, metadata=metadata)
118+
else:
119+
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
113120
show_info(f"INFO: Segmentation of {model_type} added to layers.")
114121

115122
def _create_settings_widget(self):
@@ -156,5 +163,10 @@ def _create_settings_widget(self):
156163
)
157164
setting_values.layout().addLayout(layout)
158165

166+
# Add selection UI for additional segmentation, which some models require for inference or postproc.
167+
self.extra_seg_selector_name = "Extra Segmentation"
168+
self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
169+
setting_values.layout().addWidget(self.extra_selector_widget)
170+
159171
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
160172
return settings

synaptic_reconstruction/tools/util.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..inference.active_zone import segment_active_zone
1010
from ..inference.compartments import segment_compartments
1111
from ..inference.mitochondria import segment_mitochondria
12+
from ..inference.ribbon_synapse import segment_ribbon_synapse_structures
1213
from ..inference.vesicles import segment_vesicles
1314

1415

@@ -43,8 +44,8 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
4344
"""Get the model for the given segmentation type.
4445
4546
Args:
46-
model_type: The model type.
47-
One of 'vesicles', 'mitochondria', 'active_zone', 'compartments' or 'inner_ear_structures'.
47+
model_type: The model type. You can choose One of:
48+
'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
4849
device: The device to use.
4950
5051
Returns:
@@ -58,6 +59,44 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
5859
return model
5960

6061

62+
def _segment_ribbon_AZ(image, model, tiling, scale, verbose, **kwargs):
63+
# Parse additional keyword arguments from the kwargs.
64+
vesicles = kwargs.pop("extra_segmentation")
65+
threshold = kwargs.pop("threshold", 0.5)
66+
n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
67+
n_ribbons = kwargs.pop("n_slices_exclude", 1)
68+
69+
predictions = segment_ribbon_synapse_structures(
70+
image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
71+
)
72+
73+
# If the vesicles were passed then run additional post-processing.
74+
if vesicles is None:
75+
from synaptic_reconstruction.inference.postprocessing import (
76+
segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
77+
)
78+
79+
ribbon = segment_ribbon(
80+
predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons,
81+
max_vesicle_distance=40,
82+
)
83+
PD = segment_presynaptic_density(
84+
predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40,
85+
)
86+
ref_segmentation = PD if PD.sum() > 0 else ribbon
87+
membrane = segment_membrane_distance_based(
88+
predictions["membrane"], ref_segmentation, n_sclices_exclude=n_slices_exclude, max_distance=500
89+
)
90+
91+
segmentation = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
92+
93+
# Otherwise, just return the predictions.
94+
else:
95+
segmentation = predictions
96+
97+
return segmentation
98+
99+
61100
def run_segmentation(
62101
image: np.ndarray,
63102
model: torch.nn.Module,
@@ -66,22 +105,21 @@ def run_segmentation(
66105
scale: Optional[List[float]] = None,
67106
verbose: bool = False,
68107
**kwargs,
69-
) -> np.ndarray:
108+
) -> np.ndarray | Dict[str, np.ndarray]:
70109
"""Run synaptic structure segmentation.
71110
72111
Args:
73112
image: The input image or image volume.
74113
model: The segmentation model.
75-
model_type: The model type. This will determine which segmentation
76-
post-processing is used.
114+
model_type: The model type. This will determine which segmentation post-processing is used.
77115
tiling: The tiling settings for inference.
78116
scale: A scale factor for resizing the input before applying the model.
79117
The output will be scaled back to the initial size.
80118
verbose: Whether to print detailed information about the prediction and segmentation.
81-
kwargs: Optional parameter for the segmentation function.
119+
kwargs: Optional parameters for the segmentation function.
82120
83121
Returns:
84-
The segmentation.
122+
The segmentation. For models that return multiple segmentations, this function returns a dictionary.
85123
"""
86124
if model_type.startswith("vesicles"):
87125
segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
@@ -91,8 +129,8 @@ def run_segmentation(
91129
segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
92130
elif model_type == "compartments":
93131
segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
94-
elif model_type == "ribbon_synapse_structures":
95-
raise NotImplementedError
132+
elif model_type == "ribbon":
133+
segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
96134
else:
97135
raise ValueError(f"Unknown model type: {model_type}")
98136
return segmentation
@@ -108,6 +146,7 @@ def get_model_training_resolution(model_type):
108146
"active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
109147
"compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
110148
"mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
149+
"ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
111150
"vesicles_2d": {"x": 1.35, "y": 1.35},
112151
"vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
113152
"vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
@@ -120,6 +159,7 @@ def get_model_registry():
120159
"active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0",
121160
"compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
122161
"mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
162+
"ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
123163
"vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
124164
"vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
125165
"vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b",
@@ -128,6 +168,7 @@ def get_model_registry():
128168
"active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download",
129169
"compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
130170
"mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
171+
"ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
131172
"vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
132173
"vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
133174
"vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download",

0 commit comments

Comments
 (0)