|
20 | 20 | from ..inference.inference import get_model_path, compute_scale_from_voxel_size
|
21 | 21 | from ..inference.util import _Scaler
|
22 | 22 |
|
23 |
| -class NewPseudoLabeler(self_training.DefaultPseudoLabeler): |
| 23 | +class PseudoLabelerWithBackgroundMask(self_training.DefaultPseudoLabeler): |
24 | 24 | """Subclass of DefaultPseudoLabeler, which can subtract background from the pseudo labels if a background mask is provided.
|
25 | 25 | By default, assumes that the first channel contains the transformed raw data and the second channel contains the background mask.
|
26 | 26 |
|
27 | 27 | Args:
|
28 |
| - activation: Activation function applied to the teacher prediction. |
29 |
| - confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. |
30 |
| - If None is given no mask will be computed. |
31 |
| - threshold_from_both_sides: Whether to include both values bigger than the threshold |
32 |
| - and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. |
33 |
| - The former should be used for binary labels, the latter for for multiclass labels. |
34 | 28 | confidence_mask_channel: A specific channel to use for computing the confidence mask.
|
35 | 29 | By default the confidence mask is computed across all channels independently.
|
36 | 30 | This is useful, if only one of the channels encodes a probability.
|
37 | 31 | raw_channel: Channel index of the raw data, which will be used as input to the teacher model
|
38 | 32 | background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels.
|
| 33 | + kwargs: Additional keyword arguments for `self_training.DefaultPseudoLabeler`. |
39 | 34 | """
|
40 | 35 | def __init__(
|
41 | 36 | self,
|
42 |
| - activation: Optional[torch.nn.Module] = None, |
43 |
| - confidence_threshold: Optional[float] = None, |
44 |
| - threshold_from_both_sides: bool = True, |
45 | 37 | confidence_mask_channel: Optional[int] = None,
|
46 | 38 | raw_channel: Optional[int] = 0,
|
47 | 39 | background_mask_channel: Optional[int] = 1,
|
| 40 | + **kwargs |
48 | 41 | ):
|
49 |
| - super().__init__(activation, confidence_threshold, threshold_from_both_sides) |
| 42 | + super().__init__(**kwargs) |
50 | 43 | self.confidence_mask_channel = confidence_mask_channel
|
51 | 44 | self.raw_channel = raw_channel
|
52 | 45 | self.background_mask_channel = background_mask_channel
|
@@ -97,46 +90,16 @@ def __call__(self, teacher: torch.nn.Module, input_: torch.Tensor) -> torch.Tens
|
97 | 90 |
|
98 | 91 | return pseudo_labels, label_mask
|
99 | 92 |
|
100 |
| -class NewMeanTeacherTrainer(self_training.MeanTeacherTrainer): |
| 93 | +class MeanTeacherTrainerWithBackgroundMask(self_training.MeanTeacherTrainer): |
101 | 94 | """Subclass of MeanTeacherTrainer, updated to handle cases where the background mask is provided.
|
102 | 95 | Once the pseudo labels are computed, the second channel of the teacher input is dropped, if it exists.
|
103 | 96 | The second channel of the student input is also dropped, if it exists, since it is not needed for training.
|
104 | 97 |
|
105 | 98 | Args:
|
106 |
| - activation: Activation function applied to the teacher prediction. |
107 |
| - confidence_threshold: Threshold for computing a mask for filtering the pseudo labels. |
108 |
| - If None is given no mask will be computed. |
109 |
| - threshold_from_both_sides: Whether to include both values bigger than the threshold |
110 |
| - and smaller than 1 - the thrhesold, or only values bigger than the threshold, in the mask. |
111 |
| - The former should be used for binary labels, the latter for for multiclass labels. |
112 |
| - confidence_mask_channel: A specific channel to use for computing the confidence mask. |
113 |
| - By default the confidence mask is computed across all channels independently. |
114 |
| - This is useful, if only one of the channels encodes a probability. |
115 |
| - raw_channel: Channel index of the raw data to be used as input to the teacher model. |
116 |
| - background_mask_channel: Channel index of the background mask, which will be subtracted from the pseudo labels. |
| 99 | + kwargs: Additional keyword arguments for `self_training.MeanTeacherTrainer`. |
117 | 100 | """
|
118 |
| - def __init__( |
119 |
| - self, |
120 |
| - model: torch.nn.Module, |
121 |
| - unsupervised_train_loader: torch.utils.data.DataLoader, |
122 |
| - unsupervised_loss: Callable, |
123 |
| - pseudo_labeler: Callable, |
124 |
| - supervised_train_loader: Optional[torch.utils.data.DataLoader] = None, |
125 |
| - unsupervised_val_loader: Optional[torch.utils.data.DataLoader] = None, |
126 |
| - supervised_val_loader: Optional[torch.utils.data.DataLoader] = None, |
127 |
| - supervised_loss: Optional[Callable] = None, |
128 |
| - unsupervised_loss_and_metric: Optional[Callable] = None, |
129 |
| - supervised_loss_and_metric: Optional[Callable] = None, |
130 |
| - logger=SelfTrainingTensorboardLogger, |
131 |
| - momentum: float = 0.999, |
132 |
| - reinit_teacher: Optional[bool] = None, |
133 |
| - sampler: Optional[Callable] = None, |
134 |
| - **kwargs, |
135 |
| - ): |
136 |
| - super().__init__(model, unsupervised_train_loader, unsupervised_loss, pseudo_labeler, |
137 |
| - supervised_train_loader, unsupervised_val_loader, supervised_val_loader, |
138 |
| - supervised_loss, unsupervised_loss_and_metric, supervised_loss_and_metric, |
139 |
| - logger, momentum, reinit_teacher, sampler, **kwargs) |
| 101 | + def __init__(self, **kwargs): |
| 102 | + super().__init__(**kwargs) |
140 | 103 |
|
141 | 104 | def _train_epoch_unsupervised(self, progress, forward_context, backprop):
|
142 | 105 | self.model.train()
|
@@ -294,17 +257,19 @@ def mean_teacher_adaptation(
|
294 | 257 | if os.path.isdir(source_checkpoint):
|
295 | 258 | model = torch_em.util.load_model(source_checkpoint)
|
296 | 259 | else:
|
297 |
| - model = torch.load(source_checkpoint) |
| 260 | + model = torch.load(source_checkpoint, weights_only=False) |
298 | 261 | reinit_teacher = False
|
299 | 262 |
|
300 | 263 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
301 | 264 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
|
302 | 265 |
|
303 | 266 | # self training functionality
|
304 | 267 | if train_background_mask_paths is not None:
|
305 |
| - pseudo_labeler = NewPseudoLabeler(confidence_threshold=confidence_threshold, background_mask_channel=1) |
| 268 | + pseudo_labeler = PseudoLabelerWithBackgroundMask(confidence_threshold=confidence_threshold, background_mask_channel=1) |
| 269 | + trainer_class = MeanTeacherTrainerWithBackgroundMask |
306 | 270 | else:
|
307 | 271 | pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
|
| 272 | + trainer_class = self_training.MeanTeacherTrainer |
308 | 273 |
|
309 | 274 | loss = self_training.DefaultSelfTrainingLoss()
|
310 | 275 | loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
|
@@ -345,7 +310,7 @@ def mean_teacher_adaptation(
|
345 | 310 | supervised_val_loader = None
|
346 | 311 |
|
347 | 312 | device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
|
348 |
| - trainer = self_training.MeanTeacherTrainer( |
| 313 | + trainer = trainer_class( |
349 | 314 | name=name,
|
350 | 315 | model=model,
|
351 | 316 | optimizer=optimizer,
|
|
0 commit comments