Skip to content

Commit ec1f050

Browse files
committed
added segmentation widget
1 parent 74c3e69 commit ec1f050

File tree

8 files changed

+561
-1
lines changed

8 files changed

+561
-1
lines changed

environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ dependencies:
88
- pip
99
- pyqt
1010
- magicgui
11+
- pytorch
1112
- pip:
1213
- napari-skimage-regionprops

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"console_scripts": [
1616
"sr_tools.correct_segmentation = synaptic_reconstruction.tools.segmentation_correction:main",
1717
"sr_tools.measure_distances = synaptic_reconstruction.tools.distance_measurement:main",
18-
]
18+
],
19+
"napari.manifest": [
20+
"synaptic_reconstruction = synaptic_reconstruction:napari.yaml",
21+
],
1922
},
2023
)

synaptic_reconstruction/napari.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
name: synaptic_reconstruction
2+
display_name: Synaptic Reconstruction
3+
# see https://napari.org/stable/plugins/manifest.html for valid categories
4+
categories: ["Image Processing", "Annotation"]
5+
contributions:
6+
commands:
7+
- id: synaptic_reconstruction.segment
8+
python_name: synaptic_reconstruction.tools.synaptic_plugin.segmentation:segmentation_widget
9+
title: Segment
10+
- id: synaptic_reconstruction.file_reader
11+
title: Read ".mrc, .rec" files
12+
python_name: synaptic_reconstruction.tools.file_reader_plugin.elf_reader:get_reader
13+
14+
readers:
15+
- command: synaptic_reconstruction.file_reader
16+
filename_patterns:
17+
- '*.mrc'
18+
accepts_directories: false
19+
20+
widgets:
21+
- command: synaptic_reconstruction.segment
22+
display_name: Segmentation
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Callable, List, Optional, Sequence, Union
2+
from napari.types import LayerData
3+
from elf.io import open_file
4+
5+
PathLike = str
6+
PathOrPaths = Union[PathLike, Sequence[PathLike]]
7+
ReaderFunction = Callable[[PathOrPaths], List[LayerData]]
8+
9+
10+
def get_reader(path: PathOrPaths) -> Optional[ReaderFunction]:
11+
# If we recognize the format, we return the actual reader function
12+
if isinstance(path, str) and path.endswith(".mrc"):
13+
return elf_read_file
14+
# otherwise we return None.
15+
return None
16+
17+
18+
def elf_read_file(path: PathOrPaths) -> List[LayerData]:
19+
try:
20+
with open_file(path, mode="r") as f:
21+
data = f["data"][:]
22+
layer_attributes = {
23+
"name": "Raw",
24+
"colormap": "gray",
25+
"blending": "additive"
26+
}
27+
return [(data, layer_attributes)]
28+
except Exception as e:
29+
print(f"Failed to read file: {e}")
30+
return

synaptic_reconstruction/tools/synaptic_plugin/__init__.py

Whitespace-only changes.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import napari
2+
from qtpy.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QSpinBox, QLineEdit, QGroupBox, QFormLayout, QFrame, QComboBox
3+
from superqt import QCollapsible
4+
5+
6+
class BaseWidget(QWidget):
7+
def __init__(self):
8+
super().__init__()
9+
self.viewer = napari.current_viewer()
10+
11+
def _add_int_param(self, name, value, min_val, max_val, title=None, step=1, layout=None, tooltip=None):
12+
if layout is None:
13+
layout = QHBoxLayout()
14+
label = QLabel(title or name)
15+
if tooltip:
16+
label.setToolTip(tooltip)
17+
layout.addWidget(label)
18+
param = QSpinBox()
19+
param.setRange(min_val, max_val)
20+
param.setValue(value)
21+
param.setSingleStep(step)
22+
param.valueChanged.connect(lambda val: setattr(self, name, val))
23+
if tooltip:
24+
param.setToolTip(tooltip)
25+
layout.addWidget(param)
26+
return param, layout
27+
28+
def _add_choice_param(self, name, value, options, title=None, layout=None, update=None, tooltip=None):
29+
if layout is None:
30+
layout = QHBoxLayout()
31+
label = QLabel(title or name)
32+
if tooltip:
33+
label.setToolTip(tooltip)
34+
layout.addWidget(label)
35+
36+
# Create the dropdown menu via QComboBox, set the available values.
37+
dropdown = QComboBox()
38+
dropdown.addItems(options)
39+
if update is None:
40+
dropdown.currentIndexChanged.connect(lambda index: setattr(self, name, options[index]))
41+
else:
42+
dropdown.currentIndexChanged.connect(update)
43+
44+
# Set the correct value for the value.
45+
dropdown.setCurrentIndex(dropdown.findText(value))
46+
47+
if tooltip:
48+
dropdown.setToolTip(tooltip)
49+
50+
layout.addWidget(dropdown)
51+
return dropdown, layout
52+
53+
def _add_shape_param(self, names, values, min_val, max_val, step=1, title=None, tooltip=None):
54+
layout = QHBoxLayout()
55+
56+
x_layout = QVBoxLayout()
57+
x_param, _ = self._add_int_param(
58+
names[0], values[0], min_val=min_val, max_val=max_val, layout=x_layout, step=step,
59+
title=title[0] if title is not None else title, tooltip=tooltip
60+
)
61+
layout.addLayout(x_layout)
62+
63+
y_layout = QVBoxLayout()
64+
y_param, _ = self._add_int_param(
65+
names[1], values[1], min_val=min_val, max_val=max_val, layout=y_layout, step=step,
66+
title=title[1] if title is not None else title, tooltip=tooltip
67+
)
68+
layout.addLayout(y_layout)
69+
70+
return x_param, y_param, layout
71+
72+
def _make_collapsible(self, widget, title):
73+
parent_widget = QWidget()
74+
parent_widget.setLayout(QVBoxLayout())
75+
collapsible = QCollapsible(title, parent_widget)
76+
collapsible.addWidget(widget)
77+
parent_widget.layout().addWidget(collapsible)
78+
return parent_widget
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from typing import TYPE_CHECKING
2+
import h5py
3+
from magicgui import magic_factory, widgets
4+
import napari
5+
import napari.layers
6+
from napari.utils.notifications import show_info
7+
from napari import Viewer
8+
import numpy as np
9+
from qtpy.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QSpinBox, QLineEdit, QGroupBox, QFormLayout, QFrame, QComboBox
10+
from superqt import QCollapsible
11+
from elf.io import open_file
12+
from .base_widget import BaseWidget
13+
import os
14+
from synaptic_reconstruction.inference.vesicles import segment_vesicles
15+
16+
# Custom imports for model and prediction utilities
17+
from ..util import get_device, get_model_registry, run_prediction, _available_devices
18+
19+
# if TYPE_CHECKING:
20+
# import napari
21+
22+
23+
# def _make_collapsible(widget, title):
24+
# parent_widget = QWidget()
25+
# parent_widget.setLayout(QVBoxLayout())model_path
26+
class SegmentationWidget(BaseWidget):
27+
def __init__(self):
28+
super().__init__()
29+
30+
self.model = None
31+
self.viewer = napari.current_viewer()
32+
layout = QVBoxLayout()
33+
34+
# Create the image selection dropdown
35+
self.image_selector_widget = self.create_image_selector()
36+
37+
# Add your buttons here
38+
# self.load_model_button = QPushButton('Load Model')
39+
self.predict_button = QPushButton('Run Prediction')
40+
41+
# Connect buttons to functions
42+
self.predict_button.clicked.connect(self.on_predict)
43+
# self.load_model_button.clicked.connect(self.on_load_model)
44+
45+
# create model selector
46+
self.model_selector_widget = self.load_model_widget()
47+
48+
# create advanced settings
49+
self.settings = self._create_settings_widget()
50+
51+
# Add the widgets to the layout
52+
layout.addWidget(self.image_selector_widget)
53+
layout.addWidget(self.model_selector_widget)
54+
layout.addWidget(self.settings)
55+
layout.addWidget(self.predict_button)
56+
57+
self.setLayout(layout)
58+
59+
def create_image_selector(self):
60+
selector_widget = QWidget()
61+
self.image_selector = QComboBox()
62+
63+
title_label = QLabel("Select Image Layer:")
64+
65+
# Populate initial options
66+
self.update_image_selector()
67+
68+
# Connect selection change to update self.image
69+
self.image_selector.currentIndexChanged.connect(self.update_image_data)
70+
71+
# Connect to Napari layer events to update the list
72+
self.viewer.layers.events.inserted.connect(self.update_image_selector)
73+
self.viewer.layers.events.removed.connect(self.update_image_selector)
74+
75+
layout = QVBoxLayout()
76+
layout.addWidget(title_label)
77+
layout.addWidget(self.image_selector)
78+
selector_widget.setLayout(layout)
79+
return selector_widget
80+
81+
def update_image_selector(self, event=None):
82+
"""Update dropdown options with current image layers in the viewer."""
83+
self.image_selector.clear()
84+
85+
# Add each image layer's name to the dropdown
86+
image_layers = [layer.name for layer in self.viewer.layers if isinstance(layer, napari.layers.Image)]
87+
self.image_selector.addItems(image_layers)
88+
89+
def update_image_data(self):
90+
"""Update the self.image attribute with data from the selected layer."""
91+
selected_layer_name = self.image_selector.currentText()
92+
if selected_layer_name in self.viewer.layers:
93+
self.image = self.viewer.layers[selected_layer_name].data
94+
else:
95+
self.image = None # Reset if no valid selection
96+
97+
def load_model_widget(self):
98+
model_widget = QWidget()
99+
title_label = QLabel("Select Model:")
100+
101+
models = list(get_model_registry().urls.keys())
102+
self.model = None # set default model
103+
self.model_selector = QComboBox()
104+
self.model_selector.addItems(models)
105+
# Create a layout and add the title label and combo box
106+
layout = QVBoxLayout()
107+
layout.addWidget(title_label)
108+
layout.addWidget(self.model_selector)
109+
110+
# Set layout on the model widget
111+
model_widget.setLayout(layout)
112+
return model_widget
113+
114+
# def on_load_model(self):
115+
# # Open file dialog to select a model
116+
# file_dialog = QFileDialog(self)
117+
# file_dialog.setFileMode(QFileDialog.ExistingFiles)
118+
# file_dialog.setNameFilter("Model (*.pt)")
119+
# file_dialog.setViewMode(QFileDialog.List)
120+
121+
# if file_dialog.exec_():
122+
# file_paths = file_dialog.selectedFiles()
123+
# if file_paths:
124+
# # Assuming you load a single model path here
125+
# model_path = file_paths[0]
126+
# self.load_model(model_path)
127+
128+
# def load_model(self, model_path):
129+
# print("model path type and value", type(model_path), model_path)
130+
# # Load the model from the selected path
131+
# model = get_model(model_path)
132+
# self.model = model
133+
134+
def on_predict(self):
135+
# Get the model and postprocessing settings.
136+
model_key = self.model_selector.currentText()
137+
if model_key == "- choose -":
138+
show_info("Please choose a model.")
139+
return
140+
141+
142+
model_registry = get_model_registry()
143+
model_path = model_registry.fetch(model_key)
144+
145+
if self.image is None:
146+
show_info("Please choose an image.")
147+
return
148+
149+
# get tile shape and halo from the viewer
150+
tile_shape = (self.tile_x_param.value(), self.tile_y_param.value())
151+
halo = (self.halo_x_param.value(), self.halo_y_param.value())
152+
tiling = {
153+
"tile": {
154+
"x": self.tile_x_param.value(),
155+
"y": self.tile_y_param.value(),
156+
"z": 1
157+
},
158+
"halo": {
159+
"x": self.halo_x_param.value(),
160+
"y": self.halo_y_param.value(),
161+
"z": 1
162+
}
163+
}
164+
segmentation = segment_vesicles(self.image, model_path=model_path) #tiling=tiling
165+
# Add predictions to Napari as separate layers
166+
# for i, pred in enumerate(segmentation):
167+
# layer_name = f"Prediction {i+1}"
168+
# self.viewer.add_image(pred, name=layer_name, colormap="inferno", blending="additive")
169+
layer_kwargs = {"colormap": "inferno", "blending": "additive"}
170+
return segmentation, layer_kwargs
171+
172+
def _create_settings_widget(self):
173+
setting_values = QWidget()
174+
# setting_values.setToolTip(get_tooltip("embedding", "settings"))
175+
setting_values.setLayout(QVBoxLayout())
176+
177+
# Create UI for the device.
178+
self.device = "auto"
179+
device_options = ["auto"] + _available_devices()
180+
181+
self.device_dropdown, layout = self._add_choice_param("device", self.device, device_options)
182+
# tooltip=get_tooltip("embedding", "device"))
183+
setting_values.layout().addLayout(layout)
184+
185+
# Create UI for the tile shape.
186+
self.tile_x, self.tile_y = 256, 256 # defaults
187+
self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
188+
("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
189+
# tooltip=get_tooltip("embedding", "tiling")
190+
)
191+
setting_values.layout().addLayout(layout)
192+
193+
# Create UI for the halo.
194+
self.halo_x, self.halo_y = 32, 32 # defaults
195+
self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
196+
("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
197+
# tooltip=get_tooltip("embedding", "halo")
198+
)
199+
setting_values.layout().addLayout(layout)
200+
201+
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
202+
return settings
203+
204+
205+
def segmentation_widget():
206+
return SegmentationWidget()

0 commit comments

Comments
 (0)