Skip to content

Commit a051d85

Browse files
Update CLI and model logic
1 parent 55074db commit a051d85

File tree

5 files changed

+90
-183
lines changed

5 files changed

+90
-183
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
license="MIT",
1414
entry_points={
1515
"console_scripts": [
16-
"synapse_net.run_segmentation = synaptic_reconstruction:tools.cli:segmentation_cli"
16+
"synapse_net.run_segmentation = synaptic_reconstruction.tools.cli:segmentation_cli"
1717
],
1818
"napari.manifest": [
1919
"synaptic_reconstruction = synaptic_reconstruction:napari.yaml",

synaptic_reconstruction/inference/util.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -255,21 +255,23 @@ def _get_file_paths(input_path, ext=".mrc"):
255255

256256

257257
def _load_input(img_path, extra_files, i):
258-
# Load the input data data
259-
with open_file(img_path, "r") as f:
260-
261-
# Try to automatically derive the key with the raw data.
262-
keys = list(f.keys())
263-
if len(keys) == 1:
264-
key = keys[0]
265-
elif "data" in keys:
266-
key = "data"
267-
elif "raw" in keys:
268-
key = "raw"
269-
270-
input_volume = f[key][:]
271-
assert input_volume.ndim == 3
258+
# Load the input data.
259+
if os.path.splitext(img_path)[-1] == ".tif":
260+
input_volume = imageio.imread(img_path)
272261

262+
else:
263+
with open_file(img_path, "r") as f:
264+
# Try to automatically derive the key with the raw data.
265+
keys = list(f.keys())
266+
if len(keys) == 1:
267+
key = keys[0]
268+
elif "data" in keys:
269+
key = "data"
270+
elif "raw" in keys:
271+
key = "raw"
272+
input_volume = f[key][:]
273+
274+
assert input_volume.ndim in (2, 3)
273275
# For now we assume this is always tif.
274276
if extra_files is not None:
275277
extra_input = imageio.imread(extra_files[i])

synaptic_reconstruction/tools/cli.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import argparse
22
from functools import partial
33

4-
from .util import run_segmentation
4+
from .util import run_segmentation, get_model
55
from ..inference.util import inference_helper, parse_tiling
66

77

8+
# TODO: handle kwargs
89
def segmentation_cli():
910
parser = argparse.ArgumentParser(description="Run segmentation.")
1011
parser.add_argument(
@@ -16,7 +17,7 @@ def segmentation_cli():
1617
help="The filepath to directory where the segmentations will be saved."
1718
)
1819
parser.add_argument(
19-
"--model_path", "-m", required=True, help="The filepath to the vesicle model."
20+
"--model", "-m", required=True, help="The model type."
2021
)
2122
parser.add_argument(
2223
"--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
@@ -40,10 +41,11 @@ def segmentation_cli():
4041
)
4142
args = parser.parse_args()
4243

43-
# TODO: preload the model!
44+
model = get_model(args.model)
4445
tiling = parse_tiling(args.tile_shape, args.halo)
46+
4547
segmentation_function = partial(
46-
run_segmentation, model_path=args.model_path, verbose=False, tiling=tiling,
48+
run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling,
4749
)
4850
inference_helper(
4951
args.input_path, args.output_path, segmentation_function,

synaptic_reconstruction/tools/segmentation_widget.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,27 @@
44
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
55

66
from .base_widget import BaseWidget
7-
from .util import run_segmentation, get_model_registry, _available_devices
7+
from .util import run_segmentation, get_model, get_model_registry, _available_devices
88

99

1010
class SegmentationWidget(BaseWidget):
1111
def __init__(self):
1212
super().__init__()
1313

14-
self.model = None
1514
self.viewer = napari.current_viewer()
1615
layout = QVBoxLayout()
1716

18-
# Create the image selection dropdown
17+
# Create the image selection dropdown.
1918
self.image_selector_name = "Image data"
2019
self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
2120

22-
# create buttons
21+
# Create buttons and widgets.
2322
self.predict_button = QPushButton("Run Segmentation")
24-
25-
# Connect buttons to functions
2623
self.predict_button.clicked.connect(self.on_predict)
27-
# self.load_model_button.clicked.connect(self.on_load_model)
28-
29-
# create model selector
3024
self.model_selector_widget = self.load_model_widget()
31-
32-
# create advanced settings
3325
self.settings = self._create_settings_widget()
3426

35-
# Add the widgets to the layout
27+
# Add the widgets to the layout.
3628
layout.addWidget(self.image_selector_widget)
3729
layout.addWidget(self.model_selector_widget)
3830
layout.addWidget(self.settings)
@@ -44,8 +36,7 @@ def load_model_widget(self):
4436
model_widget = QWidget()
4537
title_label = QLabel("Select Model:")
4638

47-
models = list(get_model_registry().urls.keys())
48-
self.model = None # set default model
39+
models = ["- choose -"] + list(get_model_registry().urls.keys())
4940
self.model_selector = QComboBox()
5041
self.model_selector.addItems(models)
5142
# Create a layout and add the title label and combo box
@@ -59,23 +50,20 @@ def load_model_widget(self):
5950

6051
def on_predict(self):
6152
# Get the model and postprocessing settings.
62-
model_key = self.model_selector.currentText()
63-
if model_key == "- choose -":
53+
model_type = self.model_selector.currentText()
54+
if model_type == "- choose -":
6455
show_info("Please choose a model.")
6556
return
6657

67-
# loading model
68-
model_registry = get_model_registry()
69-
model_key = self.model_selector.currentText()
70-
# model_path = "/home/freckmann15/.cache/synapse-net/models/vesicles" #
71-
model_path = model_registry.fetch(model_key)
72-
# get image data
58+
# Load the model.
59+
model = get_model(model_type, self.device)
60+
61+
# Get the image data.
7362
image = self._get_layer_selector_data(self.image_selector_name)
7463
if image is None:
7564
show_info("Please choose an image.")
7665
return
7766

78-
# FIXME: don't hard-code tiling here, but figure it out centrally in the prediction function.
7967
# get tile shape and halo from the viewer
8068
tiling = {
8169
"tile": {
@@ -89,26 +77,16 @@ def on_predict(self):
8977
"z": 1
9078
}
9179
}
92-
tile_shape = (self.tile_x_param.value(), self.tile_y_param.value())
93-
halo = (self.halo_x_param.value(), self.halo_y_param.value())
94-
use_custom_tiling = False
95-
for ts, h in zip(tile_shape, halo):
96-
if ts != 0 or h != 0: # if anything changed from default
97-
use_custom_tiling = True
98-
if use_custom_tiling:
99-
segmentation = run_segmentation(
100-
image, model_path=model_path, model_key=model_key,
101-
tiling=tiling, scale=self.scale_param.value()
102-
)
103-
else:
104-
segmentation = run_segmentation(
105-
image, model_path=model_path, model_key=model_key,
106-
scale=self.scale_param.value()
107-
)
80+
81+
# TODO: Use scale derived from the image resolution.
82+
scale = [self.scale_param.value()]
83+
segmentation = run_segmentation(
84+
image, model=model, model_type=model_type, tiling=tiling, scale=scale
85+
)
10886

10987
# Add the segmentation layer
110-
self.viewer.add_labels(segmentation, name=f"{model_key}-segmentation")
111-
show_info(f"Segmentation of {model_key} added to layers.")
88+
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation")
89+
show_info(f"Segmentation of {model_type} added to layers.")
11290

11391
def _create_settings_widget(self):
11492
setting_values = QWidget()
@@ -123,6 +101,7 @@ def _create_settings_widget(self):
123101
setting_values.layout().addLayout(layout)
124102

125103
# Create UI for the tile shape.
104+
# TODO: make the tiling 3d and get the default values from 'inference'
126105
self.tile_x, self.tile_y = 512, 512 # defaults
127106
self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
128107
("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,

0 commit comments

Comments
 (0)