4
4
5
5
from .base_widget import BaseWidget
6
6
from .util import (run_segmentation , get_model , get_model_registry , _available_devices , get_device ,
7
- get_current_tiling , compute_scale_from_voxel_size )
7
+ get_current_tiling , compute_scale_from_voxel_size , load_custom_model )
8
8
from synaptic_reconstruction .inference .util import get_default_tiling
9
9
import copy
10
10
@@ -54,27 +54,34 @@ def load_model_widget(self):
54
54
def on_predict (self ):
55
55
# Get the model and postprocessing settings.
56
56
model_type = self .model_selector .currentText ()
57
- if model_type == "- choose -" :
58
- show_info ("Please choose a model." )
57
+ custom_model_path = self .checkpoint_param .text ()
58
+ if model_type == "- choose -" and custom_model_path is None :
59
+ show_info ("INFO: Please choose a model." )
59
60
return
60
61
61
- # Load the model.
62
62
device = get_device (self .device_dropdown .currentText ())
63
- model = get_model (model_type , device )
63
+
64
+ # Load the model. Override if user chose custom model
65
+ if custom_model_path :
66
+ model = load_custom_model (custom_model_path , device )
67
+ if model :
68
+ show_info (f"INFO: Using custom model from path: { custom_model_path } " )
69
+ model_type = "custom"
70
+ else :
71
+ show_info (f"ERROR: Failed to load custom model from path: { custom_model_path } " )
72
+ return
73
+ else :
74
+ model = get_model (model_type , device )
64
75
65
76
# Get the image data.
66
77
image = self ._get_layer_selector_data (self .image_selector_name )
67
78
if image is None :
68
- show_info ("Please choose an image." )
79
+ show_info ("INFO: Please choose an image." )
69
80
return
70
81
71
82
# load current tiling
72
83
self .tiling = get_current_tiling (self .tiling , self .default_tiling , model_type )
73
84
74
- # TODO: Use scale derived from the image resolution.
75
- # get avg image shape from training of the selected model
76
- # wichmann data avg voxel size = 17.53
77
-
78
85
metadata = self ._get_layer_selector_data (self .image_selector_name , return_metadata = True )
79
86
voxel_size = metadata .get ("voxel_size" , None )
80
87
scale = None
@@ -90,17 +97,20 @@ def on_predict(self):
90
97
voxel_size ["x" ] = self .voxel_size_param .value ()
91
98
voxel_size ["y" ] = self .voxel_size_param .value ()
92
99
if voxel_size :
93
- # calculate scale so voxel_size is the same as in training
94
- scale = compute_scale_from_voxel_size (voxel_size , model_type )
95
- show_info (f"Rescaled the image by { scale } to optimize for the selected model." )
100
+ if model_type == "custom" :
101
+ show_info ("INFO: The image is not rescaled for a custom model." )
102
+ else :
103
+ # calculate scale so voxel_size is the same as in training
104
+ scale = compute_scale_from_voxel_size (voxel_size , model_type )
105
+ show_info (f"INFO: Rescaled the image by { scale } to optimize for the selected model." )
96
106
97
107
segmentation = run_segmentation (
98
108
image , model = model , model_type = model_type , tiling = self .tiling , scale = scale
99
109
)
100
110
101
111
# Add the segmentation layer
102
112
self .viewer .add_labels (segmentation , name = f"{ model_type } -segmentation" , metadata = metadata )
103
- show_info (f"Segmentation of { model_type } added to layers." )
113
+ show_info (f"INFO: Segmentation of { model_type } added to layers." )
104
114
105
115
def _create_settings_widget (self ):
106
116
setting_values = QWidget ()
@@ -141,5 +151,11 @@ def _create_settings_widget(self):
141
151
)
142
152
setting_values .layout ().addLayout (layout )
143
153
154
+ self .checkpoint_param , layout = self ._add_string_param (
155
+ name = "checkpoint" , value = "" , title = "Load Custom Model" ,
156
+ placeholder = "path/to/checkpoint.pt" ,
157
+ )
158
+ setting_values .layout ().addLayout (layout )
159
+
144
160
settings = self ._make_collapsible (widget = setting_values , title = "Advanced Settings" )
145
161
return settings
0 commit comments