Skip to content

Commit e0dfda6

Browse files
Merge branch 'main' into more-inner-ear-analysis
2 parents 93a66c1 + e95dd45 commit e0dfda6

12 files changed

+472
-97
lines changed

examples/domain_adaptation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""This script contains an example for using domain adptation to
2+
transfer a trained model for vesicle segmentation to a new dataset from a different data distribution,
3+
e.g. data from regular transmission electron microscopy (2D) instead of electron tomography or data from
4+
a different electron tomogram with different specimen and sample preparation.
5+
You don't need any annotations in the new domain to run this script.
6+
7+
You can download example data for this script from:
8+
- Adaptation to 2d TEM data: TODO zenodo link
9+
- Adaptation to different tomography data: TODO zenodo link
10+
"""
11+
12+
import os
13+
from glob import glob
14+
15+
from sklearn.model_selection import train_test_split
16+
from synaptic_reconstruction.training import mean_teacher_adaptation
17+
from synaptic_reconstruction.tools.util import get_model_path
18+
19+
20+
def main():
21+
# Choose whether to adapt the model to 2D or to 3D data.
22+
train_2d_model = True
23+
24+
# TODO adjust to zenodo downloads
25+
# These are the data folders for the example data downloaded from zenodo.
26+
# Update these paths to apply the script to your own data.
27+
# Check out the example data to see the data format for training.
28+
data_root_folder_2d = "./data/2d_tem/train_unlabeled"
29+
data_root_folder_3d = "./data/..."
30+
31+
# Choose the correct data folder depending on 2d/3d training.
32+
data_root_folder = data_root_folder_2d if train_2d_model else data_root_folder_3d
33+
34+
# Get all files with ending .h5 in the training folder.
35+
files = sorted(glob(os.path.join(data_root_folder, "**", "*.h5"), recursive=True))
36+
37+
# Crate a train / val split.
38+
train_ratio = 0.85
39+
train_paths, val_paths = train_test_split(files, test_size=1 - train_ratio, shuffle=True, random_state=42)
40+
41+
# Choose settings for the 2d or 3d domain adaptation.
42+
if train_2d_model:
43+
# This is the name of the checkpoint of the adapted model.
44+
# For the name here the checkpoint will be stored in './checkpoints/example-2d-adapted-model'
45+
model_name = "example-2d-adapted-model"
46+
# The training patch size.
47+
patch_shape = (256, 256)
48+
# The batch size for training. You can increase this if you have enough VRAM.
49+
batch_size = 4
50+
# Get the checkpoint of the pretrained model for 2d vesicle segmentation.
51+
source_checkpoint = get_model_path(model_type="vesicles_2d")
52+
else:
53+
# This is the name of the checkpoint of the adapted model.
54+
# For the name here the checkpoint will be stored in './checkpoints/example-3d-adapted-model'
55+
model_name = "example-3d-adapted-model"
56+
# The training patch size.
57+
patch_shape = (48, 256, 256)
58+
# The batch size for training. You can increase this if you have enough VRAM.
59+
batch_size = 1
60+
# Get the checkpoint of the pretrained model for d vesicle segmentation.
61+
source_checkpoint = get_model_path(model_type="vesicles_3d")
62+
63+
# We set the number of training iterations to 25,000.
64+
n_iterations = int(2.5e4)
65+
66+
# This function runs the domain adaptation. Check out its documentation for
67+
# advanced settings to update the training procedure.
68+
mean_teacher_adaptation(
69+
name=model_name,
70+
unsupervised_train_paths=train_paths,
71+
unsupervised_val_paths=val_paths,
72+
source_checkpoint=source_checkpoint,
73+
patch_shape=patch_shape,
74+
batch_size=batch_size,
75+
n_iterations=n_iterations,
76+
confidence_threshold=0.75,
77+
)
78+
79+
80+
if __name__ == "__main__":
81+
main()

examples/network_training.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""This script contains an example for how to train a network for
2+
a segmentation task with SynapseNet. This script covers the case of
3+
supervised training, i.e. your data needs to contain annotations for
4+
the structures you want to segment. If you want to use domain adaptation
5+
to adapt an already trained network to your data without the need for
6+
additional annotations then check out `domain_adaptation.py`.
7+
8+
You can download example data for this script from:
9+
TODO zenodo link to Single-Ax / Chemical Fix data.
10+
"""
11+
import os
12+
from glob import glob
13+
14+
from sklearn.model_selection import train_test_split
15+
from synaptic_reconstruction.training import supervised_training
16+
17+
18+
def main():
19+
# This is the folder that contains your training data.
20+
# The example was designed so that it runs for the sample data downloaded to './data'.
21+
# If you want to train on your own data than change this filepath accordingly.
22+
# TODO update to match zenodo download
23+
data_root_folder = "./data/vesicles/train"
24+
25+
# The training data should be saved as .h5 files, with:
26+
# an internal dataset called 'raw' that contains the image data
27+
# and another dataset that contains the training annotations.
28+
label_key = "labels/vesicles"
29+
30+
# Get all files with the ending .h5 in the training folder.
31+
files = sorted(glob(os.path.join(data_root_folder, "**", "*.h5"), recursive=True))
32+
33+
# Crate a train / val split.
34+
train_ratio = 0.85
35+
train_paths, val_paths = train_test_split(files, test_size=1 - train_ratio, shuffle=True, random_state=42)
36+
37+
# We can either train a 2d or a 3d model. Whether a 2d or a 3d model is trained is derived from the patch shape.
38+
# If your training data for 2d is stored as images (i.e. 2d data) them choose a patch shape of form Y x X,
39+
# e.g. (384, 384). If your data is stored in 3d, but you want to train a 2d model on it, choose a patch shape
40+
# of the form 1 x Y x X, e.g. (1, 384, 384).
41+
# If you want to train a 3d model then choose a patch shape of form Z x Y x X, e.g. (48, 256, 256).
42+
train_2d_model = True
43+
if train_2d_model:
44+
batch_size = 2 # You can increase the batch size if you have enough VRAM.
45+
# The model name determines the name of the checkpoint. E.g., for the name here the checkpoint will
46+
# be saved at: 'checkpoints/example-2d-vesicle-model/'.
47+
model_name = "example-2d-vesicle-model"
48+
# The patch shape for training. See futher explanations above.
49+
patch_shape = (1, 384, 384)
50+
else:
51+
batch_size = 1 # You can increase the batch size if you have enough VRAM.
52+
# See the explanations for model_name and patch_shape above.
53+
model_name = "example-3d-vesicle-model"
54+
patch_shape = (48, 256, 256)
55+
56+
# If check_loader is set to True the training samples will be visualized via napari
57+
# instead of starting a training. This is useful to validate that the training data
58+
# is read correctly.
59+
check_loader = False
60+
61+
# This function runs the training. Check out its documentation for
62+
# advanced settings to update the training procedure.
63+
supervised_training(
64+
name=model_name,
65+
train_paths=train_paths,
66+
val_paths=val_paths,
67+
label_key=label_key,
68+
patch_shape=patch_shape,
69+
batch_size=batch_size,
70+
n_samples_train=None,
71+
n_samples_val=25,
72+
check=check_loader,
73+
)
74+
75+
76+
if __name__ == "__main__":
77+
main()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
license="MIT",
1414
entry_points={
1515
"console_scripts": [
16-
"synapse_net.run_segmentation = synaptic_reconstruction.tools.cli:segmentation_cli"
16+
"synapse_net.run_segmentation = synaptic_reconstruction.tools.cli:segmentation_cli",
1717
],
1818
"napari.manifest": [
1919
"synaptic_reconstruction = synaptic_reconstruction:napari.yaml",

synaptic_reconstruction/tools/base_widget.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _update_selector(self, selector, layer_filter):
6161
image_layers = [layer.name for layer in self.viewer.layers if isinstance(layer, layer_filter)] # if isinstance(layer, napari.layers.Image)
6262
selector.addItems(image_layers)
6363

64-
def _get_layer_selector_data(self, selector_name):
64+
def _get_layer_selector_data(self, selector_name, return_metadata=False):
6565
"""Return the data for the layer currently selected in a given selector."""
6666
if selector_name in self.layer_selectors:
6767
selector_widget = self.layer_selectors[selector_name]
@@ -72,7 +72,10 @@ def _get_layer_selector_data(self, selector_name):
7272
if isinstance(image_selector, QComboBox):
7373
selected_layer_name = image_selector.currentText()
7474
if selected_layer_name in self.viewer.layers:
75-
return self.viewer.layers[selected_layer_name].data
75+
if return_metadata:
76+
return self.viewer.layers[selected_layer_name].metadata
77+
else:
78+
return self.viewer.layers[selected_layer_name].data
7679
return None # Return None if layer not found
7780

7881
def _add_string_param(self, name, value, title=None, placeholder=None, layout=None, tooltip=None):
@@ -169,6 +172,15 @@ def _add_shape_param(self, names, values, min_val, max_val, step=1, title=None,
169172
title=title[1] if title is not None else title, tooltip=tooltip
170173
)
171174
layout.addLayout(y_layout)
175+
176+
if len(names) == 3:
177+
z_layout = QVBoxLayout()
178+
z_param, _ = self._add_int_param(
179+
names[2], values[2], min_val=min_val, max_val=max_val, layout=z_layout, step=step,
180+
title=title[2] if title is not None else title, tooltip=tooltip
181+
)
182+
layout.addLayout(z_layout)
183+
return x_param, y_param, z_param, layout
172184

173185
return x_param, y_param, layout
174186

synaptic_reconstruction/tools/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
# TODO: handle kwargs
9+
# TODO: add custom model path
910
def segmentation_cli():
1011
parser = argparse.ArgumentParser(description="Run segmentation.")
1112
parser.add_argument(
@@ -16,6 +17,7 @@ def segmentation_cli():
1617
"--output_path", "-o", required=True,
1718
help="The filepath to directory where the segmentations will be saved."
1819
)
20+
# TODO: list the availabel models here by parsing the keys of the model registry
1921
parser.add_argument(
2022
"--model", "-m", required=True, help="The model type."
2123
)

synaptic_reconstruction/tools/distance_measure_widget.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,33 +101,53 @@ def _add_lines_and_table(self, lines, properties, table_data, name):
101101
def on_measure_seg_to_object(self):
102102
segmentation = self._get_layer_selector_data(self.image_selector_name1)
103103
object_data = self._get_layer_selector_data(self.image_selector_name2)
104+
# get metadata from layer if available
105+
metadata = self._get_layer_selector_data(self.image_selector_name1, return_metadata=True)
106+
resolution = metadata.get("voxel_size", None)
107+
if resolution is not None:
108+
resolution = [v for v in resolution.values()]
109+
# if user input is present override metadata
110+
if self.voxel_size_param.value() != 0.0: # changed from default
111+
resolution = segmentation.ndim * [self.voxel_size_param.value()]
104112

105113
(distances,
106114
endpoints1,
107115
endpoints2,
108116
seg_ids) = distance_measurements.measure_segmentation_to_object_distances(
109117
segmentation=segmentation, segmented_object=object_data, distance_type="boundary",
118+
resolution=resolution
110119
)
111120
lines, properties = distance_measurements.create_object_distance_lines(
112121
distances=distances,
113122
endpoints1=endpoints1,
114123
endpoints2=endpoints2,
115-
seg_ids=seg_ids
124+
seg_ids=seg_ids,
116125
)
117126
table_data = self._to_table_data(distances, seg_ids, endpoints1, endpoints2)
118127
self._add_lines_and_table(lines, properties, table_data, name="distances")
119128

120129
def on_measure_pairwise(self):
121130
segmentation = self._get_layer_selector_data(self.image_selector_name1)
131+
if segmentation is None:
132+
show_info("Please choose a segmentation.")
133+
return
134+
# get metadata from layer if available
135+
metadata = self._get_layer_selector_data(self.image_selector_name1, return_metadata=True)
136+
resolution = metadata.get("voxel_size", None)
137+
if resolution is not None:
138+
resolution = [v for v in resolution.values()]
139+
# if user input is present override metadata
140+
if self.voxel_size_param.value() != 0.0: # changed from default
141+
resolution = segmentation.ndim * [self.voxel_size_param.value()]
122142

123143
(distances,
124144
endpoints1,
125145
endpoints2,
126146
seg_ids) = distance_measurements.measure_pairwise_object_distances(
127-
segmentation=segmentation, distance_type="boundary"
147+
segmentation=segmentation, distance_type="boundary", resolution=resolution
128148
)
129149
lines, properties = distance_measurements.create_pairwise_distance_lines(
130-
distances=distances, endpoints1=endpoints1, endpoints2=endpoints2, seg_ids=seg_ids.tolist()
150+
distances=distances, endpoints1=endpoints1, endpoints2=endpoints2, seg_ids=seg_ids.tolist(),
131151
)
132152
table_data = self._to_table_data(
133153
distances=properties["distance"],
@@ -142,5 +162,10 @@ def _create_settings_widget(self):
142162
self.save_path, layout = self._add_path_param(name="Save Table", select_type="file", value="")
143163
setting_values.layout().addLayout(layout)
144164

165+
self.voxel_size_param, layout = self._add_float_param(
166+
"voxel_size", 0.0, min_val=0.0, max_val=100.0,
167+
)
168+
setting_values.layout().addLayout(layout)
169+
145170
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
146171
return settings

0 commit comments

Comments
 (0)