Skip to content

Commit 57b7258

Browse files
Merge pull request #75 from computational-cell-analytics/add-ribbon-inference
Add ribbon model and refactor IO functionality
2 parents 0fe01c4 + 71f9b2c commit 57b7258

File tree

10 files changed

+248
-80
lines changed

10 files changed

+248
-80
lines changed

synaptic_reconstruction/file_utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
2-
from typing import List, Optional, Union
2+
from typing import Dict, List, Optional, Tuple, Union
3+
4+
import mrcfile
5+
import numpy as np
36

47

58
def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]:
@@ -23,3 +26,54 @@ def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, Lis
2326
return tomograms
2427
assert len(tomograms) == n_tomograms, f"{folder}: {len(tomograms)}, {n_tomograms}"
2528
return tomograms[0] if n_tomograms == 1 else tomograms
29+
30+
31+
def _parse_voxel_size(voxel_size):
32+
parsed_voxel_size = None
33+
try:
34+
# The voxel sizes are stored in Angsrrom in the MRC header, but we want them
35+
# in nanometer. Hence we divide by a factor of 10 here.
36+
parsed_voxel_size = {
37+
"x": voxel_size.x / 10,
38+
"y": voxel_size.y / 10,
39+
"z": voxel_size.z / 10,
40+
}
41+
except Exception as e:
42+
print(f"Failed to read voxel size: {e}")
43+
return parsed_voxel_size
44+
45+
46+
def read_voxel_size(path: str) -> Dict[str, float] | None:
47+
"""Read voxel size from mrc/rec file.
48+
49+
The original unit of voxel size is Angstrom and we convert it to nanometers by dividing it by ten.
50+
51+
Args:
52+
path: Path to mrc/rec file.
53+
54+
Returns:
55+
Mapping from the axis name to voxel size. None if the voxel size could not be read.
56+
"""
57+
with mrcfile.open(path, permissive=True) as mrc:
58+
voxel_size = _parse_voxel_size(mrc.voxel_size)
59+
return voxel_size
60+
61+
62+
def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]:
63+
"""Read data and voxel size from mrc/rec file.
64+
65+
Args:
66+
path: Path to mrc/rec file.
67+
68+
Returns:
69+
The data read from the file.
70+
The voxel size read from the file.
71+
"""
72+
with mrcfile.open(path, permissive=True) as mrc:
73+
voxel_size = _parse_voxel_size(mrc.voxel_size)
74+
data = np.asarray(mrc.data[:])
75+
assert data.ndim in (2, 3)
76+
77+
# Transpose the data to match python axis order.
78+
data = np.flip(data, axis=1) if data.ndim == 3 else np.flip(data, axis=0)
79+
return data, voxel_size
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
3+
# TODO
4+
# - merge compartments which share vesicles (based on threshold for merging)
5+
# - filter out compartments with less than some threshold vesicles
6+
def postpocess_compartments():
7+
pass

synaptic_reconstruction/inference/postprocessing/ribbon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def segment_ribbon(
2020
n_slices_exclude: The number of slices to exclude on the top / bottom
2121
in order to avoid segmentation errors due to imaging artifacts in top and bottom.
2222
n_ribbons: The number of ribbons in the tomogram.
23-
max_vesicle_distance: The maximal distance to associate a vesicle with a ribbon.
23+
max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.
2424
"""
2525
assert ribbon_prediction.shape == vesicle_segmentation.shape
2626

synaptic_reconstruction/napari.yaml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
name: synaptic_reconstruction
22
display_name: SynapseNet
3-
# see https://napari.org/stable/plugins/manifest.html for valid categories
3+
4+
# See https://napari.org/stable/plugins/manifest.html for valid categories.
45
categories: ["Image Processing", "Annotation"]
6+
57
contributions:
68
commands:
9+
# Commands for widgets.
710
- id: synaptic_reconstruction.segment
811
python_name: synaptic_reconstruction.tools.segmentation_widget:SegmentationWidget
912
title: Segment
@@ -20,6 +23,14 @@ contributions:
2023
python_name: synaptic_reconstruction.tools.vesicle_pool_widget:VesiclePoolWidget
2124
title: Vesicle Pooling
2225

26+
# Commands for sample data.
27+
- id: synaptic_reconstruction.sample_data_tem_2d
28+
python_name: synaptic_reconstruction.sample_data:sample_data_tem_2d
29+
title: Load TEM 2D sample data
30+
- id: synaptic_reconstruction.sample_data_tem_tomo
31+
python_name: synaptic_reconstruction.sample_data:sample_data_tem_tomo
32+
title: Load TEM Tomo sample data
33+
2334
readers:
2435
- command: synaptic_reconstruction.file_reader
2536
filename_patterns:
@@ -37,3 +48,11 @@ contributions:
3748
display_name: Morphology Analysis
3849
- command: synaptic_reconstruction.vesicle_pooling
3950
display_name: Vesicle Pooling
51+
52+
sample_data:
53+
- command: synaptic_reconstruction.sample_data_tem_2d
54+
display_name: TEM 2D Sample Data
55+
key: synapse-net-tem-2d
56+
- command: synaptic_reconstruction.sample_data_tem_tomo
57+
display_name: TEM Tomo Sample Data
58+
key: synapse-net-tem-tomo

synaptic_reconstruction/sample_data.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
import os
22
import pooch
33

4+
from .file_utils import read_mrc
5+
46

57
def get_sample_data(name: str) -> str:
68
"""Get the filepath to SynapseNet sample data, stored as mrc file.
79
810
Args:
9-
name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data.
11+
name: The name of the sample data. Currently, we only provide 'tem_2d' and 'tem_tomo'.
1012
1113
Returns:
1214
The filepath to the downloaded sample data.
1315
"""
1416
registry = {
1517
"tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28",
18+
"tem_tomo.mrc": "24af31a10761b59fa6ad9f0e763f8f084304e4f31c59b482dd09dde8cd443ed7",
1619
}
1720
urls = {
1821
"tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download",
22+
"tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/NeP7gOv76Vj26lm/download",
1923
}
2024
key = f"{name}.mrc"
2125

@@ -32,3 +36,19 @@ def get_sample_data(name: str) -> str:
3236
)
3337
file_path = data_registry.fetch(key)
3438
return file_path
39+
40+
41+
def _sample_data(name):
42+
file_path = get_sample_data(name)
43+
data, voxel_size = read_mrc(file_path)
44+
metadata = {"file_path": file_path, "voxel_size": voxel_size}
45+
add_image_kwargs = {"name": name, "metadata": metadata, "colormap": "gray"}
46+
return [(data, add_image_kwargs)]
47+
48+
49+
def sample_data_tem_2d():
50+
return _sample_data("tem_2d")
51+
52+
53+
def sample_data_tem_tomo():
54+
return _sample_data("tem_tomo")

synaptic_reconstruction/tools/base_widget.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,11 @@ def __init__(self):
2323
self.attribute_dict = {}
2424

2525
def _create_layer_selector(self, selector_name, layer_type="Image"):
26-
"""
27-
Create a layer selector for an image or labels and store it in a dictionary.
26+
"""Create a layer selector for an image or labels and store it in a dictionary.
2827
29-
Parameters:
30-
- selector_name (str): The name of the selector, used as a key in the dictionary.
31-
- layer_type (str): The type of layer to filter for ("Image" or "Labels").
28+
Args:
29+
selector_name (str): The name of the selector, used as a key in the dictionary.
30+
layer_type (str): The type of layer to filter for ("Image" or "Labels").
3231
"""
3332
if not hasattr(self, "layer_selectors"):
3433
self.layer_selectors = {}
@@ -286,17 +285,19 @@ def _get_file_path(self, name, textbox, tooltip=None):
286285
# Handle the case where the selected path is not a file
287286
print("Invalid file selected. Please try again.")
288287

289-
def _handle_resolution(self, metadata, voxel_size_param, ndim):
288+
def _handle_resolution(self, metadata, voxel_size_param, ndim, return_as_list=True):
290289
# Get the resolution / voxel size from the layer metadata if available.
291290
resolution = metadata.get("voxel_size", None)
292-
if resolution is not None:
293-
resolution = [resolution[ax] for ax in ("zyx" if ndim == 3 else "yx")]
294291

295292
# If user input was given then override resolution from metadata.
293+
axes = "zyx" if ndim == 3 else "yx"
296294
if voxel_size_param.value() != 0.0: # Changed from default.
297-
resolution = ndim * [voxel_size_param.value()]
295+
resolution = {ax: voxel_size_param.value() for ax in axes}
296+
297+
if resolution is not None and return_as_list:
298+
resolution = [resolution[ax] for ax in axes]
299+
assert len(resolution) == ndim
298300

299-
assert len(resolution) == ndim
300301
return resolution
301302

302303
def _save_table(self, save_path, data):

synaptic_reconstruction/tools/segmentation_widget.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import copy
2+
13
import napari
4+
import numpy as np
5+
26
from napari.utils.notifications import show_info
37
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
48

59
from .base_widget import BaseWidget
610
from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,
711
get_current_tiling, compute_scale_from_voxel_size, load_custom_model)
8-
from synaptic_reconstruction.inference.util import get_default_tiling
9-
import copy
12+
from ..inference.util import get_default_tiling
1013

1114

1215
class SegmentationWidget(BaseWidget):
@@ -79,37 +82,41 @@ def on_predict(self):
7982
show_info("INFO: Please choose an image.")
8083
return
8184

82-
# load current tiling
85+
# Get the current tiling.
8386
self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)
8487

88+
# Get the voxel size.
8589
metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
86-
voxel_size = metadata.get("voxel_size", None)
87-
scale = None
90+
voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
8891

89-
if self.voxel_size_param.value() != 0.0: # changed from default
90-
voxel_size = {}
91-
# override voxel size with user input
92-
if len(image.shape) == 3:
93-
voxel_size["x"] = self.voxel_size_param.value()
94-
voxel_size["y"] = self.voxel_size_param.value()
95-
voxel_size["z"] = self.voxel_size_param.value()
96-
else:
97-
voxel_size["x"] = self.voxel_size_param.value()
98-
voxel_size["y"] = self.voxel_size_param.value()
92+
# Determine the scaling based on the voxel size.
93+
scale = None
9994
if voxel_size:
10095
if model_type == "custom":
10196
show_info("INFO: The image is not rescaled for a custom model.")
10297
else:
10398
# calculate scale so voxel_size is the same as in training
10499
scale = compute_scale_from_voxel_size(voxel_size, model_type)
105-
show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.")
106-
100+
scale_info = list(map(lambda x: np.round(x, 2), scale))
101+
show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
102+
103+
# Some models require an additional segmentation for inference or postprocessing.
104+
# For these models we read out the 'Extra Segmentation' widget.
105+
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
106+
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
107+
kwargs = {"extra_segmentation": extra_seg}
108+
else:
109+
kwargs = {}
107110
segmentation = run_segmentation(
108-
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale
111+
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
109112
)
110113

111-
# Add the segmentation layer
112-
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
114+
# Add the segmentation layer(s).
115+
if isinstance(segmentation, dict):
116+
for name, seg in segmentation.items():
117+
self.viewer.add_labels(seg, name=name, metadata=metadata)
118+
else:
119+
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
113120
show_info(f"INFO: Segmentation of {model_type} added to layers.")
114121

115122
def _create_settings_widget(self):
@@ -156,5 +163,10 @@ def _create_settings_widget(self):
156163
)
157164
setting_values.layout().addLayout(layout)
158165

166+
# Add selection UI for additional segmentation, which some models require for inference or postproc.
167+
self.extra_seg_selector_name = "Extra Segmentation"
168+
self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
169+
setting_values.layout().addWidget(self.extra_selector_widget)
170+
159171
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
160172
return settings

0 commit comments

Comments
 (0)