Skip to content

background mask option for pseudo labeler - unsupervised training #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
186 changes: 170 additions & 16 deletions synapse_net/training/domain_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import tempfile
from glob import glob
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional, Tuple, Callable
import time

import mrcfile
import torch
import torch_em
import torch_em.self_training as self_training
from torch_em.self_training.logger import SelfTrainingTensorboardLogger
from elf.io import open_file
from sklearn.model_selection import train_test_split

Expand All @@ -18,6 +20,147 @@
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
from ..inference.util import _Scaler

class PseudoLabelerWithBackgroundMask(self_training.DefaultPseudoLabeler):
"""Subclass of DefaultPseudoLabeler, which can subtract background from the pseudo labels if a background mask is provided.
By default, assumes that the first channel contains the transformed raw data and the second channel contains the background mask.

Args:
confidence_mask_channel: A specific channel to use for computing the confidence mask.
By default the confidence mask is computed across all channels independently.
This is useful, if only one of the channels encodes a probability.
raw_channel: Channel index of the raw data, which will be used as input to the teacher model
background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels.
kwargs: Additional keyword arguments for `self_training.DefaultPseudoLabeler`.
"""
def __init__(
self,
confidence_mask_channel: Optional[int] = None,
raw_channel: Optional[int] = 0,
background_mask_channel: Optional[int] = 1,
**kwargs
):
super().__init__(**kwargs)
self.confidence_mask_channel = confidence_mask_channel
self.raw_channel = raw_channel
self.background_mask_channel = background_mask_channel

def _subtract_background(self, pseudo_labels: torch.Tensor, background_mask: torch.Tensor):
bool_mask = background_mask.bool()
return pseudo_labels.masked_fill(bool_mask, 0)

def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
"""Compute pseudo-labels.

Args:
teacher: The teacher model.
input_: The input for this batch.

Returns:
The pseudo-labels.
"""
if input_.ndim != 5:
raise ValueError(f"Expect data with 5 dimensions (B, C, D, H, W), got shape {input_.shape}.")

has_background_mask = input_.shape[1] > 1

if has_background_mask:
if self.background_mask_channel > input_.shape[1]:
raise ValueError(f"Channel index {self.background_mask_channel} is out of bounds for shape {input_.shape}.")

background_mask = input_[:, self.background_mask_channel].unsqueeze(1)
input_ = input_[:, self.raw_channel].unsqueeze(1)

pseudo_labels = teacher(input_)

if self.activation is not None:
pseudo_labels = self.activation(pseudo_labels)
if self.confidence_threshold is None:
label_mask = None
else:
mask_input = pseudo_labels if self.confidence_mask_channel is None\
else pseudo_labels[self.confidence_mask_channel:(self.confidence_mask_channel+1)]
label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\
else self._compute_label_mask_one_side(mask_input)
if self.confidence_mask_channel is not None:
size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2)))
label_mask = label_mask.expand(*size)

if has_background_mask:
pseudo_labels = self._subtract_background(pseudo_labels, background_mask)

return pseudo_labels, label_mask

class MeanTeacherTrainerWithBackgroundMask(self_training.MeanTeacherTrainer):
"""Subclass of MeanTeacherTrainer, updated to handle cases where the background mask is provided.
Once the pseudo labels are computed, the second channel of the teacher input is dropped, if it exists.
The second channel of the student input is also dropped, if it exists, since it is not needed for training.

Args:
kwargs: Additional keyword arguments for `self_training.MeanTeacherTrainer`.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)

def _train_epoch_unsupervised(self, progress, forward_context, backprop):
self.model.train()

n_iter = 0
t_per_iter = time.time()

# Sample from both the supervised and unsupervised loader.
for xu1, xu2 in self.unsupervised_train_loader:

# Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add an assertion here that we have two channels. Otherwise the normal MeanTeacherTrainer should have been used and this may indicate some error in the code logic. So it's better to fail in that case.

if xu2.shape[1] > 1:
xu2 = xu2[:, :1].contiguous()

xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True)

teacher_input, model_input = xu1, xu2

with forward_context(), torch.no_grad():
# Compute the pseudo labels.
pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)

# Drop the second channel for xu1 (teacher input) after computing the pseudo labels.
if xu1.shape[1] > 1:
xu1 = xu1[:, :1].contiguous()

# If we have a sampler then check if the current batch matches the condition for inclusion in training.
if self.sampler is not None:
keep_batch = self.sampler(pseudo_labels, label_filter)
if not keep_batch:
continue

self.optimizer.zero_grad()
# Perform unsupervised training
with forward_context():
loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter)
backprop(loss)

if self.logger is not None:
with torch.no_grad(), forward_context():
pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None
self.logger.log_train_unsupervised(
self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter
)
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
self.logger.log_lr(self._iteration, lr)
if self.pseudo_labeler.confidence_threshold is not None:
self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold)

with torch.no_grad():
self._momentum_update()

self._iteration += 1
n_iter += 1
if self._iteration >= self.max_iteration:
break
progress.update(1)

t_per_iter = (time.time() - t_per_iter) / n_iter
return t_per_iter

def mean_teacher_adaptation(
name: str,
unsupervised_train_paths: Tuple[str],
Expand All @@ -36,13 +179,14 @@ def mean_teacher_adaptation(
n_iterations: int = int(1e4),
n_samples_train: Optional[int] = None,
n_samples_val: Optional[int] = None,
train_mask_paths: Optional[Tuple[str]] = None,
val_mask_paths: Optional[Tuple[str]] = None,
train_sample_mask_paths: Optional[Tuple[str]] = None,
val_sample_mask_paths: Optional[Tuple[str]] = None,
train_background_mask_paths: Optional[Tuple[str]] = None,
patch_sampler: Optional[callable] = None,
pseudo_label_sampler: Optional[callable] = None,
device: int = 0,
) -> None:
"""Run domain adaptation to transfer a network trained on a source domain for a supervised
"""Run domain adapation to transfer a network trained on a source domain for a supervised
segmentation task to perform this task on a different target domain.

We support different domain adaptation settings:
Expand Down Expand Up @@ -85,10 +229,14 @@ def mean_teacher_adaptation(
based on the patch_shape and size of the volumes used for training.
n_samples_val: The number of val samples per epoch. By default this will be estimated
based on the patch_shape and size of the volumes used for validation.
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
patch_sampler: Accept or reject patches based on a condition.
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
train_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject
patches for training.
val_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject
patches for validation.
train_background_mask_paths: Filepaths to the background masks used for training.
Background masks are used to subtract background from the pseudo labels before the forward pass.
patch_sampler: A sampler for rejecting patches based on a defined conditon.
pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition.
device: GPU ID for training.
"""
assert (supervised_train_paths is None) == (supervised_val_paths is None)
Expand Down Expand Up @@ -116,17 +264,24 @@ def mean_teacher_adaptation(
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

# self training functionality
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
if train_background_mask_paths is not None:
pseudo_labeler = PseudoLabelerWithBackgroundMask(confidence_threshold=confidence_threshold, background_mask_channel=1)
trainer_class = MeanTeacherTrainerWithBackgroundMask
else:
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
trainer_class = self_training.MeanTeacherTrainer

loss = self_training.DefaultSelfTrainingLoss()
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()

unsupervised_train_loader = get_unsupervised_loader(
data_paths=unsupervised_train_paths,
raw_key=raw_key,
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_train,
sample_mask_paths=train_mask_paths,
sample_mask_paths=train_sample_mask_paths,
background_mask_paths=train_background_mask_paths,
sampler=patch_sampler
)
unsupervised_val_loader = get_unsupervised_loader(
Expand All @@ -135,7 +290,8 @@ def mean_teacher_adaptation(
patch_shape=patch_shape,
batch_size=batch_size,
n_samples=n_samples_val,
sample_mask_paths=val_mask_paths,
sample_mask_paths=val_sample_mask_paths,
background_mask_paths=None,
sampler=patch_sampler
)

Expand All @@ -154,7 +310,7 @@ def mean_teacher_adaptation(
supervised_val_loader = None

device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
trainer = self_training.MeanTeacherTrainer(
trainer = trainer_class(
name=name,
model=model,
optimizer=optimizer,
Expand All @@ -178,16 +334,14 @@ def mean_teacher_adaptation(
sampler=pseudo_label_sampler,
)
trainer.fit(n_iterations)



# TODO patch shapes for other models
PATCH_SHAPES = {
"vesicles_3d": [48, 256, 256],
}
"""@private
"""


def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
if len(files) == 0:
Expand Down
Loading
Loading