10
10
from tqdm import tqdm
11
11
12
12
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/inner_ear_data"
13
+ ROOT_OTHER_TOMOS = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/other_tomograms/"
14
+
13
15
LABEL_KEY = "labels/inner_ear_structures"
16
+ OTHER_NAMES = ["vesicle_pools" , "tether" , "rat" ]
17
+
18
+
19
+ def get_other_paths (name ):
20
+ assert name in OTHER_NAMES , f"Invalid name { name } "
21
+ if name == "vesicle_pools" :
22
+ folder = "01_vesicle_pools"
23
+ elif name == "tether" :
24
+ folder = "02_tether"
25
+ else :
26
+ folder = "03_ratten_tomos"
27
+ paths = sorted (glob (os .path .join (ROOT , folder , "*.h5" )))
28
+ return paths
14
29
15
30
16
31
def get_train_val_test_split (root ):
@@ -34,8 +49,7 @@ def get_train_val_test_split(root):
34
49
return train_tomos , val_tomos , test_tomos
35
50
36
51
37
- def preprocess_labels (tomograms ):
38
- structure_keys = ("ribbon" , "PD" , "membrane" )
52
+ def preprocess_labels (tomograms , structure_keys = ("ribbon" , "PD" , "membrane" )):
39
53
nc = len (structure_keys )
40
54
41
55
for tomo in tqdm (tomograms , desc = "Preprocess labels" ):
@@ -60,11 +74,11 @@ def noop(x):
60
74
return x
61
75
62
76
63
- def train_inner_ear_structures (train_tomograms , val_tomograms ):
77
+ def train_inner_ear_structures (train_tomograms , val_tomograms , name ):
64
78
patch_shape = (64 , 512 , 512 )
65
79
sampler = MinForegroundSampler (min_fraction = 0.05 , p_reject = 0.95 )
66
80
supervised_training (
67
- name = "inner_ear_structure_model" ,
81
+ name = name ,
68
82
train_paths = train_tomograms , val_paths = val_tomograms ,
69
83
label_key = LABEL_KEY , patch_shape = patch_shape , save_root = "." ,
70
84
sampler = sampler , label_transform = noop , out_channels = 3 ,
@@ -73,11 +87,27 @@ def train_inner_ear_structures(train_tomograms, val_tomograms):
73
87
)
74
88
75
89
76
- def main ():
90
+ def training_v1 ():
91
+ train_tomograms , val_tomograms , _ = get_train_val_test_split (ROOT )
92
+ preprocess_labels (train_tomograms )
93
+ preprocess_labels (val_tomograms )
94
+ train_inner_ear_structures (train_tomograms , val_tomograms , name = "inner_ear_structure_model" )
95
+
96
+
97
+ def training_v2 ():
77
98
train_tomograms , val_tomograms , _ = get_train_val_test_split (ROOT )
78
99
preprocess_labels (train_tomograms )
100
+ for name in OTHER_NAMES :
101
+ other_tomograms = get_other_paths (name )
102
+ preprocess_labels (other_tomograms , structure_keys = ("ribbons" , "presynapse" , "membrane" ))
103
+ train_tomograms .extend (other_tomograms )
79
104
preprocess_labels (val_tomograms )
80
- train_inner_ear_structures (train_tomograms , val_tomograms )
105
+ train_inner_ear_structures (train_tomograms , val_tomograms , name = "inner_ear_structure_model_v2" )
106
+
107
+
108
+ def main ():
109
+ # training_v1()
110
+ training_v2 ()
81
111
82
112
83
113
if __name__ == "__main__" :
0 commit comments