Skip to content

Commit 068a90b

Browse files
committed
PR #2 cosmetic changes
1 parent 7ebd59d commit 068a90b

File tree

1 file changed

+13
-48
lines changed

1 file changed

+13
-48
lines changed

synapse_net/training/domain_adaptation.py

Lines changed: 13 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,26 @@
2020
from ..inference.inference import get_model_path, compute_scale_from_voxel_size
2121
from ..inference.util import _Scaler
2222

23-
class NewPseudoLabeler(self_training.DefaultPseudoLabeler):
23+
class PseudoLabelerWithBackgroundMask(self_training.DefaultPseudoLabeler):
2424
"""Subclass of DefaultPseudoLabeler, which can subtract background from the pseudo labels if a background mask is provided.
2525
By default, assumes that the first channel contains the transformed raw data and the second channel contains the background mask.
2626
2727
Args:
28-
activation: Activation function applied to the teacher prediction.
29-
confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
30-
If None is given no mask will be computed.
31-
threshold_from_both_sides: Whether to include both values bigger than the threshold
32-
and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
33-
The former should be used for binary labels, the latter for for multiclass labels.
3428
confidence_mask_channel: A specific channel to use for computing the confidence mask.
3529
By default the confidence mask is computed across all channels independently.
3630
This is useful, if only one of the channels encodes a probability.
3731
raw_channel: Channel index of the raw data, which will be used as input to the teacher model
3832
background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels.
33+
kwargs: Additional keyword arguments for `self_training.DefaultPseudoLabeler`.
3934
"""
4035
def __init__(
4136
self,
42-
activation: Optional[torch.nn.Module] = None,
43-
confidence_threshold: Optional[float] = None,
44-
threshold_from_both_sides: bool = True,
4537
confidence_mask_channel: Optional[int] = None,
4638
raw_channel: Optional[int] = 0,
4739
background_mask_channel: Optional[int] = 1,
40+
**kwargs
4841
):
49-
super().__init__(activation, confidence_threshold, threshold_from_both_sides)
42+
super().__init__(**kwargs)
5043
self.confidence_mask_channel = confidence_mask_channel
5144
self.raw_channel = raw_channel
5245
self.background_mask_channel = background_mask_channel
@@ -97,46 +90,16 @@ def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tens
9790

9891
return pseudo_labels, label_mask
9992

100-
class NewMeanTeacherTrainer(self_training.MeanTeacherTrainer):
93+
class MeanTeacherTrainerWithBackgroundMask(self_training.MeanTeacherTrainer):
10194
"""Subclass of MeanTeacherTrainer, updated to handle cases where the background mask is provided.
10295
Once the pseudo labels are computed, the second channel of the teacher input is dropped, if it exists.
10396
The second channel of the student input is also dropped, if it exists, since it is not needed for training.
10497
10598
Args:
106-
activation: Activation function applied to the teacher prediction.
107-
confidence_threshold: Threshold for computing a mask for filtering the pseudo labels.
108-
If None is given no mask will be computed.
109-
threshold_from_both_sides: Whether to include both values bigger than the threshold
110-
and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask.
111-
The former should be used for binary labels, the latter for for multiclass labels.
112-
confidence_mask_channel: A specific channel to use for computing the confidence mask.
113-
By default the confidence mask is computed across all channels independently.
114-
This is useful, if only one of the channels encodes a probability.
115-
raw_channel: Channel index of the raw data to be used as input to the teacher model.
116-
background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels.
99+
kwargs: Additional keyword arguments for `self_training.MeanTeacherTrainer`.
117100
"""
118-
def __init__(
119-
self,
120-
model: torch.nn.Module,
121-
unsupervised_train_loader: torch.utils.data.DataLoader,
122-
unsupervised_loss: Callable,
123-
pseudo_labeler: Callable,
124-
supervised_train_loader: Optional[torch.utils.data.DataLoader] = None,
125-
unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
126-
supervised_val_loader: Optional[torch.utils.data.DataLoader] = None,
127-
supervised_loss: Optional[Callable] = None,
128-
unsupervised_loss_and_metric: Optional[Callable] = None,
129-
supervised_loss_and_metric: Optional[Callable] = None,
130-
logger=SelfTrainingTensorboardLogger,
131-
momentum: float = 0.999,
132-
reinit_teacher: Optional[bool] = None,
133-
sampler: Optional[Callable] = None,
134-
**kwargs,
135-
):
136-
super().__init__(model, unsupervised_train_loader, unsupervised_loss, pseudo_labeler,
137-
supervised_train_loader, unsupervised_val_loader, supervised_val_loader,
138-
supervised_loss, unsupervised_loss_and_metric, supervised_loss_and_metric,
139-
logger, momentum, reinit_teacher, sampler, **kwargs)
101+
def __init__(self, **kwargs):
102+
super().__init__(**kwargs)
140103

141104
def _train_epoch_unsupervised(self, progress, forward_context, backprop):
142105
self.model.train()
@@ -294,17 +257,19 @@ def mean_teacher_adaptation(
294257
if os.path.isdir(source_checkpoint):
295258
model = torch_em.util.load_model(source_checkpoint)
296259
else:
297-
model = torch.load(source_checkpoint)
260+
model = torch.load(source_checkpoint, weights_only=False)
298261
reinit_teacher = False
299262

300263
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
301264
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
302265

303266
# self training functionality
304267
if train_background_mask_paths is not None:
305-
pseudo_labeler = NewPseudoLabeler(confidence_threshold=confidence_threshold, background_mask_channel=1)
268+
pseudo_labeler = PseudoLabelerWithBackgroundMask(confidence_threshold=confidence_threshold, background_mask_channel=1)
269+
trainer_class = MeanTeacherTrainerWithBackgroundMask
306270
else:
307271
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
272+
trainer_class = self_training.MeanTeacherTrainer
308273

309274
loss = self_training.DefaultSelfTrainingLoss()
310275
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
@@ -345,7 +310,7 @@ def mean_teacher_adaptation(
345310
supervised_val_loader = None
346311

347312
device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
348-
trainer = self_training.MeanTeacherTrainer(
313+
trainer = trainer_class(
349314
name=name,
350315
model=model,
351316
optimizer=optimizer,

0 commit comments

Comments
 (0)