Skip to content

Commit a974e65

Browse files
Add examples for network training (#61)
Add examples for network training
1 parent c396113 commit a974e65

File tree

7 files changed

+199
-36
lines changed

7 files changed

+199
-36
lines changed

examples/domain_adaptation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""This script contains an example for using domain adptation to
2+
transfer a trained model for vesicle segmentation to a new dataset from a different data distribution,
3+
e.g. data from regular transmission electron microscopy (2D) instead of electron tomography or data from
4+
a different electron tomogram with different specimen and sample preparation.
5+
You don't need any annotations in the new domain to run this script.
6+
7+
You can download example data for this script from:
8+
- Adaptation to 2d TEM data: TODO zenodo link
9+
- Adaptation to different tomography data: TODO zenodo link
10+
"""
11+
12+
import os
13+
from glob import glob
14+
15+
from sklearn.model_selection import train_test_split
16+
from synaptic_reconstruction.training import mean_teacher_adaptation
17+
from synaptic_reconstruction.tools.util import get_model_path
18+
19+
20+
def main():
21+
# Choose whether to adapt the model to 2D or to 3D data.
22+
train_2d_model = True
23+
24+
# TODO adjust to zenodo downloads
25+
# These are the data folders for the example data downloaded from zenodo.
26+
# Update these paths to apply the script to your own data.
27+
# Check out the example data to see the data format for training.
28+
data_root_folder_2d = "./data/2d_tem/train_unlabeled"
29+
data_root_folder_3d = "./data/..."
30+
31+
# Choose the correct data folder depending on 2d/3d training.
32+
data_root_folder = data_root_folder_2d if train_2d_model else data_root_folder_3d
33+
34+
# Get all files with ending .h5 in the training folder.
35+
files = sorted(glob(os.path.join(data_root_folder, "**", "*.h5"), recursive=True))
36+
37+
# Crate a train / val split.
38+
train_ratio = 0.85
39+
train_paths, val_paths = train_test_split(files, test_size=1 - train_ratio, shuffle=True, random_state=42)
40+
41+
# Choose settings for the 2d or 3d domain adaptation.
42+
if train_2d_model:
43+
# This is the name of the checkpoint of the adapted model.
44+
# For the name here the checkpoint will be stored in './checkpoints/example-2d-adapted-model'
45+
model_name = "example-2d-adapted-model"
46+
# The training patch size.
47+
patch_shape = (256, 256)
48+
# The batch size for training. You can increase this if you have enough VRAM.
49+
batch_size = 4
50+
# Get the checkpoint of the pretrained model for 2d vesicle segmentation.
51+
source_checkpoint = get_model_path(model_type="vesicles_2d")
52+
else:
53+
# This is the name of the checkpoint of the adapted model.
54+
# For the name here the checkpoint will be stored in './checkpoints/example-3d-adapted-model'
55+
model_name = "example-3d-adapted-model"
56+
# The training patch size.
57+
patch_shape = (48, 256, 256)
58+
# The batch size for training. You can increase this if you have enough VRAM.
59+
batch_size = 1
60+
# Get the checkpoint of the pretrained model for d vesicle segmentation.
61+
source_checkpoint = get_model_path(model_type="vesicles_3d")
62+
63+
# We set the number of training iterations to 25,000.
64+
n_iterations = int(2.5e4)
65+
66+
# This function runs the domain adaptation. Check out its documentation for
67+
# advanced settings to update the training procedure.
68+
mean_teacher_adaptation(
69+
name=model_name,
70+
unsupervised_train_paths=train_paths,
71+
unsupervised_val_paths=val_paths,
72+
source_checkpoint=source_checkpoint,
73+
patch_shape=patch_shape,
74+
batch_size=batch_size,
75+
n_iterations=n_iterations,
76+
confidence_threshold=0.75,
77+
)
78+
79+
80+
if __name__ == "__main__":
81+
main()

examples/network_training.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""This script contains an example for how to train a network for
2+
a segmentation task with SynapseNet. This script covers the case of
3+
supervised training, i.e. your data needs to contain annotations for
4+
the structures you want to segment. If you want to use domain adaptation
5+
to adapt an already trained network to your data without the need for
6+
additional annotations then check out `domain_adaptation.py`.
7+
8+
You can download example data for this script from:
9+
TODO zenodo link to Single-Ax / Chemical Fix data.
10+
"""
11+
import os
12+
from glob import glob
13+
14+
from sklearn.model_selection import train_test_split
15+
from synaptic_reconstruction.training import supervised_training
16+
17+
18+
def main():
19+
# This is the folder that contains your training data.
20+
# The example was designed so that it runs for the sample data downloaded to './data'.
21+
# If you want to train on your own data than change this filepath accordingly.
22+
# TODO update to match zenodo download
23+
data_root_folder = "./data/vesicles/train"
24+
25+
# The training data should be saved as .h5 files, with:
26+
# an internal dataset called 'raw' that contains the image data
27+
# and another dataset that contains the training annotations.
28+
label_key = "labels/vesicles"
29+
30+
# Get all files with the ending .h5 in the training folder.
31+
files = sorted(glob(os.path.join(data_root_folder, "**", "*.h5"), recursive=True))
32+
33+
# Crate a train / val split.
34+
train_ratio = 0.85
35+
train_paths, val_paths = train_test_split(files, test_size=1 - train_ratio, shuffle=True, random_state=42)
36+
37+
# We can either train a 2d or a 3d model. Whether a 2d or a 3d model is trained is derived from the patch shape.
38+
# If your training data for 2d is stored as images (i.e. 2d data) them choose a patch shape of form Y x X,
39+
# e.g. (384, 384). If your data is stored in 3d, but you want to train a 2d model on it, choose a patch shape
40+
# of the form 1 x Y x X, e.g. (1, 384, 384).
41+
# If you want to train a 3d model then choose a patch shape of form Z x Y x X, e.g. (48, 256, 256).
42+
train_2d_model = True
43+
if train_2d_model:
44+
batch_size = 2 # You can increase the batch size if you have enough VRAM.
45+
# The model name determines the name of the checkpoint. E.g., for the name here the checkpoint will
46+
# be saved at: 'checkpoints/example-2d-vesicle-model/'.
47+
model_name = "example-2d-vesicle-model"
48+
# The patch shape for training. See futher explanations above.
49+
patch_shape = (1, 384, 384)
50+
else:
51+
batch_size = 1 # You can increase the batch size if you have enough VRAM.
52+
# See the explanations for model_name and patch_shape above.
53+
model_name = "example-3d-vesicle-model"
54+
patch_shape = (48, 256, 256)
55+
56+
# If check_loader is set to True the training samples will be visualized via napari
57+
# instead of starting a training. This is useful to validate that the training data
58+
# is read correctly.
59+
check_loader = False
60+
61+
# This function runs the training. Check out its documentation for
62+
# advanced settings to update the training procedure.
63+
supervised_training(
64+
name=model_name,
65+
train_paths=train_paths,
66+
val_paths=val_paths,
67+
label_key=label_key,
68+
patch_shape=patch_shape,
69+
batch_size=batch_size,
70+
n_samples_train=None,
71+
n_samples_val=25,
72+
check=check_loader,
73+
)
74+
75+
76+
if __name__ == "__main__":
77+
main()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
license="MIT",
1414
entry_points={
1515
"console_scripts": [
16-
"synapse_net.run_segmentation = synaptic_reconstruction.tools.cli:segmentation_cli"
16+
"synapse_net.run_segmentation = synaptic_reconstruction.tools.cli:segmentation_cli",
1717
],
1818
"napari.manifest": [
1919
"synaptic_reconstruction = synaptic_reconstruction:napari.yaml",

synaptic_reconstruction/tools/util.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010
from ..inference.mitochondria import segment_mitochondria
1111

1212

13+
def get_model_path(model_type: str) -> str:
14+
"""Get the local path to a given model.
15+
16+
Args:
17+
The model type.
18+
19+
Returns:
20+
The local path to the model.
21+
"""
22+
model_registry = get_model_registry()
23+
model_path = model_registry.fetch(model_type)
24+
return model_path
25+
26+
1327
def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
1428
"""Get the model for the given segmentation type.
1529
@@ -22,8 +36,7 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
2236
The model.
2337
"""
2438
device = get_device(device)
25-
model_registry = get_model_registry()
26-
model_path = model_registry.fetch(model_type)
39+
model_path = get_model_path(model_type)
2740
warnings.filterwarnings(
2841
"ignore",
2942
message="You are using `torch.load` with `weights_only=False`",

synaptic_reconstruction/training/domain_adaptation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1+
import os
12
from typing import Optional, Tuple
23

34
import torch
45
import torch_em
56
import torch_em.self_training as self_training
67

78
from .semisupervised_training import get_unsupervised_loader
8-
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader
9+
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, determine_ndim
910

1011

1112
def mean_teacher_adaptation(
1213
name: str,
1314
unsupervised_train_paths: Tuple[str],
1415
unsupervised_val_paths: Tuple[str],
1516
patch_shape: Tuple[int, int, int],
16-
save_root: str,
17+
save_root: Optional[str] = None,
1718
source_checkpoint: Optional[str] = None,
1819
supervised_train_paths: Optional[Tuple[str]] = None,
1920
supervised_val_paths: Optional[Tuple[str]] = None,
@@ -70,22 +71,24 @@ def mean_teacher_adaptation(
7071
based on the patch_shape and size of the volumes used for validation.
7172
"""
7273
assert (supervised_train_paths is None) == (supervised_val_paths is None)
74+
is_2d, _ = determine_ndim(patch_shape)
7375

7476
if source_checkpoint is None:
7577
# training from scratch only makes sense if we have supervised training data
7678
# that's why we have the assertion here.
7779
assert supervised_train_paths is not None
7880
print("Mean teacher training from scratch (AdaMT)")
79-
# TODO determine 2d vs 3d
80-
is_2d = False
8181
if is_2d:
8282
model = get_2d_model(out_channels=2)
8383
else:
8484
model = get_3d_model(out_channels=2)
8585
reinit_teacher = True
8686
else:
8787
print("Mean teacehr training initialized from source model:", source_checkpoint)
88-
model = torch_em.util.load_model(source_checkpoint)
88+
if os.path.isdir(source_checkpoint):
89+
model = torch_em.util.load_model(source_checkpoint)
90+
else:
91+
model = torch.load(source_checkpoint)
8992
reinit_teacher = False
9093

9194
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

synaptic_reconstruction/training/semisupervised_training.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch_em.self_training as self_training
77
from torchvision import transforms
88

9-
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader
9+
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, determine_ndim
1010

1111

1212
def weak_augmentations(p: float = 0.75) -> callable:
@@ -61,14 +61,7 @@ def get_unsupervised_loader(
6161
else:
6262
roi = None
6363

64-
if len(patch_shape) == 2:
65-
ndim = 2
66-
else:
67-
assert len(patch_shape) == 3
68-
z, y, x = patch_shape
69-
ndim = 2 if z == 1 else 3
70-
print("ndim is: ", ndim)
71-
64+
_, ndim = determine_ndim(patch_shape)
7265
raw_transform = torch_em.transform.get_raw_transform()
7366
transform = torch_em.transform.get_augmentations(ndim=ndim)
7467

synaptic_reconstruction/training/supervised_training.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ def adjust_patch_shape(data_shape, patch_shape):
6767
return patch_shape # Return the original patch_shape for 3D data
6868

6969

70+
def determine_ndim(patch_shape):
71+
# Check for 2D or 3D training
72+
try:
73+
z, y, x = patch_shape
74+
except ValueError:
75+
y, x = patch_shape
76+
z = 1
77+
is_2d = z == 1
78+
ndim = 2 if is_2d else 3
79+
return is_2d, ndim
80+
81+
7082
def get_supervised_loader(
7183
data_paths: Tuple[str],
7284
raw_key: str,
@@ -108,16 +120,7 @@ def get_supervised_loader(
108120
Returns:
109121
The PyTorch dataloader.
110122
"""
111-
112-
# Check for 2D or 3D training
113-
try:
114-
z, y, x = patch_shape
115-
ndim = 2 if z == 1 else 3
116-
except ValueError:
117-
y, x = patch_shape
118-
ndim = 2
119-
print("ndim is: ", ndim)
120-
123+
_, ndim = determine_ndim(patch_shape)
121124
if label_transform is not None: # A specific label transform was passed, do nothing.
122125
pass
123126
elif add_boundary_transform:
@@ -166,7 +169,7 @@ def supervised_training(
166169
val_paths: Tuple[str],
167170
label_key: str,
168171
patch_shape: Tuple[int, int, int],
169-
save_root: str,
172+
save_root: Optional[str] = None,
170173
raw_key: str = "raw",
171174
batch_size: int = 1,
172175
lr: float = 1e-4,
@@ -236,14 +239,7 @@ def supervised_training(
236239
check_loader(val_loader, n_samples=4)
237240
return
238241

239-
# Check for 2D or 3D training
240-
try:
241-
z, y, x = patch_shape
242-
except ValueError:
243-
y, x = patch_shape
244-
z = 1
245-
is_2d = z == 1
246-
242+
is_2d, _ = determine_ndim(patch_shape)
247243
if is_2d:
248244
model = get_2d_model(out_channels=out_channels)
249245
else:

0 commit comments

Comments
 (0)