1
- from typing import TYPE_CHECKING
2
- import h5py
3
- from magicgui import magic_factory , widgets
4
1
import napari
5
2
import napari .layers
6
3
from napari .utils .notifications import show_info
7
- from napari import Viewer
8
- import numpy as np
9
- from qtpy .QtWidgets import QWidget , QVBoxLayout , QHBoxLayout , QPushButton , QFileDialog , QLabel , QSpinBox , QLineEdit , QGroupBox , QFormLayout , QFrame , QComboBox
10
- from superqt import QCollapsible
11
- from elf .io import open_file
4
+ from qtpy .QtWidgets import QWidget , QVBoxLayout , QPushButton , QLabel , QComboBox
5
+
12
6
from .base_widget import BaseWidget
13
- import os
14
- from synaptic_reconstruction .inference .vesicles import segment_vesicles
7
+ from synaptic_reconstruction .training .supervised_training import get_2d_model
15
8
16
9
# Custom imports for model and prediction utilities
17
- from ..util import get_device , get_model_registry , run_prediction , _available_devices
10
+ from ..util import run_segmentation , get_model_registry , _available_devices
18
11
19
12
# if TYPE_CHECKING:
20
13
# import napari
@@ -28,6 +21,7 @@ def __init__(self):
28
21
super ().__init__ ()
29
22
30
23
self .model = None
24
+ self .image = None
31
25
self .viewer = napari .current_viewer ()
32
26
layout = QVBoxLayout ()
33
27
@@ -137,18 +131,19 @@ def on_predict(self):
137
131
if model_key == "- choose -" :
138
132
show_info ("Please choose a model." )
139
133
return
134
+ # loading model
140
135
141
-
142
136
model_registry = get_model_registry ()
143
- model_path = model_registry .fetch (model_key )
137
+ model_key = self .model_selector .currentText ()
138
+ model_path = "/home/freckmann15/.cache/synapse-net/models/vesicles" # model_registry.fetch(model_key)
139
+ # model = get_2d_model(out_channels=2)
140
+ # model = load_model_weights(model=model, model_path=model_path)
144
141
145
142
if self .image is None :
146
143
show_info ("Please choose an image." )
147
144
return
148
145
149
146
# get tile shape and halo from the viewer
150
- tile_shape = (self .tile_x_param .value (), self .tile_y_param .value ())
151
- halo = (self .halo_x_param .value (), self .halo_y_param .value ())
152
147
tiling = {
153
148
"tile" : {
154
149
"x" : self .tile_x_param .value (),
@@ -161,13 +156,24 @@ def on_predict(self):
161
156
"z" : 1
162
157
}
163
158
}
164
- segmentation = segment_vesicles (self .image , model_path = model_path ) #tiling=tiling
159
+ tile_shape = (self .tile_x_param .value (), self .tile_y_param .value ())
160
+ halo = (self .halo_x_param .value (), self .halo_y_param .value ())
161
+ use_custom_tiling = False
162
+ for ts , h in zip (tile_shape , halo ):
163
+ if ts != 0 or h != 0 : # if anything is changed from default
164
+ use_custom_tiling = True
165
+ if use_custom_tiling :
166
+ segmentation = run_segmentation (self .image , model_path = model_path , model_key = model_key , tiling = tiling )
167
+ else :
168
+ segmentation = run_segmentation (self .image , model_path = model_path , model_key = model_key )
169
+ # segmentation = np.random.randint(0, 256, size=self.image.shape, dtype=np.uint8)
170
+ self .viewer .add_image (segmentation , name = "Segmentation" , colormap = "inferno" , blending = "additive" )
165
171
# Add predictions to Napari as separate layers
166
172
# for i, pred in enumerate(segmentation):
167
173
# layer_name = f"Prediction {i+1}"
168
174
# self.viewer.add_image(pred, name=layer_name, colormap="inferno", blending="additive")
169
- layer_kwargs = {"colormap" : "inferno" , "blending" : "additive" }
170
- return segmentation , layer_kwargs
175
+ # layer_kwargs = {"colormap": "inferno", "blending": "additive"}
176
+ # return segmentation, layer_kwargs
171
177
172
178
def _create_settings_widget (self ):
173
179
setting_values = QWidget ()
@@ -183,15 +189,15 @@ def _create_settings_widget(self):
183
189
setting_values .layout ().addLayout (layout )
184
190
185
191
# Create UI for the tile shape.
186
- self .tile_x , self .tile_y = 256 , 256 # defaults
192
+ self .tile_x , self .tile_y = 0 , 0 # defaults
187
193
self .tile_x_param , self .tile_y_param , layout = self ._add_shape_param (
188
194
("tile_x" , "tile_y" ), (self .tile_x , self .tile_y ), min_val = 0 , max_val = 2048 , step = 16 ,
189
195
# tooltip=get_tooltip("embedding", "tiling")
190
196
)
191
197
setting_values .layout ().addLayout (layout )
192
198
193
199
# Create UI for the halo.
194
- self .halo_x , self .halo_y = 32 , 32 # defaults
200
+ self .halo_x , self .halo_y = 0 , 0 # defaults
195
201
self .halo_x_param , self .halo_y_param , layout = self ._add_shape_param (
196
202
("halo_x" , "halo_y" ), (self .halo_x , self .halo_y ), min_val = 0 , max_val = 512 ,
197
203
# tooltip=get_tooltip("embedding", "halo")
0 commit comments