Skip to content

Commit e95dd45

Browse files
Merge pull request #60 from computational-cell-analytics/56-add-support-for-custom-model-paths-in-segmentation-plugin
added custom model path with target selector to choose between desire…
2 parents a974e65 + 3e11446 commit e95dd45

File tree

3 files changed

+82
-29
lines changed

3 files changed

+82
-29
lines changed

synaptic_reconstruction/tools/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
# TODO: handle kwargs
9+
# TODO: add custom model path
910
def segmentation_cli():
1011
parser = argparse.ArgumentParser(description="Run segmentation.")
1112
parser.add_argument(
@@ -16,6 +17,7 @@ def segmentation_cli():
1617
"--output_path", "-o", required=True,
1718
help="The filepath to directory where the segmentations will be saved."
1819
)
20+
# TODO: list the availabel models here by parsing the keys of the model registry
1921
parser.add_argument(
2022
"--model", "-m", required=True, help="The model type."
2123
)

synaptic_reconstruction/tools/segmentation_widget.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .base_widget import BaseWidget
66
from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,
7-
get_current_tiling, compute_scale_from_voxel_size)
7+
get_current_tiling, compute_scale_from_voxel_size, load_custom_model)
88
from synaptic_reconstruction.inference.util import get_default_tiling
99
import copy
1010

@@ -54,27 +54,34 @@ def load_model_widget(self):
5454
def on_predict(self):
5555
# Get the model and postprocessing settings.
5656
model_type = self.model_selector.currentText()
57-
if model_type == "- choose -":
58-
show_info("Please choose a model.")
57+
custom_model_path = self.checkpoint_param.text()
58+
if model_type == "- choose -" and custom_model_path is None:
59+
show_info("INFO: Please choose a model.")
5960
return
6061

61-
# Load the model.
6262
device = get_device(self.device_dropdown.currentText())
63-
model = get_model(model_type, device)
63+
64+
# Load the model. Override if user chose custom model
65+
if custom_model_path:
66+
model = load_custom_model(custom_model_path, device)
67+
if model:
68+
show_info(f"INFO: Using custom model from path: {custom_model_path}")
69+
model_type = "custom"
70+
else:
71+
show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
72+
return
73+
else:
74+
model = get_model(model_type, device)
6475

6576
# Get the image data.
6677
image = self._get_layer_selector_data(self.image_selector_name)
6778
if image is None:
68-
show_info("Please choose an image.")
79+
show_info("INFO: Please choose an image.")
6980
return
7081

7182
# load current tiling
7283
self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)
7384

74-
# TODO: Use scale derived from the image resolution.
75-
# get avg image shape from training of the selected model
76-
# wichmann data avg voxel size = 17.53
77-
7885
metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
7986
voxel_size = metadata.get("voxel_size", None)
8087
scale = None
@@ -90,17 +97,20 @@ def on_predict(self):
9097
voxel_size["x"] = self.voxel_size_param.value()
9198
voxel_size["y"] = self.voxel_size_param.value()
9299
if voxel_size:
93-
# calculate scale so voxel_size is the same as in training
94-
scale = compute_scale_from_voxel_size(voxel_size, model_type)
95-
show_info(f"Rescaled the image by {scale} to optimize for the selected model.")
100+
if model_type == "custom":
101+
show_info("INFO: The image is not rescaled for a custom model.")
102+
else:
103+
# calculate scale so voxel_size is the same as in training
104+
scale = compute_scale_from_voxel_size(voxel_size, model_type)
105+
show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.")
96106

97107
segmentation = run_segmentation(
98108
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale
99109
)
100110

101111
# Add the segmentation layer
102112
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
103-
show_info(f"Segmentation of {model_type} added to layers.")
113+
show_info(f"INFO: Segmentation of {model_type} added to layers.")
104114

105115
def _create_settings_widget(self):
106116
setting_values = QWidget()
@@ -141,5 +151,11 @@ def _create_settings_widget(self):
141151
)
142152
setting_values.layout().addLayout(layout)
143153

154+
self.checkpoint_param, layout = self._add_string_param(
155+
name="checkpoint", value="", title="Load Custom Model",
156+
placeholder="path/to/checkpoint.pt",
157+
)
158+
setting_values.layout().addLayout(layout)
159+
144160
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
145161
return settings

synaptic_reconstruction/tools/util.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
import os
2+
import re
23
from typing import Dict, List, Optional, Union
34

45
import torch
56
import numpy as np
67
import pooch
7-
import warnings
88

99
from ..inference.vesicles import segment_vesicles
1010
from ..inference.mitochondria import segment_mitochondria
1111

1212

13+
def load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
14+
model_path = _clean_filepath(model_path)
15+
if device is None:
16+
device = get_device(device)
17+
try:
18+
model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
19+
except Exception as e:
20+
print(e)
21+
print("model path", model_path)
22+
return None
23+
return model
24+
25+
1326
def get_model_path(model_type: str) -> str:
1427
"""Get the local path to a given model.
1528
@@ -35,19 +48,14 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
3548
Returns:
3649
The model.
3750
"""
38-
device = get_device(device)
51+
if device is None:
52+
device = get_device(device)
3953
model_path = get_model_path(model_type)
40-
warnings.filterwarnings(
41-
"ignore",
42-
message="You are using `torch.load` with `weights_only=False`",
43-
category=FutureWarning
44-
)
45-
model = torch.load(model_path)
54+
model = torch.load(model_path, weights_only=False)
4655
model.to(device)
4756
return model
4857

4958

50-
# TODO: distinguish between 2d and 3d vesicle model segmentation
5159
def run_segmentation(
5260
image: np.ndarray,
5361
model: torch.nn.Module,
@@ -60,12 +68,15 @@ def run_segmentation(
6068
"""Run synaptic structure segmentation.
6169
6270
Args:
63-
image: ...
64-
model: ...
65-
model_type: ...
66-
tiling: ...
67-
scale: ...
68-
verbose: ...
71+
image: The input image or image volume.
72+
model: The segmentation model.
73+
model_type: The model type. This will determine which segmentation
74+
post-processing is used.
75+
tiling: The tiling settings for inference.
76+
scale: A scale factor for resizing the input before applying the model.
77+
The output will be scaled back to the initial size.
78+
verbose: Whether to print detailed information about the prediction and segmentation.
79+
kwargs: Optional parameter for the segmentation function.
6980
7081
Returns:
7182
The segmentation.
@@ -234,3 +245,27 @@ def compute_scale_from_voxel_size(
234245
voxel_size["z"] / training_voxel_size["z"]
235246
)
236247
return scale
248+
249+
250+
def _clean_filepath(filepath):
251+
"""
252+
Cleans a given filepath by:
253+
- Removing newline characters (\n)
254+
- Removing escape sequences
255+
- Stripping the 'file://' prefix if present
256+
257+
Args:
258+
filepath (str): The original filepath
259+
260+
Returns:
261+
str: The cleaned filepath
262+
"""
263+
# Remove 'file://' prefix if present
264+
if filepath.startswith("file://"):
265+
filepath = filepath[7:]
266+
267+
# Remove escape sequences and newlines
268+
filepath = re.sub(r'\\.', '', filepath)
269+
filepath = filepath.replace('\n', '').replace('\r', '')
270+
271+
return filepath

0 commit comments

Comments
 (0)