4
4
from qtpy .QtWidgets import QWidget , QVBoxLayout , QPushButton , QLabel , QComboBox
5
5
6
6
from .base_widget import BaseWidget
7
- from .util import run_segmentation , get_model_registry , _available_devices
7
+ from .util import run_segmentation , get_model , get_model_registry , _available_devices
8
8
9
9
10
10
class SegmentationWidget (BaseWidget ):
11
11
def __init__ (self ):
12
12
super ().__init__ ()
13
13
14
- self .model = None
15
14
self .viewer = napari .current_viewer ()
16
15
layout = QVBoxLayout ()
17
16
18
- # Create the image selection dropdown
17
+ # Create the image selection dropdown.
19
18
self .image_selector_name = "Image data"
20
19
self .image_selector_widget = self ._create_layer_selector (self .image_selector_name , layer_type = "Image" )
21
20
22
- # create buttons
21
+ # Create buttons and widgets.
23
22
self .predict_button = QPushButton ("Run Segmentation" )
24
-
25
- # Connect buttons to functions
26
23
self .predict_button .clicked .connect (self .on_predict )
27
- # self.load_model_button.clicked.connect(self.on_load_model)
28
-
29
- # create model selector
30
24
self .model_selector_widget = self .load_model_widget ()
31
-
32
- # create advanced settings
33
25
self .settings = self ._create_settings_widget ()
34
26
35
- # Add the widgets to the layout
27
+ # Add the widgets to the layout.
36
28
layout .addWidget (self .image_selector_widget )
37
29
layout .addWidget (self .model_selector_widget )
38
30
layout .addWidget (self .settings )
@@ -44,8 +36,7 @@ def load_model_widget(self):
44
36
model_widget = QWidget ()
45
37
title_label = QLabel ("Select Model:" )
46
38
47
- models = list (get_model_registry ().urls .keys ())
48
- self .model = None # set default model
39
+ models = ["- choose -" ] + list (get_model_registry ().urls .keys ())
49
40
self .model_selector = QComboBox ()
50
41
self .model_selector .addItems (models )
51
42
# Create a layout and add the title label and combo box
@@ -59,23 +50,20 @@ def load_model_widget(self):
59
50
60
51
def on_predict (self ):
61
52
# Get the model and postprocessing settings.
62
- model_key = self .model_selector .currentText ()
63
- if model_key == "- choose -" :
53
+ model_type = self .model_selector .currentText ()
54
+ if model_type == "- choose -" :
64
55
show_info ("Please choose a model." )
65
56
return
66
57
67
- # loading model
68
- model_registry = get_model_registry ()
69
- model_key = self .model_selector .currentText ()
70
- # model_path = "/home/freckmann15/.cache/synapse-net/models/vesicles" #
71
- model_path = model_registry .fetch (model_key )
72
- # get image data
58
+ # Load the model.
59
+ model = get_model (model_type , self .device )
60
+
61
+ # Get the image data.
73
62
image = self ._get_layer_selector_data (self .image_selector_name )
74
63
if image is None :
75
64
show_info ("Please choose an image." )
76
65
return
77
66
78
- # FIXME: don't hard-code tiling here, but figure it out centrally in the prediction function.
79
67
# get tile shape and halo from the viewer
80
68
tiling = {
81
69
"tile" : {
@@ -89,26 +77,16 @@ def on_predict(self):
89
77
"z" : 1
90
78
}
91
79
}
92
- tile_shape = (self .tile_x_param .value (), self .tile_y_param .value ())
93
- halo = (self .halo_x_param .value (), self .halo_y_param .value ())
94
- use_custom_tiling = False
95
- for ts , h in zip (tile_shape , halo ):
96
- if ts != 0 or h != 0 : # if anything changed from default
97
- use_custom_tiling = True
98
- if use_custom_tiling :
99
- segmentation = run_segmentation (
100
- image , model_path = model_path , model_key = model_key ,
101
- tiling = tiling , scale = self .scale_param .value ()
102
- )
103
- else :
104
- segmentation = run_segmentation (
105
- image , model_path = model_path , model_key = model_key ,
106
- scale = self .scale_param .value ()
107
- )
80
+
81
+ # TODO: Use scale derived from the image resolution.
82
+ scale = [self .scale_param .value ()]
83
+ segmentation = run_segmentation (
84
+ image , model = model , model_type = model_type , tiling = tiling , scale = scale
85
+ )
108
86
109
87
# Add the segmentation layer
110
- self .viewer .add_labels (segmentation , name = f"{ model_key } -segmentation" )
111
- show_info (f"Segmentation of { model_key } added to layers." )
88
+ self .viewer .add_labels (segmentation , name = f"{ model_type } -segmentation" )
89
+ show_info (f"Segmentation of { model_type } added to layers." )
112
90
113
91
def _create_settings_widget (self ):
114
92
setting_values = QWidget ()
@@ -123,6 +101,7 @@ def _create_settings_widget(self):
123
101
setting_values .layout ().addLayout (layout )
124
102
125
103
# Create UI for the tile shape.
104
+ # TODO: make the tiling 3d and get the default values from 'inference'
126
105
self .tile_x , self .tile_y = 512 , 512 # defaults
127
106
self .tile_x_param , self .tile_y_param , layout = self ._add_shape_param (
128
107
("tile_x" , "tile_y" ), (self .tile_x , self .tile_y ), min_val = 0 , max_val = 2048 , step = 16 ,
0 commit comments