Skip to content

boundary mask for unsupervised training #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions synapse_net/training/domain_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
from ..inference.util import _Scaler


def mean_teacher_adaptation(
name: str,
unsupervised_train_paths: Tuple[str],
Expand All @@ -37,9 +36,13 @@ def mean_teacher_adaptation(
n_iterations: int = int(1e4),
n_samples_train: Optional[int] = None,
n_samples_val: Optional[int] = None,
sampler: Optional[callable] = None,
train_mask_paths: Optional[Tuple[str]] = None,
val_mask_paths: Optional[Tuple[str]] = None,
patch_sampler: Optional[callable] = None,
pseudo_label_sampler: Optional[callable] = None,
device: int = 0,
) -> None:
"""Run domain adapation to transfer a network trained on a source domain for a supervised
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
segmentation task to perform this task on a different target domain.

We support different domain adaptation settings:
Expand Down Expand Up @@ -82,6 +85,11 @@ def mean_teacher_adaptation(
based on the patch_shape and size of the volumes used for training.
n_samples_val: The number of val samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for validation.
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
patch_sampler: Accept or reject patches based on a condition.
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
device: GPU ID for training.
"""
assert (supervised_train_paths is None) == (supervised_val_paths is None)
is_2d, _ = _determine_ndim(patch_shape)
Expand All @@ -97,7 +105,7 @@ def mean_teacher_adaptation(
model = get_3d_model(out_channels=2)
reinit_teacher = True
else:
print("Mean teacehr training initialized from source model:", source_checkpoint)
print("Mean teacher training initialized from source model:", source_checkpoint)
if os.path.isdir(source_checkpoint):
model = torch_em.util.load_model(source_checkpoint)
else:
Expand All @@ -111,12 +119,24 @@ def mean_teacher_adaptation(
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
loss = self_training.DefaultSelfTrainingLoss()
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()

unsupervised_train_loader = get_unsupervised_loader(
unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
data_paths=unsupervised_train_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_train,
sample_mask_paths=train_mask_paths,
sampler=patch_sampler
)
unsupervised_val_loader = get_unsupervised_loader(
unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
data_paths=unsupervised_val_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_val,
sample_mask_paths=val_mask_paths,
sampler=patch_sampler
)

if supervised_train_paths is not None:
Expand All @@ -133,7 +153,7 @@ def mean_teacher_adaptation(
supervised_train_loader = None
supervised_val_loader = None

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
trainer = self_training.MeanTeacherTrainer(
name=name,
model=model,
Expand All @@ -155,11 +175,11 @@ def mean_teacher_adaptation(
device=device,
reinit_teacher=reinit_teacher,
save_root=save_root,
sampler=sampler,
sampler=pseudo_label_sampler,
)
trainer.fit(n_iterations)


# TODO patch shapes for other models
PATCH_SHAPES = {
"vesicles_3d": [48, 256, 256],
Expand Down Expand Up @@ -228,7 +248,6 @@ def _parse_patch_shape(patch_shape, model_name):
patch_shape = PATCH_SHAPES[model_name]
return patch_shape


def main():
"""@private
"""
Expand Down Expand Up @@ -293,4 +312,4 @@ def main():
n_samples_train=args.n_samples_train,
n_samples_val=args.n_samples_val,
check=args.check,
)
)
66 changes: 60 additions & 6 deletions synapse_net/training/semisupervised_training.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Optional, Tuple

import numpy as np
import uuid
import h5py
import torch
import torch_em
import torch_em.self_training as self_training
from torchvision import transforms

from synapse_net.file_utils import read_mrc
from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim


Expand All @@ -28,14 +31,36 @@ def weak_augmentations(p: float = 0.75) -> callable:
])
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)

def drop_mask_channel(x):
x = x[:1]
return x

class ComposedTransform:
def __init__(self, *funcs):
self.funcs = funcs

def __call__(self, x):
for f in self.funcs:
x = f(x)
return x

class ChannelSplitterSampler:
def __init__(self, sampler):
self.sampler = sampler

def __call__(self, x):
raw, mask = x[0], x[1]
return self.sampler(raw, mask)

def get_unsupervised_loader(
data_paths: Tuple[str],
raw_key: str,
patch_shape: Tuple[int, int, int],
batch_size: int,
n_samples: Optional[int],
exclude_top_and_bottom: bool = False,
sample_mask_paths: Optional[Tuple[str]] = None,
sampler: Optional[callable] = None,
exclude_top_and_bottom: bool = False,
) -> torch.utils.data.DataLoader:
"""Get a dataloader for unsupervised segmentation training.

Expand All @@ -50,19 +75,46 @@ def get_unsupervised_loader(
based on the patch_shape and size of the volumes used for training.
exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
avoid artifacts at the border of tomograms.
sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
sampler: Accept or reject patches based on a condition.

Returns:
The PyTorch dataloader.
"""

# We exclude the top and bottom slices where the tomogram reconstruction is bad.
# TODO this seems unneccesary if we have a boundary mask - remove?
if exclude_top_and_bottom:
roi = np.s_[5:-5, :, :]
else:
roi = None
# stack tomograms and masks and write to temp files to use as input to RawDataset()
if sample_mask_paths is not None:
assert len(data_paths) == len(sample_mask_paths), \
f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths."

stacked_paths = []
for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)):
raw = read_mrc(data_path)[0]
mask = read_mrc(mask_path)[0]
stacked = np.stack([raw, mask], axis=0)

tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5"
with h5py.File(tmp_path, "w") as f:
f.create_dataset("raw", data=stacked, compression="gzip")
stacked_paths.append(tmp_path)

# update variables for RawDataset()
data_paths = tuple(stacked_paths)
base_transform = torch_em.transform.get_raw_transform()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be adapted to only act on channel 0 (the actual raw data.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's leave this for the next PR.

raw_transform = ComposedTransform(base_transform, drop_mask_channel)
sampler = ChannelSplitterSampler(sampler)
with_channels = True
else:
raw_transform = torch_em.transform.get_raw_transform()
with_channels = False
sampler = None

_, ndim = _determine_ndim(patch_shape)
raw_transform = torch_em.transform.get_raw_transform()
transform = torch_em.transform.get_augmentations(ndim=ndim)

if n_samples is None:
Expand All @@ -71,15 +123,17 @@ def get_unsupervised_loader(
n_samples_per_ds = int(n_samples / len(data_paths))

augmentations = (weak_augmentations(), weak_augmentations())

datasets = [
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds)
torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi,
n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations)
for path in data_paths
]
ds = torch.utils.data.ConcatDataset(datasets)

num_workers = 4 * batch_size
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
num_workers=num_workers, shuffle=True)
return loader


Expand Down
Loading