Skip to content

Commit 43eff47

Browse files
stmartineau99Sageconstantinpape
authored
boundary mask for unsupervised training (#132)
* boundary mask for unsupervised training * implement lamella mask * boundary mask for unsupervised training * boundary mask for unsupervised training * boundary mask for unsupervised training * Update synapse_net/training/domain_adaptation.py --------- Co-authored-by: Sage <sage@Sages-MacBook-Pro.local> Co-authored-by: Constantin Pape <c.pape@gmx.net>
1 parent bfccbf0 commit 43eff47

File tree

2 files changed

+92
-19
lines changed

2 files changed

+92
-19
lines changed

synapse_net/training/domain_adaptation.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
1919
from ..inference.util import _Scaler
2020

21-
2221
def mean_teacher_adaptation(
2322
name: str,
2423
unsupervised_train_paths: Tuple[str],
@@ -37,9 +36,13 @@ def mean_teacher_adaptation(
3736
n_iterations: int = int(1e4),
3837
n_samples_train: Optional[int] = None,
3938
n_samples_val: Optional[int] = None,
40-
sampler: Optional[callable] = None,
39+
train_mask_paths: Optional[Tuple[str]] = None,
40+
val_mask_paths: Optional[Tuple[str]] = None,
41+
patch_sampler: Optional[callable] = None,
42+
pseudo_label_sampler: Optional[callable] = None,
43+
device: int = 0,
4144
) -> None:
42-
"""Run domain adapation to transfer a network trained on a source domain for a supervised
45+
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
4346
segmentation task to perform this task on a different target domain.
4447
4548
We support different domain adaptation settings:
@@ -82,6 +85,11 @@ def mean_teacher_adaptation(
8285
based on the patch_shape and size of the volumes used for training.
8386
n_samples_val: The number of val samples per epoch. By default this will be estimated
8487
based on the patch_shape and size of the volumes used for validation.
88+
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
89+
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
90+
patch_sampler: Accept or reject patches based on a condition.
91+
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
92+
device: GPU ID for training.
8593
"""
8694
assert (supervised_train_paths is None) == (supervised_val_paths is None)
8795
is_2d, _ = _determine_ndim(patch_shape)
@@ -97,7 +105,7 @@ def mean_teacher_adaptation(
97105
model = get_3d_model(out_channels=2)
98106
reinit_teacher = True
99107
else:
100-
print("Mean teacehr training initialized from source model:", source_checkpoint)
108+
print("Mean teacher training initialized from source model:", source_checkpoint)
101109
if os.path.isdir(source_checkpoint):
102110
model = torch_em.util.load_model(source_checkpoint)
103111
else:
@@ -111,12 +119,24 @@ def mean_teacher_adaptation(
111119
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
112120
loss = self_training.DefaultSelfTrainingLoss()
113121
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
114-
122+
115123
unsupervised_train_loader = get_unsupervised_loader(
116-
unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
124+
data_paths=unsupervised_train_paths,
125+
raw_key=raw_key,
126+
patch_shape=patch_shape,
127+
batch_size=batch_size,
128+
n_samples=n_samples_train,
129+
sample_mask_paths=train_mask_paths,
130+
sampler=patch_sampler
117131
)
118132
unsupervised_val_loader = get_unsupervised_loader(
119-
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
133+
data_paths=unsupervised_val_paths,
134+
raw_key=raw_key,
135+
patch_shape=patch_shape,
136+
batch_size=batch_size,
137+
n_samples=n_samples_val,
138+
sample_mask_paths=val_mask_paths,
139+
sampler=patch_sampler
120140
)
121141

122142
if supervised_train_paths is not None:
@@ -133,7 +153,7 @@ def mean_teacher_adaptation(
133153
supervised_train_loader = None
134154
supervised_val_loader = None
135155

136-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
156+
device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
137157
trainer = self_training.MeanTeacherTrainer(
138158
name=name,
139159
model=model,
@@ -155,11 +175,11 @@ def mean_teacher_adaptation(
155175
device=device,
156176
reinit_teacher=reinit_teacher,
157177
save_root=save_root,
158-
sampler=sampler,
178+
sampler=pseudo_label_sampler,
159179
)
160180
trainer.fit(n_iterations)
161-
162-
181+
182+
163183
# TODO patch shapes for other models
164184
PATCH_SHAPES = {
165185
"vesicles_3d": [48, 256, 256],
@@ -228,7 +248,6 @@ def _parse_patch_shape(patch_shape, model_name):
228248
patch_shape = PATCH_SHAPES[model_name]
229249
return patch_shape
230250

231-
232251
def main():
233252
"""@private
234253
"""
@@ -293,4 +312,4 @@ def main():
293312
n_samples_train=args.n_samples_train,
294313
n_samples_val=args.n_samples_val,
295314
check=args.check,
296-
)
315+
)

synapse_net/training/semisupervised_training.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from typing import Optional, Tuple
22

33
import numpy as np
4+
import uuid
5+
import h5py
46
import torch
57
import torch_em
68
import torch_em.self_training as self_training
79
from torchvision import transforms
810

11+
from synapse_net.file_utils import read_mrc
912
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
1013

1114

@@ -28,14 +31,36 @@ def weak_augmentations(p: float = 0.75) -> callable:
2831
])
2932
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
3033

34+
def drop_mask_channel(x):
35+
x = x[:1]
36+
return x
37+
38+
class ComposedTransform:
39+
def __init__(self, *funcs):
40+
self.funcs = funcs
41+
42+
def __call__(self, x):
43+
for f in self.funcs:
44+
x = f(x)
45+
return x
46+
47+
class ChannelSplitterSampler:
48+
def __init__(self, sampler):
49+
self.sampler = sampler
50+
51+
def __call__(self, x):
52+
raw, mask = x[0], x[1]
53+
return self.sampler(raw, mask)
3154

3255
def get_unsupervised_loader(
3356
data_paths: Tuple[str],
3457
raw_key: str,
3558
patch_shape: Tuple[int, int, int],
3659
batch_size: int,
3760
n_samples: Optional[int],
38-
exclude_top_and_bottom: bool = False,
61+
sample_mask_paths: Optional[Tuple[str]] = None,
62+
sampler: Optional[callable] = None,
63+
exclude_top_and_bottom: bool = False,
3964
) -> torch.utils.data.DataLoader:
4065
"""Get a dataloader for unsupervised segmentation training.
4166
@@ -50,19 +75,46 @@ def get_unsupervised_loader(
5075
based on the patch_shape and size of the volumes used for training.
5176
exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
5277
avoid artifacts at the border of tomograms.
78+
sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
79+
sampler: Accept or reject patches based on a condition.
5380
5481
Returns:
5582
The PyTorch dataloader.
5683
"""
57-
5884
# We exclude the top and bottom slices where the tomogram reconstruction is bad.
85+
# TODO this seems unneccesary if we have a boundary mask - remove?
5986
if exclude_top_and_bottom:
6087
roi = np.s_[5:-5, :, :]
6188
else:
6289
roi = None
90+
# stack tomograms and masks and write to temp files to use as input to RawDataset()
91+
if sample_mask_paths is not None:
92+
assert len(data_paths) == len(sample_mask_paths), \
93+
f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths."
94+
95+
stacked_paths = []
96+
for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)):
97+
raw = read_mrc(data_path)[0]
98+
mask = read_mrc(mask_path)[0]
99+
stacked = np.stack([raw, mask], axis=0)
100+
101+
tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5"
102+
with h5py.File(tmp_path, "w") as f:
103+
f.create_dataset("raw", data=stacked, compression="gzip")
104+
stacked_paths.append(tmp_path)
105+
106+
# update variables for RawDataset()
107+
data_paths = tuple(stacked_paths)
108+
base_transform = torch_em.transform.get_raw_transform()
109+
raw_transform = ComposedTransform(base_transform, drop_mask_channel)
110+
sampler = ChannelSplitterSampler(sampler)
111+
with_channels = True
112+
else:
113+
raw_transform = torch_em.transform.get_raw_transform()
114+
with_channels = False
115+
sampler = None
63116

64117
_, ndim = _determine_ndim(patch_shape)
65-
raw_transform = torch_em.transform.get_raw_transform()
66118
transform = torch_em.transform.get_augmentations(ndim=ndim)
67119

68120
if n_samples is None:
@@ -71,15 +123,17 @@ def get_unsupervised_loader(
71123
n_samples_per_ds = int(n_samples / len(data_paths))
72124

73125
augmentations = (weak_augmentations(), weak_augmentations())
126+
74127
datasets = [
75-
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
76-
augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds)
128+
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi,
129+
n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations)
77130
for path in data_paths
78131
]
79132
ds = torch.utils.data.ConcatDataset(datasets)
80133

81134
num_workers = 4 * batch_size
82-
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
135+
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
136+
num_workers=num_workers, shuffle=True)
83137
return loader
84138

85139

0 commit comments

Comments
 (0)