Skip to content

Commit f56436a

Browse files
Update structure segmentation training
1 parent 74c3e69 commit f56436a

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

scripts/inner_ear/training/train_structure_segmentation.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,22 @@
1010
from tqdm import tqdm
1111

1212
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/inner_ear_data"
13+
ROOT_OTHER_TOMOS = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/other_tomograms/"
14+
1315
LABEL_KEY = "labels/inner_ear_structures"
16+
OTHER_NAMES = ["vesicle_pools", "tether", "rat"]
17+
18+
19+
def get_other_paths(name):
20+
assert name in OTHER_NAMES, f"Invalid name {name}"
21+
if name == "vesicle_pools":
22+
folder = "01_vesicle_pools"
23+
elif name == "tether":
24+
folder = "02_tether"
25+
else:
26+
folder = "03_ratten_tomos"
27+
paths = sorted(glob(os.path.join(ROOT, folder, "*.h5")))
28+
return paths
1429

1530

1631
def get_train_val_test_split(root):
@@ -34,8 +49,7 @@ def get_train_val_test_split(root):
3449
return train_tomos, val_tomos, test_tomos
3550

3651

37-
def preprocess_labels(tomograms):
38-
structure_keys = ("ribbon", "PD", "membrane")
52+
def preprocess_labels(tomograms, structure_keys=("ribbon", "PD", "membrane")):
3953
nc = len(structure_keys)
4054

4155
for tomo in tqdm(tomograms, desc="Preprocess labels"):
@@ -60,11 +74,11 @@ def noop(x):
6074
return x
6175

6276

63-
def train_inner_ear_structures(train_tomograms, val_tomograms):
77+
def train_inner_ear_structures(train_tomograms, val_tomograms, name):
6478
patch_shape = (64, 512, 512)
6579
sampler = MinForegroundSampler(min_fraction=0.05, p_reject=0.95)
6680
supervised_training(
67-
name="inner_ear_structure_model",
81+
name=name,
6882
train_paths=train_tomograms, val_paths=val_tomograms,
6983
label_key=LABEL_KEY, patch_shape=patch_shape, save_root=".",
7084
sampler=sampler, label_transform=noop, out_channels=3,
@@ -73,11 +87,27 @@ def train_inner_ear_structures(train_tomograms, val_tomograms):
7387
)
7488

7589

76-
def main():
90+
def training_v1():
91+
train_tomograms, val_tomograms, _ = get_train_val_test_split(ROOT)
92+
preprocess_labels(train_tomograms)
93+
preprocess_labels(val_tomograms)
94+
train_inner_ear_structures(train_tomograms, val_tomograms, name="inner_ear_structure_model")
95+
96+
97+
def training_v2():
7798
train_tomograms, val_tomograms, _ = get_train_val_test_split(ROOT)
7899
preprocess_labels(train_tomograms)
100+
for name in OTHER_NAMES:
101+
other_tomograms = get_other_paths(name)
102+
preprocess_labels(other_tomograms, structure_keys=("ribbons", "presynapse", "membrane"))
103+
train_tomograms.extend(other_tomograms)
79104
preprocess_labels(val_tomograms)
80-
train_inner_ear_structures(train_tomograms, val_tomograms)
105+
train_inner_ear_structures(train_tomograms, val_tomograms, name="inner_ear_structure_model_v2")
106+
107+
108+
def main():
109+
# training_v1()
110+
training_v2()
81111

82112

83113
if __name__ == "__main__":

0 commit comments

Comments
 (0)