Skip to content

Commit eabd2aa

Browse files
Merge pull request #41 from computational-cell-analytics/sm-dev
Sm dev
2 parents fc7a3a1 + b58f1d3 commit eabd2aa

File tree

11 files changed

+305
-15
lines changed

11 files changed

+305
-15
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ models/*/
88
run_sbatch.sbatch
99
slurm/
1010
scripts/cooper/evaluation_results/
11-
scripts/cooper/training/copy_testset.py
11+
scripts/cooper/training/copy_testset.py
12+
scripts/rizzoli/upsample_data.py
13+
scripts/cooper/training/find_rec_testset.py

scripts/cooper/training/evaluation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key)
3030
#get the labels and vesicles
3131
with h5py.File(labels_path) as label_file:
3232
labels = label_file["labels"]
33-
vesicles = labels["vesicles"]
34-
gt = vesicles[anno_key][:]
33+
#vesicles = labels["vesicles"]
34+
gt = labels[anno_key][:]
3535

3636
with h5py.File(vesicles_path) as seg_file:
3737
segmentation = seg_file["vesicles"]

scripts/cooper/training/train_AZ.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import os
2+
from glob import glob
3+
import argparse
4+
import json
5+
6+
import torch_em
7+
import torch
8+
9+
from sklearn.model_selection import train_test_split
10+
11+
from synaptic_reconstruction.training import supervised_training
12+
from synaptic_reconstruction.training import semisupervised_training
13+
14+
TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/exported_imod_objects"
15+
OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/training_AZ_v1"
16+
17+
18+
def _require_train_val_test_split(datasets):
19+
train_ratio, val_ratio, test_ratio = 0.8, 0.1, 0.1
20+
if len(datasets) < 10:
21+
train_ratio, val_ratio, test_ratio = 0.5, 0.25, 0.25
22+
23+
def _train_val_test_split(names):
24+
train, test = train_test_split(names, test_size=1 - train_ratio, shuffle=True)
25+
_ratio = test_ratio / (test_ratio + val_ratio)
26+
val, test = train_test_split(test, test_size=_ratio)
27+
return train, val, test
28+
29+
for ds in datasets:
30+
print(ds)
31+
split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json")
32+
if os.path.exists(split_path):
33+
continue
34+
35+
file_paths = sorted(glob(os.path.join(TRAIN_ROOT, ds, "*.h5")))
36+
file_names = [os.path.basename(path) for path in file_paths]
37+
38+
train, val, test = _train_val_test_split(file_names)
39+
40+
with open(split_path, "w") as f:
41+
json.dump({"train": train, "val": val, "test": test}, f)
42+
43+
def _require_train_val_split(datasets):
44+
train_ratio, val_ratio= 0.8, 0.2
45+
46+
def _train_val_split(names):
47+
train, val = train_test_split(names, test_size=1 - train_ratio, shuffle=True)
48+
return train, val
49+
50+
for ds in datasets:
51+
print(ds)
52+
split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json")
53+
if os.path.exists(split_path):
54+
continue
55+
56+
file_paths = sorted(glob(os.path.join(TRAIN_ROOT, ds, "*.h5")))
57+
file_names = [os.path.basename(path) for path in file_paths]
58+
59+
train, val = _train_val_split(file_names)
60+
61+
with open(split_path, "w") as f:
62+
json.dump({"train": train, "val": val}, f)
63+
64+
def get_paths(split, datasets, testset=True):
65+
if testset:
66+
_require_train_val_test_split(datasets)
67+
else:
68+
_require_train_val_split(datasets)
69+
70+
paths = []
71+
for ds in datasets:
72+
split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json")
73+
with open(split_path) as f:
74+
names = json.load(f)[split]
75+
ds_paths = [os.path.join(TRAIN_ROOT, ds, name) for name in names]
76+
assert all(os.path.exists(path) for path in ds_paths)
77+
paths.extend(ds_paths)
78+
79+
return paths
80+
81+
def train(key, ignore_label = None, training_2D = False, testset = True):
82+
83+
datasets = [
84+
"01_hoi_maus_2020_incomplete",
85+
"06_hoi_wt_stem750_fm",
86+
"12_chemical_fix_cryopreparation"
87+
]
88+
train_paths = get_paths("train", datasets=datasets, testset=testset)
89+
val_paths = get_paths("val", datasets=datasets, testset=testset)
90+
91+
print("Start training with:")
92+
print(len(train_paths), "tomograms for training")
93+
print(len(val_paths), "tomograms for validation")
94+
95+
patch_shape = [48, 256, 256]
96+
model_name=f"3D-AZ-model-v1"
97+
98+
#checking for 2D training
99+
if training_2D:
100+
patch_shape = [1, 256, 256]
101+
model_name=f"2D-AZ-model-v1"
102+
103+
batch_size = 4
104+
check = False
105+
106+
supervised_training(
107+
name=model_name,
108+
train_paths=train_paths,
109+
val_paths=val_paths,
110+
label_key=f"/labels/{key}",
111+
patch_shape=patch_shape, batch_size=batch_size,
112+
sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=1),
113+
n_samples_train=None, n_samples_val=25,
114+
check=check,
115+
save_root="/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/AZ_models",
116+
n_iterations=int(5e3),
117+
ignore_label= ignore_label,
118+
label_transform=torch_em.transform.label.labels_to_binary,
119+
out_channels = 1,
120+
)
121+
122+
123+
def main():
124+
parser = argparse.ArgumentParser()
125+
parser.add_argument("-k", "--key", required=True, help="Key ID that will be used by model in training")
126+
parser.add_argument("-m", "--mask", type=int, default=None, help="Mask ID that will be ignored by model in training")
127+
parser.add_argument("-2D", "--training_2D", action='store_true', help="Set to True for 2D training")
128+
parser.add_argument("-t", "--testset", action='store_false', help="Set to False if no testset should be created")
129+
args = parser.parse_args()
130+
train(args.key, args.mask, args.training_2D, args.testset)
131+
132+
133+
if __name__ == "__main__":
134+
main()

scripts/cooper/vesicle_segmentation_h5.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,19 @@ def get_volume(input_path):
3434
input_volume = f[key][:]
3535
return input_volume
3636

37-
def run_vesicle_segmentation(input_path, output_path, model_path, tile_shape, halo, include_boundary, key_label):
37+
def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mask_key,tile_shape, halo, include_boundary, key_label):
3838
tiling = parse_tiling(tile_shape, halo)
3939
print(f"using tiling {tiling}")
4040
input = get_volume(input_path)
41-
segmentation, prediction = segment_vesicles(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, return_predictions=True, exclude_boundary=not include_boundary)
41+
42+
#check if we have a restricting mask for the segmentation
43+
if mask_path is not None:
44+
with open_file(mask_path, "r") as f:
45+
mask = f[mask_key][:]
46+
else:
47+
mask = None
48+
49+
segmentation, prediction = segment_vesicles(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, return_predictions=True, exclude_boundary=not include_boundary, mask = mask)
4250
foreground, boundaries = prediction[:2]
4351

4452
seg_output = _require_output_folders(output_path)
@@ -63,6 +71,12 @@ def run_vesicle_segmentation(input_path, output_path, model_path, tile_shape, ha
6371
f.create_dataset(f"prediction_{key_label}/foreground", data = foreground, compression="gzip")
6472
f.create_dataset(f"prediction_{key_label}/boundaries", data = boundaries, compression="gzip")
6573

74+
if mask is not None:
75+
if mask_key in f:
76+
print("mask image already saved")
77+
else:
78+
f.create_dataset(mask_key, data = mask, compression = "gzip")
79+
6680

6781

6882

@@ -75,7 +89,15 @@ def segment_folder(args):
7589
print(input_files)
7690
pbar = tqdm(input_files, desc="Run segmentation")
7791
for input_path in pbar:
78-
run_vesicle_segmentation(input_path, args.output_path, args.model_path, args.tile_shape, args.halo, args.include_boundary, args.key_label)
92+
93+
filename = os.path.basename(input_path)
94+
try:
95+
mask_path = os.path.join(args.mask_path, filename)
96+
except:
97+
print(f"Mask file not found for {input_path}")
98+
mask_path = None
99+
100+
run_vesicle_segmentation(input_path, args.output_path, args.model_path, mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label)
79101

80102
def main():
81103
parser = argparse.ArgumentParser(description="Segment vesicles in EM tomograms.")
@@ -90,6 +112,12 @@ def main():
90112
parser.add_argument(
91113
"--model_path", "-m", required=True, help="The filepath to the vesicle model."
92114
)
115+
parser.add_argument(
116+
"--mask_path", help="The filepath to a h5 file with a mask that will be used to restrict the segmentation. Needs to be in combination with mask_key."
117+
)
118+
parser.add_argument(
119+
"--mask_key", help="Key name that holds the mask segmentation"
120+
)
93121
parser.add_argument(
94122
"--tile_shape", type=int, nargs=3,
95123
help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient."
@@ -113,7 +141,7 @@ def main():
113141
if os.path.isdir(input_):
114142
segment_folder(args)
115143
else:
116-
run_vesicle_segmentation(input_, args.output_path, args.model_path, args.tile_shape, args.halo, args.include_boundary, args.key_label)
144+
run_vesicle_segmentation(input_, args.output_path, args.model_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label)
117145

118146
print("Finished segmenting!")
119147

scripts/cryo/vesicles/train_domain_adaptation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def vesicle_domain_adaptation(teacher_model, testset = True):
8484

8585
#adjustable parameters
8686
patch_shape = [48, 256, 256]
87-
model_name = "vesicle-DA-cryo-v1"
87+
model_name = "vesicle-DA-cryo-v2"
8888

8989
model_root = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/models_v2/checkpoints/"
9090
checkpoint_path = os.path.join(model_root, teacher_model)
@@ -98,6 +98,7 @@ def vesicle_domain_adaptation(teacher_model, testset = True):
9898
save_root="/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/DA_models",
9999
source_checkpoint=checkpoint_path,
100100
confidence_threshold=0.75,
101+
n_iterations=int(5e4),
101102
)
102103

103104

scripts/inner_ear/training/train_domain_adaptation_vesicles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def vesicle_domain_adaptation(teacher_model, testset = True):
119119

120120
#adjustable parameters
121121
patch_shape = [48, 256, 256]
122-
model_name = "vesicle-DA-inner_ear-v1"
122+
model_name = "vesicle-DA-inner_ear-v2"
123123

124124
model_root = "/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/models_v2/checkpoints/"
125125
checkpoint_path = os.path.join(model_root, teacher_model)
@@ -133,6 +133,7 @@ def vesicle_domain_adaptation(teacher_model, testset = True):
133133
save_root="/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/DA_models",
134134
source_checkpoint=checkpoint_path,
135135
confidence_threshold=0.75,
136+
n_iterations=int(1e5),
136137
)
137138

138139

scripts/rizzoli/2D_vesicle_segmentation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tqdm import tqdm
77
import torch
88
import torch_em
9+
import numpy as np
910

1011
from synaptic_reconstruction.inference.vesicles import segment_vesicles
1112
from synaptic_reconstruction.inference.util import parse_tiling
@@ -73,13 +74,18 @@ def run_vesicle_segmentation(input_path, output_path, model_path, tile_shape, ha
7374

7475
def process_slices(input_volume):
7576
processed_slices = []
77+
foreground = []
78+
boundaries = []
7679
for z in range(input_volume.shape[0]):
7780
slice_ = input_volume[z, :, :]
78-
segmented_slice = segment_vesicles(input_volume=slice_, model=model, verbose=False, tiling=tiling, exclude_boundary=not include_boundary)
81+
segmented_slice, prediction_slice = segment_vesicles(input_volume=slice_, model=model, verbose=False, tiling=tiling, return_predictions=True, exclude_boundary=not include_boundary)
7982
processed_slices.append(segmented_slice)
80-
return processed_slices
83+
foreground_pred_slice, boundaries_pred_slice = prediction_slice[:2]
84+
foreground.append(foreground_pred_slice)
85+
boundaries.append(boundaries_pred_slice)
86+
return processed_slices, foreground, boundaries
8187

82-
segmentation = process_slices(input)
88+
segmentation, foreground, boundaries = process_slices(input)
8389

8490
seg_output = _require_output_folders(output_path)
8591
file_name = Path(input_path).stem
@@ -100,6 +106,8 @@ def process_slices(input_volume):
100106
print("Skipping", input_path, "because", key, "exists")
101107
else:
102108
f.create_dataset(key, data=segmentation, compression="gzip")
109+
f.create_dataset(f"prediction_{key_label}/foreground", data = foreground, compression="gzip")
110+
f.create_dataset(f"prediction_{key_label}/boundaries", data = boundaries, compression="gzip")
103111

104112

105113

scripts/rizzoli/evaluation_2D.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def evaluate_file(labels_path, vesicles_path, model_name, segment_key, anno_key)
5858
#get the labels and vesicles
5959
with h5py.File(labels_path) as label_file:
6060
labels = label_file["labels"]
61-
vesicles = labels["vesicles"]
62-
gt = vesicles[anno_key][:]
61+
#vesicles = labels["vesicles"]
62+
gt = labels[anno_key][:]
6363

6464
with h5py.File(vesicles_path) as seg_file:
6565
segmentation = seg_file["vesicles"]

0 commit comments

Comments
 (0)