-
Notifications
You must be signed in to change notification settings - Fork 3
background mask option for pseudo labeler - unsupervised training #135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
stmartineau99
wants to merge
12
commits into
computational-cell-analytics:main
Choose a base branch
from
stmartineau99:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
4c17c0d
boundary mask for unsupervised training
6e865c3
implement lamella mask
stmartineau99 2df5a6c
boundary mask for unsupervised training
stmartineau99 0257cdb
boundary mask for unsupervised training
stmartineau99 3e454d7
boundary mask for unsupervised training
stmartineau99 702138f
Update synapse_net/training/domain_adaptation.py
constantinpape 712714f
optional background mask for unsupervised training
stmartineau99 292e450
fix domain_adaptation.py
stmartineau99 e6f86fc
background mask for unsupervised training
stmartineau99 aa8cd15
background mask for unsupervised training
stmartineau99 7ebd59d
create subclass NewMeanTeacherTrainer
stmartineau99 068a90b
PR #2 cosmetic changes
stmartineau99 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,14 @@ | |
import tempfile | ||
from glob import glob | ||
from pathlib import Path | ||
from typing import Optional, Tuple | ||
from typing import Optional, Tuple, Callable | ||
import time | ||
|
||
import mrcfile | ||
import torch | ||
import torch_em | ||
import torch_em.self_training as self_training | ||
from torch_em.self_training.logger import SelfTrainingTensorboardLogger | ||
from elf.io import open_file | ||
from sklearn.model_selection import train_test_split | ||
|
||
|
@@ -18,6 +20,147 @@ | |
from ..inference.inference import get_model_path, compute_scale_from_voxel_size | ||
from ..inference.util import _Scaler | ||
|
||
class PseudoLabelerWithBackgroundMask(self_training.DefaultPseudoLabeler): | ||
"""Subclass of DefaultPseudoLabeler, which can subtract background from the pseudo labels if a background mask is provided. | ||
By default, assumes that the first channel contains the transformed raw data and the second channel contains the background mask. | ||
|
||
Args: | ||
confidence_mask_channel: A specific channel to use for computing the confidence mask. | ||
By default the confidence mask is computed across all channels independently. | ||
This is useful, if only one of the channels encodes a probability. | ||
raw_channel: Channel index of the raw data, which will be used as input to the teacher model | ||
background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels. | ||
kwargs: Additional keyword arguments for `self_training.DefaultPseudoLabeler`. | ||
""" | ||
def __init__( | ||
self, | ||
confidence_mask_channel: Optional[int] = None, | ||
raw_channel: Optional[int] = 0, | ||
background_mask_channel: Optional[int] = 1, | ||
**kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.confidence_mask_channel = confidence_mask_channel | ||
self.raw_channel = raw_channel | ||
self.background_mask_channel = background_mask_channel | ||
|
||
def _subtract_background(self, pseudo_labels: torch.Tensor, background_mask: torch.Tensor): | ||
bool_mask = background_mask.bool() | ||
return pseudo_labels.masked_fill(bool_mask, 0) | ||
|
||
def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: | ||
"""Compute pseudo-labels. | ||
|
||
Args: | ||
teacher: The teacher model. | ||
input_: The input for this batch. | ||
|
||
Returns: | ||
The pseudo-labels. | ||
""" | ||
if input_.ndim != 5: | ||
raise ValueError(f"Expect data with 5 dimensions (B, C, D, H, W), got shape {input_.shape}.") | ||
|
||
has_background_mask = input_.shape[1] > 1 | ||
|
||
if has_background_mask: | ||
if self.background_mask_channel > input_.shape[1]: | ||
raise ValueError(f"Channel index {self.background_mask_channel} is out of bounds for shape {input_.shape}.") | ||
|
||
background_mask = input_[:, self.background_mask_channel].unsqueeze(1) | ||
input_ = input_[:, self.raw_channel].unsqueeze(1) | ||
|
||
pseudo_labels = teacher(input_) | ||
|
||
if self.activation is not None: | ||
pseudo_labels = self.activation(pseudo_labels) | ||
if self.confidence_threshold is None: | ||
label_mask = None | ||
else: | ||
mask_input = pseudo_labels if self.confidence_mask_channel is None\ | ||
else pseudo_labels[self.confidence_mask_channel:(self.confidence_mask_channel+1)] | ||
label_mask = self._compute_label_mask_both_sides(mask_input) if self.threshold_from_both_sides\ | ||
else self._compute_label_mask_one_side(mask_input) | ||
if self.confidence_mask_channel is not None: | ||
size = (pseudo_labels.shape[0], pseudo_labels.shape[1], *([-1] * (pseudo_labels.ndim - 2))) | ||
label_mask = label_mask.expand(*size) | ||
|
||
if has_background_mask: | ||
pseudo_labels = self._subtract_background(pseudo_labels, background_mask) | ||
|
||
return pseudo_labels, label_mask | ||
|
||
class MeanTeacherTrainerWithBackgroundMask(self_training.MeanTeacherTrainer): | ||
"""Subclass of MeanTeacherTrainer, updated to handle cases where the background mask is provided. | ||
Once the pseudo labels are computed, the second channel of the teacher input is dropped, if it exists. | ||
The second channel of the student input is also dropped, if it exists, since it is not needed for training. | ||
|
||
Args: | ||
kwargs: Additional keyword arguments for `self_training.MeanTeacherTrainer`. | ||
""" | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def _train_epoch_unsupervised(self, progress, forward_context, backprop): | ||
self.model.train() | ||
|
||
n_iter = 0 | ||
t_per_iter = time.time() | ||
|
||
# Sample from both the supervised and unsupervised loader. | ||
for xu1, xu2 in self.unsupervised_train_loader: | ||
|
||
# Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add an assertion here that we have two channels. Otherwise the normal |
||
if xu2.shape[1] > 1: | ||
xu2 = xu2[:, :1].contiguous() | ||
|
||
xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) | ||
|
||
teacher_input, model_input = xu1, xu2 | ||
|
||
with forward_context(), torch.no_grad(): | ||
# Compute the pseudo labels. | ||
pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) | ||
|
||
# Drop the second channel for xu1 (teacher input) after computing the pseudo labels. | ||
if xu1.shape[1] > 1: | ||
xu1 = xu1[:, :1].contiguous() | ||
|
||
# If we have a sampler then check if the current batch matches the condition for inclusion in training. | ||
if self.sampler is not None: | ||
keep_batch = self.sampler(pseudo_labels, label_filter) | ||
if not keep_batch: | ||
continue | ||
|
||
self.optimizer.zero_grad() | ||
# Perform unsupervised training | ||
with forward_context(): | ||
loss = self.unsupervised_loss(self.model, model_input, pseudo_labels, label_filter) | ||
backprop(loss) | ||
|
||
if self.logger is not None: | ||
with torch.no_grad(), forward_context(): | ||
pred = self.model(model_input) if self._iteration % self.log_image_interval == 0 else None | ||
self.logger.log_train_unsupervised( | ||
self._iteration, loss, xu1, xu2, pred, pseudo_labels, label_filter | ||
) | ||
lr = [pm["lr"] for pm in self.optimizer.param_groups][0] | ||
self.logger.log_lr(self._iteration, lr) | ||
if self.pseudo_labeler.confidence_threshold is not None: | ||
self.logger.log_ct(self._iteration, self.pseudo_labeler.confidence_threshold) | ||
|
||
with torch.no_grad(): | ||
self._momentum_update() | ||
|
||
self._iteration += 1 | ||
n_iter += 1 | ||
if self._iteration >= self.max_iteration: | ||
break | ||
progress.update(1) | ||
|
||
t_per_iter = (time.time() - t_per_iter) / n_iter | ||
return t_per_iter | ||
|
||
def mean_teacher_adaptation( | ||
name: str, | ||
unsupervised_train_paths: Tuple[str], | ||
|
@@ -36,13 +179,14 @@ def mean_teacher_adaptation( | |
n_iterations: int = int(1e4), | ||
n_samples_train: Optional[int] = None, | ||
n_samples_val: Optional[int] = None, | ||
train_mask_paths: Optional[Tuple[str]] = None, | ||
val_mask_paths: Optional[Tuple[str]] = None, | ||
train_sample_mask_paths: Optional[Tuple[str]] = None, | ||
constantinpape marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val_sample_mask_paths: Optional[Tuple[str]] = None, | ||
train_background_mask_paths: Optional[Tuple[str]] = None, | ||
patch_sampler: Optional[callable] = None, | ||
pseudo_label_sampler: Optional[callable] = None, | ||
device: int = 0, | ||
) -> None: | ||
"""Run domain adaptation to transfer a network trained on a source domain for a supervised | ||
"""Run domain adapation to transfer a network trained on a source domain for a supervised | ||
constantinpape marked this conversation as resolved.
Show resolved
Hide resolved
|
||
segmentation task to perform this task on a different target domain. | ||
|
||
We support different domain adaptation settings: | ||
|
@@ -85,10 +229,14 @@ def mean_teacher_adaptation( | |
based on the patch_shape and size of the volumes used for training. | ||
n_samples_val: The number of val samples per epoch. By default this will be estimated | ||
based on the patch_shape and size of the volumes used for validation. | ||
train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. | ||
val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. | ||
patch_sampler: Accept or reject patches based on a condition. | ||
pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. | ||
train_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject | ||
patches for training. | ||
val_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject | ||
patches for validation. | ||
train_background_mask_paths: Filepaths to the background masks used for training. | ||
Background masks are used to subtract background from the pseudo labels before the forward pass. | ||
patch_sampler: A sampler for rejecting patches based on a defined conditon. | ||
pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition. | ||
device: GPU ID for training. | ||
""" | ||
assert (supervised_train_paths is None) == (supervised_val_paths is None) | ||
|
@@ -116,17 +264,24 @@ def mean_teacher_adaptation( | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) | ||
|
||
# self training functionality | ||
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) | ||
if train_background_mask_paths is not None: | ||
stmartineau99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pseudo_labeler = PseudoLabelerWithBackgroundMask(confidence_threshold=confidence_threshold, background_mask_channel=1) | ||
trainer_class = MeanTeacherTrainerWithBackgroundMask | ||
else: | ||
stmartineau99 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) | ||
trainer_class = self_training.MeanTeacherTrainer | ||
|
||
loss = self_training.DefaultSelfTrainingLoss() | ||
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() | ||
|
||
unsupervised_train_loader = get_unsupervised_loader( | ||
data_paths=unsupervised_train_paths, | ||
raw_key=raw_key, | ||
patch_shape=patch_shape, | ||
batch_size=batch_size, | ||
n_samples=n_samples_train, | ||
sample_mask_paths=train_mask_paths, | ||
sample_mask_paths=train_sample_mask_paths, | ||
constantinpape marked this conversation as resolved.
Show resolved
Hide resolved
|
||
background_mask_paths=train_background_mask_paths, | ||
sampler=patch_sampler | ||
) | ||
unsupervised_val_loader = get_unsupervised_loader( | ||
|
@@ -135,7 +290,8 @@ def mean_teacher_adaptation( | |
patch_shape=patch_shape, | ||
batch_size=batch_size, | ||
n_samples=n_samples_val, | ||
sample_mask_paths=val_mask_paths, | ||
sample_mask_paths=val_sample_mask_paths, | ||
background_mask_paths=None, | ||
sampler=patch_sampler | ||
) | ||
|
||
|
@@ -154,7 +310,7 @@ def mean_teacher_adaptation( | |
supervised_val_loader = None | ||
|
||
device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") | ||
trainer = self_training.MeanTeacherTrainer( | ||
trainer = trainer_class( | ||
name=name, | ||
model=model, | ||
optimizer=optimizer, | ||
|
@@ -178,16 +334,14 @@ def mean_teacher_adaptation( | |
sampler=pseudo_label_sampler, | ||
) | ||
trainer.fit(n_iterations) | ||
|
||
|
||
|
||
# TODO patch shapes for other models | ||
PATCH_SHAPES = { | ||
"vesicles_3d": [48, 256, 256], | ||
} | ||
"""@private | ||
""" | ||
|
||
|
||
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction): | ||
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) | ||
if len(files) == 0: | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.