Skip to content

Commit 7af1511

Browse files
committed
added dynamic switch between segmentation functions absed on the model used
1 parent ec1f050 commit 7af1511

File tree

3 files changed

+69
-37
lines changed

3 files changed

+69
-37
lines changed

environment.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,8 @@ dependencies:
99
- pyqt
1010
- magicgui
1111
- pytorch
12+
- bioimageio.core
13+
- kornia
14+
- tensorboard
1215
- pip:
1316
- napari-skimage-regionprops

synaptic_reconstruction/tools/synaptic_plugin/segmentation.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
1-
from typing import TYPE_CHECKING
2-
import h5py
3-
from magicgui import magic_factory, widgets
41
import napari
52
import napari.layers
63
from napari.utils.notifications import show_info
7-
from napari import Viewer
8-
import numpy as np
9-
from qtpy.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QSpinBox, QLineEdit, QGroupBox, QFormLayout, QFrame, QComboBox
10-
from superqt import QCollapsible
11-
from elf.io import open_file
4+
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
5+
126
from .base_widget import BaseWidget
13-
import os
14-
from synaptic_reconstruction.inference.vesicles import segment_vesicles
7+
from synaptic_reconstruction.training.supervised_training import get_2d_model
158

169
# Custom imports for model and prediction utilities
17-
from ..util import get_device, get_model_registry, run_prediction, _available_devices
10+
from ..util import run_segmentation, get_model_registry, _available_devices
1811

1912
# if TYPE_CHECKING:
2013
# import napari
@@ -28,6 +21,7 @@ def __init__(self):
2821
super().__init__()
2922

3023
self.model = None
24+
self.image = None
3125
self.viewer = napari.current_viewer()
3226
layout = QVBoxLayout()
3327

@@ -137,18 +131,19 @@ def on_predict(self):
137131
if model_key == "- choose -":
138132
show_info("Please choose a model.")
139133
return
134+
# loading model
140135

141-
142136
model_registry = get_model_registry()
143-
model_path = model_registry.fetch(model_key)
137+
model_key = self.model_selector.currentText()
138+
model_path = "/home/freckmann15/.cache/synapse-net/models/vesicles" # model_registry.fetch(model_key)
139+
# model = get_2d_model(out_channels=2)
140+
# model = load_model_weights(model=model, model_path=model_path)
144141

145142
if self.image is None:
146143
show_info("Please choose an image.")
147144
return
148145

149146
# get tile shape and halo from the viewer
150-
tile_shape = (self.tile_x_param.value(), self.tile_y_param.value())
151-
halo = (self.halo_x_param.value(), self.halo_y_param.value())
152147
tiling = {
153148
"tile": {
154149
"x": self.tile_x_param.value(),
@@ -161,13 +156,24 @@ def on_predict(self):
161156
"z": 1
162157
}
163158
}
164-
segmentation = segment_vesicles(self.image, model_path=model_path) #tiling=tiling
159+
tile_shape = (self.tile_x_param.value(), self.tile_y_param.value())
160+
halo = (self.halo_x_param.value(), self.halo_y_param.value())
161+
use_custom_tiling = False
162+
for ts, h in zip(tile_shape, halo):
163+
if ts != 0 or h != 0: # if anything is changed from default
164+
use_custom_tiling = True
165+
if use_custom_tiling:
166+
segmentation = run_segmentation(self.image, model_path=model_path, model_key=model_key, tiling=tiling)
167+
else:
168+
segmentation = run_segmentation(self.image, model_path=model_path, model_key=model_key)
169+
# segmentation = np.random.randint(0, 256, size=self.image.shape, dtype=np.uint8)
170+
self.viewer.add_image(segmentation, name="Segmentation", colormap="inferno", blending="additive")
165171
# Add predictions to Napari as separate layers
166172
# for i, pred in enumerate(segmentation):
167173
# layer_name = f"Prediction {i+1}"
168174
# self.viewer.add_image(pred, name=layer_name, colormap="inferno", blending="additive")
169-
layer_kwargs = {"colormap": "inferno", "blending": "additive"}
170-
return segmentation, layer_kwargs
175+
# layer_kwargs = {"colormap": "inferno", "blending": "additive"}
176+
# return segmentation, layer_kwargs
171177

172178
def _create_settings_widget(self):
173179
setting_values = QWidget()
@@ -183,15 +189,15 @@ def _create_settings_widget(self):
183189
setting_values.layout().addLayout(layout)
184190

185191
# Create UI for the tile shape.
186-
self.tile_x, self.tile_y = 256, 256 # defaults
192+
self.tile_x, self.tile_y = 0, 0 # defaults
187193
self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
188194
("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
189195
# tooltip=get_tooltip("embedding", "tiling")
190196
)
191197
setting_values.layout().addLayout(layout)
192198

193199
# Create UI for the halo.
194-
self.halo_x, self.halo_y = 32, 32 # defaults
200+
self.halo_x, self.halo_y = 0, 0 # defaults
195201
self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
196202
("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
197203
# tooltip=get_tooltip("embedding", "halo")

synaptic_reconstruction/tools/util.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,46 @@
77
import pooch
88
import requests
99
from torch_em.util.prediction import predict_with_halo
10+
from synaptic_reconstruction.inference.vesicles import segment_vesicles
11+
from synaptic_reconstruction.inference.mitochondria import segment_mitochondria
12+
13+
14+
def run_segmentation(image, model_path, model_key, tiling=None):
15+
if model_key == "vesicles":
16+
segmentation = segment_vesicles(image, model_path=model_path, tiling=tiling)
17+
elif model_key == "mitochondria":
18+
segmentation = segment_mitochondria(image, model_path=model_path, tiling=tiling)
19+
return segmentation
20+
21+
22+
def load_model_weights(model, model_path):
23+
# Load the entire checkpoint
24+
checkpoint = torch.load(model_path, map_location=get_device()) # ["model_state"]
25+
checkpoint_state_dict = checkpoint.get("state_dict", checkpoint)
26+
# Filter out any non-model keys (e.g., training state, optimizer state)
27+
model_state_dict = {k: v for k, v in checkpoint_state_dict.items() if k in model.state_dict()}
28+
return model.load_state_dict(model_state_dict, strict=False).eval()
29+
30+
# Extract the 'state_dict' from the checkpoint
31+
state_dict = checkpoint.get('state_dict', checkpoint)
32+
33+
# Filter out unnecessary keys from the state_dict
34+
model_state_dict = model.state_dict()
35+
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
36+
37+
# Check for missing keys (optional)
38+
missing_keys = set(model_state_dict.keys()) - set(filtered_state_dict.keys())
39+
if missing_keys:
40+
print(f"Warning: Missing keys in state_dict: {missing_keys}")
41+
42+
# Check for unexpected keys (optional)
43+
unexpected_keys = set(filtered_state_dict.keys()) - set(model_state_dict.keys())
44+
if unexpected_keys:
45+
print(f"Warning: Unexpected keys in state_dict: {unexpected_keys}")
46+
47+
# Load the filtered state_dict into the model
48+
model.load_state_dict(filtered_state_dict)
49+
return model
1050

1151

1252
def download_and_organize_file(url, mode_type, app_name="synapse-net/models"):
@@ -30,23 +70,6 @@ def download_and_organize_file(url, mode_type, app_name="synapse-net/models"):
3070
os.rename(file_path, os.path.join(dir_name, mode_type, "best.pt"))
3171
return os.path.join(dir_name, mode_type, "best.pt")
3272

33-
34-
# # If file_path is a directory or has no extension, organize it into a folder with "best.pt"
35-
# if os.path.isdir(file_path) or os.path.splitext(file_path)[1] == "":
36-
# # Create a directory for the file if needed
37-
# dir_path = file_path if os.path.isdir(file_path) else os.path.splitext(file_path)[0]
38-
# os.makedirs(dir_path, exist_ok=True)
39-
40-
# # Move the file into the directory with the name "best.pt"
41-
# new_file_path = os.path.join(dir_path, "best.pt")
42-
# if file_path != new_file_path:
43-
# os.rename(file_path, new_file_path)
44-
45-
# return dir_path
46-
# else:
47-
# # Return the directory where the file resides
48-
# return os.path.dirname(file_path)
49-
5073

5174
def organize_file_path(path):
5275
# Check if path is a file or directory

0 commit comments

Comments
 (0)