Skip to content

Commit 7941db0

Browse files
committed
added custom model path with target selector to choose between desired segmentation functions
1 parent c396113 commit 7941db0

File tree

2 files changed

+86
-15
lines changed

2 files changed

+86
-15
lines changed

synaptic_reconstruction/tools/segmentation_widget.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .base_widget import BaseWidget
66
from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,
7-
get_current_tiling, compute_scale_from_voxel_size)
7+
get_current_tiling, compute_scale_from_voxel_size, load_custom_model)
88
from synaptic_reconstruction.inference.util import get_default_tiling
99
import copy
1010

@@ -54,27 +54,34 @@ def load_model_widget(self):
5454
def on_predict(self):
5555
# Get the model and postprocessing settings.
5656
model_type = self.model_selector.currentText()
57-
if model_type == "- choose -":
58-
show_info("Please choose a model.")
57+
custom_model_path = self.checkpoint_param.text()
58+
if model_type == "- choose -" and custom_model_path is None:
59+
show_info("INFO: Please choose a model.")
5960
return
6061

61-
# Load the model.
6262
device = get_device(self.device_dropdown.currentText())
63-
model = get_model(model_type, device)
63+
64+
# Load the model. Override if user chose custom model
65+
if custom_model_path:
66+
model = load_custom_model(custom_model_path, device)
67+
if model:
68+
show_info(f"INFO: Using custom model from path: {custom_model_path}")
69+
model_type = "custom"
70+
else:
71+
show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
72+
return
73+
else:
74+
model = get_model(model_type, device)
6475

6576
# Get the image data.
6677
image = self._get_layer_selector_data(self.image_selector_name)
6778
if image is None:
68-
show_info("Please choose an image.")
79+
show_info("INFO: Please choose an image.")
6980
return
7081

7182
# load current tiling
7283
self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)
7384

74-
# TODO: Use scale derived from the image resolution.
75-
# get avg image shape from training of the selected model
76-
# wichmann data avg voxel size = 17.53
77-
7885
metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
7986
voxel_size = metadata.get("voxel_size", None)
8087
scale = None
@@ -90,17 +97,24 @@ def on_predict(self):
9097
voxel_size["x"] = self.voxel_size_param.value()
9198
voxel_size["y"] = self.voxel_size_param.value()
9299
if voxel_size:
93-
# calculate scale so voxel_size is the same as in training
94-
scale = compute_scale_from_voxel_size(voxel_size, model_type)
95-
show_info(f"Rescaled the image by {scale} to optimize for the selected model.")
100+
if model_type == "custom":
101+
show_info(f"INFO: Scaling image not avaialable for custom model.")
102+
else:
103+
# calculate scale so voxel_size is the same as in training
104+
scale = compute_scale_from_voxel_size(voxel_size, model_type)
105+
show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.")
106+
107+
if model_type == "custom":
108+
# choose appropriate segmentation
109+
model_type = self.model_target_selector.currentText()
96110

97111
segmentation = run_segmentation(
98112
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale
99113
)
100114

101115
# Add the segmentation layer
102116
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
103-
show_info(f"Segmentation of {model_type} added to layers.")
117+
show_info(f"INFO: Segmentation of {model_type} added to layers.")
104118

105119
def _create_settings_widget(self):
106120
setting_values = QWidget()
@@ -141,5 +155,18 @@ def _create_settings_widget(self):
141155
)
142156
setting_values.layout().addLayout(layout)
143157

158+
self.checkpoint_param, layout = self._add_string_param(
159+
name="checkpoint", value="", title="Load Custom Model",
160+
placeholder="path/to/checkpoint.pt",
161+
)
162+
setting_values.layout().addLayout(layout)
163+
164+
self.model_target_selector, layout = self._add_choice_param(
165+
name="Model Target", value="vesicles", options=[
166+
"vesicles", "mitochondria", "active_zone", "compartments", "inner_ear_structures"
167+
]
168+
)
169+
setting_values.layout().addLayout(layout)
170+
144171
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
145172
return settings

synaptic_reconstruction/tools/util.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import re
23
from typing import Dict, List, Optional, Union
34

45
import torch
@@ -10,6 +11,24 @@
1011
from ..inference.mitochondria import segment_mitochondria
1112

1213

14+
def load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
15+
model_path = _clean_filepath(model_path)
16+
if device is None:
17+
device = get_device(device)
18+
try:
19+
warnings.filterwarnings(
20+
"ignore",
21+
message="You are using `torch.load` with `weights_only=False`",
22+
category=FutureWarning
23+
)
24+
model = torch.load(model_path, map_location=torch.device(device))
25+
except Exception as e:
26+
print(e)
27+
print("model path", model_path)
28+
return None
29+
return model
30+
31+
1332
def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
1433
"""Get the model for the given segmentation type.
1534
@@ -21,7 +40,8 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
2140
Returns:
2241
The model.
2342
"""
24-
device = get_device(device)
43+
if device is None:
44+
device = get_device(device)
2545
model_registry = get_model_registry()
2646
model_path = model_registry.fetch(model_type)
2747
warnings.filterwarnings(
@@ -221,3 +241,27 @@ def compute_scale_from_voxel_size(
221241
voxel_size["z"] / training_voxel_size["z"]
222242
)
223243
return scale
244+
245+
246+
def _clean_filepath(filepath):
247+
"""
248+
Cleans a given filepath by:
249+
- Removing newline characters (\n)
250+
- Removing escape sequences
251+
- Stripping the 'file://' prefix if present
252+
253+
Args:
254+
filepath (str): The original filepath
255+
256+
Returns:
257+
str: The cleaned filepath
258+
"""
259+
# Remove 'file://' prefix if present
260+
if filepath.startswith("file://"):
261+
filepath = filepath[7:]
262+
263+
# Remove escape sequences and newlines
264+
filepath = re.sub(r'\\.', '', filepath)
265+
filepath = filepath.replace('\n', '').replace('\r', '')
266+
267+
return filepath

0 commit comments

Comments
 (0)