18
18
from ..inference .inference import get_model_path , compute_scale_from_voxel_size
19
19
from ..inference .util import _Scaler
20
20
21
-
22
21
def mean_teacher_adaptation (
23
22
name : str ,
24
23
unsupervised_train_paths : Tuple [str ],
@@ -37,9 +36,13 @@ def mean_teacher_adaptation(
37
36
n_iterations : int = int (1e4 ),
38
37
n_samples_train : Optional [int ] = None ,
39
38
n_samples_val : Optional [int ] = None ,
40
- sampler : Optional [callable ] = None ,
39
+ train_mask_paths : Optional [Tuple [str ]] = None ,
40
+ val_mask_paths : Optional [Tuple [str ]] = None ,
41
+ patch_sampler : Optional [callable ] = None ,
42
+ pseudo_label_sampler : Optional [callable ] = None ,
43
+ device : int = 0 ,
41
44
) -> None :
42
- """Run domain adapation to transfer a network trained on a source domain for a supervised
45
+ """Run domain adaptation to transfer a network trained on a source domain for a supervised
43
46
segmentation task to perform this task on a different target domain.
44
47
45
48
We support different domain adaptation settings:
@@ -82,6 +85,11 @@ def mean_teacher_adaptation(
82
85
based on the patch_shape and size of the volumes used for training.
83
86
n_samples_val: The number of val samples per epoch. By default this will be estimated
84
87
based on the patch_shape and size of the volumes used for validation.
88
+ train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
89
+ val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
90
+ patch_sampler: Accept or reject patches based on a condition.
91
+ pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
92
+ device: GPU ID for training.
85
93
"""
86
94
assert (supervised_train_paths is None ) == (supervised_val_paths is None )
87
95
is_2d , _ = _determine_ndim (patch_shape )
@@ -97,7 +105,7 @@ def mean_teacher_adaptation(
97
105
model = get_3d_model (out_channels = 2 )
98
106
reinit_teacher = True
99
107
else :
100
- print ("Mean teacehr training initialized from source model:" , source_checkpoint )
108
+ print ("Mean teacher training initialized from source model:" , source_checkpoint )
101
109
if os .path .isdir (source_checkpoint ):
102
110
model = torch_em .util .load_model (source_checkpoint )
103
111
else :
@@ -111,12 +119,24 @@ def mean_teacher_adaptation(
111
119
pseudo_labeler = self_training .DefaultPseudoLabeler (confidence_threshold = confidence_threshold )
112
120
loss = self_training .DefaultSelfTrainingLoss ()
113
121
loss_and_metric = self_training .DefaultSelfTrainingLossAndMetric ()
114
-
122
+
115
123
unsupervised_train_loader = get_unsupervised_loader (
116
- unsupervised_train_paths , raw_key , patch_shape , batch_size , n_samples = n_samples_train
124
+ data_paths = unsupervised_train_paths ,
125
+ raw_key = raw_key ,
126
+ patch_shape = patch_shape ,
127
+ batch_size = batch_size ,
128
+ n_samples = n_samples_train ,
129
+ sample_mask_paths = train_mask_paths ,
130
+ sampler = patch_sampler
117
131
)
118
132
unsupervised_val_loader = get_unsupervised_loader (
119
- unsupervised_val_paths , raw_key , patch_shape , batch_size , n_samples = n_samples_val
133
+ data_paths = unsupervised_val_paths ,
134
+ raw_key = raw_key ,
135
+ patch_shape = patch_shape ,
136
+ batch_size = batch_size ,
137
+ n_samples = n_samples_val ,
138
+ sample_mask_paths = val_mask_paths ,
139
+ sampler = patch_sampler
120
140
)
121
141
122
142
if supervised_train_paths is not None :
@@ -133,7 +153,7 @@ def mean_teacher_adaptation(
133
153
supervised_train_loader = None
134
154
supervised_val_loader = None
135
155
136
- device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
156
+ device = torch .device (f "cuda: { device } " ) if torch .cuda .is_available () else torch .device ("cpu" )
137
157
trainer = self_training .MeanTeacherTrainer (
138
158
name = name ,
139
159
model = model ,
@@ -155,11 +175,11 @@ def mean_teacher_adaptation(
155
175
device = device ,
156
176
reinit_teacher = reinit_teacher ,
157
177
save_root = save_root ,
158
- sampler = sampler ,
178
+ sampler = pseudo_label_sampler ,
159
179
)
160
180
trainer .fit (n_iterations )
161
-
162
-
181
+
182
+
163
183
# TODO patch shapes for other models
164
184
PATCH_SHAPES = {
165
185
"vesicles_3d" : [48 , 256 , 256 ],
@@ -228,7 +248,6 @@ def _parse_patch_shape(patch_shape, model_name):
228
248
patch_shape = PATCH_SHAPES [model_name ]
229
249
return patch_shape
230
250
231
-
232
251
def main ():
233
252
"""@private
234
253
"""
@@ -293,4 +312,4 @@ def main():
293
312
n_samples_train = args .n_samples_train ,
294
313
n_samples_val = args .n_samples_val ,
295
314
check = args .check ,
296
- )
315
+ )
0 commit comments