Skip to content

Commit c396113

Browse files
Merge pull request #55 from computational-cell-analytics/53-next-steps-plugin-implementation
53 next steps plugin implementation
2 parents e3d3798 + 27c4d9d commit c396113

File tree

5 files changed

+185
-42
lines changed

5 files changed

+185
-42
lines changed

synaptic_reconstruction/tools/base_widget.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _update_selector(self, selector, layer_filter):
6161
image_layers = [layer.name for layer in self.viewer.layers if isinstance(layer, layer_filter)] # if isinstance(layer, napari.layers.Image)
6262
selector.addItems(image_layers)
6363

64-
def _get_layer_selector_data(self, selector_name):
64+
def _get_layer_selector_data(self, selector_name, return_metadata=False):
6565
"""Return the data for the layer currently selected in a given selector."""
6666
if selector_name in self.layer_selectors:
6767
selector_widget = self.layer_selectors[selector_name]
@@ -72,7 +72,10 @@ def _get_layer_selector_data(self, selector_name):
7272
if isinstance(image_selector, QComboBox):
7373
selected_layer_name = image_selector.currentText()
7474
if selected_layer_name in self.viewer.layers:
75-
return self.viewer.layers[selected_layer_name].data
75+
if return_metadata:
76+
return self.viewer.layers[selected_layer_name].metadata
77+
else:
78+
return self.viewer.layers[selected_layer_name].data
7679
return None # Return None if layer not found
7780

7881
def _add_string_param(self, name, value, title=None, placeholder=None, layout=None, tooltip=None):
@@ -169,6 +172,15 @@ def _add_shape_param(self, names, values, min_val, max_val, step=1, title=None,
169172
title=title[1] if title is not None else title, tooltip=tooltip
170173
)
171174
layout.addLayout(y_layout)
175+
176+
if len(names) == 3:
177+
z_layout = QVBoxLayout()
178+
z_param, _ = self._add_int_param(
179+
names[2], values[2], min_val=min_val, max_val=max_val, layout=z_layout, step=step,
180+
title=title[2] if title is not None else title, tooltip=tooltip
181+
)
182+
layout.addLayout(z_layout)
183+
return x_param, y_param, z_param, layout
172184

173185
return x_param, y_param, layout
174186

synaptic_reconstruction/tools/distance_measure_widget.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,33 +101,53 @@ def _add_lines_and_table(self, lines, properties, table_data, name):
101101
def on_measure_seg_to_object(self):
102102
segmentation = self._get_layer_selector_data(self.image_selector_name1)
103103
object_data = self._get_layer_selector_data(self.image_selector_name2)
104+
# get metadata from layer if available
105+
metadata = self._get_layer_selector_data(self.image_selector_name1, return_metadata=True)
106+
resolution = metadata.get("voxel_size", None)
107+
if resolution is not None:
108+
resolution = [v for v in resolution.values()]
109+
# if user input is present override metadata
110+
if self.voxel_size_param.value() != 0.0: # changed from default
111+
resolution = segmentation.ndim * [self.voxel_size_param.value()]
104112

105113
(distances,
106114
endpoints1,
107115
endpoints2,
108116
seg_ids) = distance_measurements.measure_segmentation_to_object_distances(
109117
segmentation=segmentation, segmented_object=object_data, distance_type="boundary",
118+
resolution=resolution
110119
)
111120
lines, properties = distance_measurements.create_object_distance_lines(
112121
distances=distances,
113122
endpoints1=endpoints1,
114123
endpoints2=endpoints2,
115-
seg_ids=seg_ids
124+
seg_ids=seg_ids,
116125
)
117126
table_data = self._to_table_data(distances, seg_ids, endpoints1, endpoints2)
118127
self._add_lines_and_table(lines, properties, table_data, name="distances")
119128

120129
def on_measure_pairwise(self):
121130
segmentation = self._get_layer_selector_data(self.image_selector_name1)
131+
if segmentation is None:
132+
show_info("Please choose a segmentation.")
133+
return
134+
# get metadata from layer if available
135+
metadata = self._get_layer_selector_data(self.image_selector_name1, return_metadata=True)
136+
resolution = metadata.get("voxel_size", None)
137+
if resolution is not None:
138+
resolution = [v for v in resolution.values()]
139+
# if user input is present override metadata
140+
if self.voxel_size_param.value() != 0.0: # changed from default
141+
resolution = segmentation.ndim * [self.voxel_size_param.value()]
122142

123143
(distances,
124144
endpoints1,
125145
endpoints2,
126146
seg_ids) = distance_measurements.measure_pairwise_object_distances(
127-
segmentation=segmentation, distance_type="boundary"
147+
segmentation=segmentation, distance_type="boundary", resolution=resolution
128148
)
129149
lines, properties = distance_measurements.create_pairwise_distance_lines(
130-
distances=distances, endpoints1=endpoints1, endpoints2=endpoints2, seg_ids=seg_ids.tolist()
150+
distances=distances, endpoints1=endpoints1, endpoints2=endpoints2, seg_ids=seg_ids.tolist(),
131151
)
132152
table_data = self._to_table_data(
133153
distances=properties["distance"],
@@ -142,5 +162,10 @@ def _create_settings_widget(self):
142162
self.save_path, layout = self._add_path_param(name="Save Table", select_type="file", value="")
143163
setting_values.layout().addLayout(layout)
144164

165+
self.voxel_size_param, layout = self._add_float_param(
166+
"voxel_size", 0.0, min_val=0.0, max_val=100.0,
167+
)
168+
setting_values.layout().addLayout(layout)
169+
145170
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
146171
return settings

synaptic_reconstruction/tools/segmentation_widget.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import napari
2-
import napari.layers
32
from napari.utils.notifications import show_info
43
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
54

65
from .base_widget import BaseWidget
7-
from .util import run_segmentation, get_model, get_model_registry, _available_devices
6+
from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,
7+
get_current_tiling, compute_scale_from_voxel_size)
8+
from synaptic_reconstruction.inference.util import get_default_tiling
9+
import copy
810

911

1012
class SegmentationWidget(BaseWidget):
@@ -13,6 +15,7 @@ def __init__(self):
1315

1416
self.viewer = napari.current_viewer()
1517
layout = QVBoxLayout()
18+
self.tiling = {}
1619

1720
# Create the image selection dropdown.
1821
self.image_selector_name = "Image data"
@@ -56,36 +59,47 @@ def on_predict(self):
5659
return
5760

5861
# Load the model.
59-
model = get_model(model_type, self.device)
62+
device = get_device(self.device_dropdown.currentText())
63+
model = get_model(model_type, device)
6064

6165
# Get the image data.
6266
image = self._get_layer_selector_data(self.image_selector_name)
6367
if image is None:
6468
show_info("Please choose an image.")
6569
return
6670

67-
# get tile shape and halo from the viewer
68-
tiling = {
69-
"tile": {
70-
"x": self.tile_x_param.value(),
71-
"y": self.tile_y_param.value(),
72-
"z": 1
73-
},
74-
"halo": {
75-
"x": self.halo_x_param.value(),
76-
"y": self.halo_y_param.value(),
77-
"z": 1
78-
}
79-
}
71+
# load current tiling
72+
self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)
8073

8174
# TODO: Use scale derived from the image resolution.
82-
scale = [self.scale_param.value()]
75+
# get avg image shape from training of the selected model
76+
# wichmann data avg voxel size = 17.53
77+
78+
metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
79+
voxel_size = metadata.get("voxel_size", None)
80+
scale = None
81+
82+
if self.voxel_size_param.value() != 0.0: # changed from default
83+
voxel_size = {}
84+
# override voxel size with user input
85+
if len(image.shape) == 3:
86+
voxel_size["x"] = self.voxel_size_param.value()
87+
voxel_size["y"] = self.voxel_size_param.value()
88+
voxel_size["z"] = self.voxel_size_param.value()
89+
else:
90+
voxel_size["x"] = self.voxel_size_param.value()
91+
voxel_size["y"] = self.voxel_size_param.value()
92+
if voxel_size:
93+
# calculate scale so voxel_size is the same as in training
94+
scale = compute_scale_from_voxel_size(voxel_size, model_type)
95+
show_info(f"Rescaled the image by {scale} to optimize for the selected model.")
96+
8397
segmentation = run_segmentation(
84-
image, model=model, model_type=model_type, tiling=tiling, scale=scale
98+
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale
8599
)
86100

87101
# Add the segmentation layer
88-
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation")
102+
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
89103
show_info(f"Segmentation of {model_type} added to layers.")
90104

91105
def _create_settings_widget(self):
@@ -94,31 +108,36 @@ def _create_settings_widget(self):
94108
setting_values.setLayout(QVBoxLayout())
95109

96110
# Create UI for the device.
97-
self.device = "auto"
111+
device = "auto"
98112
device_options = ["auto"] + _available_devices()
99113

100-
self.device_dropdown, layout = self._add_choice_param("device", self.device, device_options)
114+
self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
101115
setting_values.layout().addLayout(layout)
102116

103117
# Create UI for the tile shape.
104-
# TODO: make the tiling 3d and get the default values from 'inference'
105-
self.tile_x, self.tile_y = 512, 512 # defaults
106-
self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
107-
("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
118+
self.default_tiling = get_default_tiling()
119+
self.tiling = copy.deepcopy(self.default_tiling)
120+
self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
121+
("tile_x", "tile_y", "tile_z"),
122+
(self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
123+
min_val=0, max_val=2048, step=16,
108124
# tooltip=get_tooltip("embedding", "tiling")
109125
)
110126
setting_values.layout().addLayout(layout)
111127

112128
# Create UI for the halo.
113-
self.halo_x, self.halo_y = 64, 64 # defaults
114-
self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
115-
("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
129+
130+
self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
131+
("halo_x", "halo_y", "halo_z"),
132+
(self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
133+
min_val=0, max_val=512,
116134
# tooltip=get_tooltip("embedding", "halo")
117135
)
118136
setting_values.layout().addLayout(layout)
119137

120-
self.scale_param, layout = self._add_float_param(
121-
"scale", 0.5, min_val=0.0, max_val=8.0,
138+
# read voxel size from layer metadata
139+
self.voxel_size_param, layout = self._add_float_param(
140+
"voxel_size", 0.0, min_val=0.0, max_val=100.0,
122141
)
123142
setting_values.layout().addLayout(layout)
124143

synaptic_reconstruction/tools/util.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import numpy as np
66
import pooch
7+
import warnings
78

89
from ..inference.vesicles import segment_vesicles
910
from ..inference.mitochondria import segment_mitochondria
@@ -23,6 +24,11 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
2324
device = get_device(device)
2425
model_registry = get_model_registry()
2526
model_path = model_registry.fetch(model_type)
27+
warnings.filterwarnings(
28+
"ignore",
29+
message="You are using `torch.load` with `weights_only=False`",
30+
category=FutureWarning
31+
)
2632
model = torch.load(model_path)
2733
model.to(device)
2834
return model
@@ -73,12 +79,12 @@ def get_cache_dir():
7379

7480
def get_model_training_resolution(model_type):
7581
resolutions = {
76-
"active_zone": 1.44,
77-
"compartments": 3.47,
78-
"mitochondria": 1.0, # FIXME: this is a dummy value, we need to determine the real one
79-
"vesicles_2d": 1.35,
80-
"vesicles_3d": 1.35,
81-
"vesicles_cryo": 0.88,
82+
"active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
83+
"compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
84+
"mitochondria": {"x": 1.0, "y": 1.0, "z": 1.0}, # FIXME: this is a dummy value, we need to determine the real one
85+
"vesicles_2d": {"x": 1.35, "y": 1.35},
86+
"vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
87+
"vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
8288
}
8389
return resolutions[model_type]
8490

@@ -168,3 +174,50 @@ def _available_devices():
168174
else:
169175
available_devices.append(device)
170176
return available_devices
177+
178+
179+
def get_current_tiling(tiling: dict, default_tiling: dict, model_type: str):
180+
# get tiling values from qt objects
181+
for k, v in tiling.items():
182+
for k2, v2 in v.items():
183+
if isinstance(v2, int):
184+
continue
185+
tiling[k][k2] = v2.value()
186+
# check if user inputs tiling/halo or not
187+
if default_tiling == tiling:
188+
if "2d" in model_type:
189+
# if its a 2d model expand x,y and set z to 1
190+
tiling = {
191+
"tile": {
192+
"x": 512,
193+
"y": 512,
194+
"z": 1
195+
},
196+
"halo": {
197+
"x": 64,
198+
"y": 64,
199+
"z": 1
200+
}
201+
}
202+
elif "2d" in model_type:
203+
# if its a 2d model set z to 1
204+
tiling["tile"]["z"] = 1
205+
tiling["halo"]["z"] = 1
206+
207+
return tiling
208+
209+
210+
def compute_scale_from_voxel_size(
211+
voxel_size: dict,
212+
model_type: str
213+
) -> List[float]:
214+
training_voxel_size = get_model_training_resolution(model_type)
215+
scale = [
216+
voxel_size["x"] / training_voxel_size["x"],
217+
voxel_size["y"] / training_voxel_size["y"],
218+
]
219+
if len(voxel_size) == 3 and len(training_voxel_size) == 3:
220+
scale.append(
221+
voxel_size["z"] / training_voxel_size["z"]
222+
)
223+
return scale

synaptic_reconstruction/tools/volume_reader.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
from typing import Callable, List, Optional, Sequence, Union
4+
import mrcfile
45
from napari.types import LayerData
56

67
from elf.io import open_file, is_dataset
@@ -22,7 +23,17 @@ def get_reader(path: PathOrPaths) -> Optional[ReaderFunction]:
2223
def _read_mrc(path, fname):
2324
with open_file(path, mode="r") as f:
2425
data = f["data"][:]
25-
layer_attributes = {"name": fname, "colormap": "gray"}
26+
voxel_size = read_voxel_size(path)
27+
metadata = {
28+
"file_path": path,
29+
"voxel_size": voxel_size
30+
}
31+
layer_attributes = {
32+
"name": fname,
33+
"colormap": "gray",
34+
"metadata": metadata
35+
}
36+
2637
return [(data, layer_attributes)]
2738

2839

@@ -61,3 +72,26 @@ def read_image_volume(path: PathOrPaths) -> List[LayerData]:
6172
except Exception as e:
6273
print(f"Failed to read file: {e}")
6374
return
75+
76+
77+
def read_voxel_size(input_path: str) -> dict | None:
78+
"""Read voxel size from mrc/rec file and store it in layer_attributes.
79+
The original unit of voxel size is Angstrom and we convert it to nanometers
80+
by dividing it by ten.
81+
82+
Args:
83+
input_path (str): path to mrc/rec file
84+
layer_attributes (dict): napari layer attributes to store voxel size to
85+
"""
86+
new_voxel_size = None
87+
with mrcfile.open(input_path, permissive=True) as mrc:
88+
try:
89+
voxel_size = mrc.voxel_size
90+
new_voxel_size = {
91+
"x": voxel_size.x / 10,
92+
"y": voxel_size.y / 10,
93+
"z": voxel_size.z / 10,
94+
}
95+
except Exception as e:
96+
print(f"Failed to read voxel size: {e}")
97+
return new_voxel_size

0 commit comments

Comments
 (0)