Skip to content

Commit 0257cdb

Browse files
committed
boundary mask for unsupervised training
1 parent 2df5a6c commit 0257cdb

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed

synapse_net/training/domain_adaptation.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,138 @@ def mean_teacher_adaptation(
178178
sampler=pseudo_label_sampler,
179179
)
180180
trainer.fit(n_iterations)
181+
182+
183+
# TODO patch shapes for other models
184+
PATCH_SHAPES = {
185+
"vesicles_3d": [48, 256, 256],
186+
}
187+
"""@private
188+
"""
189+
190+
191+
def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
192+
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
193+
if len(files) == 0:
194+
raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
195+
196+
# Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
197+
# And resave the volumes.
198+
resave_val_crops = len(files) < 4
199+
200+
# We only resave the data if we resave val crops or resize the training data
201+
resave_data = resave_val_crops or resize_training_data
202+
if not resave_data:
203+
train_paths, val_paths = train_test_split(files, test_size=val_fraction)
204+
return train_paths, val_paths
205+
206+
train_paths, val_paths = [], []
207+
for file_path in files:
208+
file_name = os.path.basename(file_path)
209+
data = open_file(file_path, mode="r")["data"][:]
210+
211+
if resize_training_data:
212+
with mrcfile.open(file_path) as f:
213+
voxel_size = f.voxel_size
214+
voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
215+
scale = compute_scale_from_voxel_size(voxel_size, model_name)
216+
scaler = _Scaler(scale, verbose=False)
217+
data = scaler.sale_input(data)
218+
219+
if resave_val_crops:
220+
n_slices = data.shape[0]
221+
val_slice = int((1.0 - val_fraction) * n_slices)
222+
train_data, val_data = data[:val_slice], data[val_slice:]
223+
224+
train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
225+
with open_file(train_path, mode="w") as f:
226+
f.create_dataset("data", data=train_data, compression="lzf")
227+
train_paths.append(train_path)
228+
229+
val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
230+
with open_file(val_path, mode="w") as f:
231+
f.create_dataset("data", data=val_data, compression="lzf")
232+
val_paths.append(val_path)
233+
234+
else:
235+
output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
236+
with open_file(output_path, mode="w") as f:
237+
f.create_dataset("data", data=data, compression="lzf")
238+
train_paths.append(output_path)
239+
240+
if not resave_val_crops:
241+
train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
242+
243+
return train_paths, val_paths
244+
245+
246+
def _parse_patch_shape(patch_shape, model_name):
247+
if patch_shape is None:
248+
patch_shape = PATCH_SHAPES[model_name]
249+
return patch_shape
250+
251+
def main():
252+
"""@private
253+
"""
254+
import argparse
255+
256+
parser = argparse.ArgumentParser(
257+
description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
258+
"You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
259+
"synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa
260+
"The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa
261+
"You can then use this model for segmentation with the SynapseNet GUI or CLI. "
262+
"Check out the information below for details on the arguments of this function.",
263+
formatter_class=argparse.RawTextHelpFormatter
264+
)
265+
parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
266+
parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
267+
parser.add_argument("--file_pattern", default="*",
268+
help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
269+
parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa
270+
parser.add_argument(
271+
"--source_model",
272+
default="vesicles_3d",
273+
help="The source model used for weight initialization of teacher and student model. "
274+
"By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
275+
)
276+
parser.add_argument(
277+
"--resize_training_data", action="store_true",
278+
help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
279+
)
280+
parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
281+
parser.add_argument(
282+
"--patch_shape", nargs=3, type=int,
283+
help="The patch shape for training. By default the patch shape the source model was trained with is used."
284+
)
285+
286+
# More optional argument:
287+
parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
288+
parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa
289+
parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa
290+
parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa
291+
parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa
292+
293+
args = parser.parse_args()
294+
295+
source_checkpoint = get_model_path(args.source_model)
296+
patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
297+
with tempfile.TemporaryDirectory() as tmp_dir:
298+
unsupervised_train_paths, unsupervised_val_paths = _get_paths(
299+
args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
300+
)
301+
unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
302+
303+
mean_teacher_adaptation(
304+
name=args.name,
305+
unsupervised_train_paths=unsupervised_train_paths,
306+
unsupervised_val_paths=unsupervised_val_paths,
307+
patch_shape=patch_shape,
308+
source_checkpoint=source_checkpoint,
309+
raw_key=raw_key,
310+
n_iterations=args.n_iterations,
311+
batch_size=args.batch_size,
312+
n_samples_train=args.n_samples_train,
313+
n_samples_val=args.n_samples_val,
314+
check=args.check,
315+
)

0 commit comments

Comments
 (0)