Skip to content

Commit a0c31a8

Browse files
Fix issue in data aggregation
1 parent cb693b1 commit a0c31a8

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

scripts/aggregate_data_information.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,24 @@
1212
stem = "STEM"
1313

1414

15-
def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions):
15+
def aggregate_vesicle_train_data(roots, conditions, resolutions):
1616
tomo_names = []
1717
tomo_vesicles_all, tomo_vesicles_imod = [], []
1818
tomo_condition = []
1919
tomo_resolution = []
2020
tomo_train = []
2121

22-
for ds, root in roots.items():
23-
print("Aggregate data for", ds)
24-
train_root = root["train"]
25-
if train_root == "":
26-
test_root = root["test"]
27-
tomograms = sorted(glob(os.path.join(test_root, "2024**", "*.h5"), recursive=True))
28-
this_test_tomograms = [os.path.basename(tomo) for tomo in tomograms]
22+
def aggregate_split(ds, split_root, split):
23+
if ds.startswith("04"):
24+
tomograms = sorted(glob(os.path.join(split_root, "2024**", "*.h5"), recursive=True))
2925
else:
30-
# This is only the case for 04, which is also nested
31-
tomograms = sorted(glob(os.path.join(train_root, "*.h5")))
32-
this_test_tomograms = test_tomograms[ds]
26+
tomograms = sorted(glob(os.path.join(split_root, "*.h5")))
3327

3428
assert len(tomograms) > 0, ds
3529
this_condition = conditions[ds]
3630
this_resolution = resolutions[ds][0]
3731

38-
for tomo_path in tqdm(tomograms):
32+
for tomo_path in tqdm(tomograms, desc=f"Aggregate {split}"):
3933
fname = os.path.basename(tomo_path)
4034
with h5py.File(tomo_path, "r") as f:
4135
try:
@@ -58,7 +52,16 @@ def aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions)
5852
tomo_vesicles_imod.append(n_vesicles_imod)
5953
tomo_condition.append(this_condition)
6054
tomo_resolution.append(this_resolution)
61-
tomo_train.append("test" if fname in this_test_tomograms else "train/val")
55+
tomo_train.append(split)
56+
57+
for ds, root in roots.items():
58+
print("Aggregate data for", ds)
59+
train_root = root["train"]
60+
if train_root != "":
61+
aggregate_split(ds, train_root, "train/val")
62+
test_root = root["test"]
63+
if test_root != "":
64+
aggregate_split(ds, test_root, "test")
6265

6366
df = pd.DataFrame({
6467
"tomogram": tomo_names,
@@ -117,19 +120,6 @@ def vesicle_train_data():
117120
},
118121
}
119122

120-
test_tomograms = {
121-
"01": ["tomogram-009.h5", "tomogram-038.h5", "tomogram-049.h5", "tomogram-052.h5", "tomogram-057.h5", "tomogram-060.h5", "tomogram-067.h5", "tomogram-074.h5", "tomogram-076.h5", "tomogram-083.h5", "tomogram-133.h5", "tomogram-136.h5", "tomogram-145.h5", "tomogram-149.h5", "tomogram-150.h5"], # noqa
122-
"02": ["tomogram-004.h5", "tomogram-008.h5"],
123-
"03": ["tomogram-003.h5", "tomogram-004.h5", "tomogram-008.h5",],
124-
"04": [], # all used for test
125-
"05": ["tomogram-003.h5", "tomogram-005.h5",],
126-
"07": ["tomogram-006.h5", "tomogram-017.h5",],
127-
"09": [], # no test data
128-
"10": ["tomogram-001.h5", "tomogram-002.h5", "tomogram-007.h5"],
129-
"11": ["tomogram-001.h5 tomogram-007.h5 tomogram-008.h5"],
130-
"12": ["tomogram-004.h5", "tomogram-021.h5", "tomogram-022.h5",],
131-
}
132-
133123
conditions = {
134124
"01": single_ax_tem,
135125
"02": dual_ax_tem,
@@ -156,7 +146,7 @@ def vesicle_train_data():
156146
"12": (1.554, 1.554, 1.554)
157147
}
158148

159-
aggregate_vesicle_train_data(roots, test_tomograms, conditions, resolutions)
149+
aggregate_vesicle_train_data(roots, conditions, resolutions)
160150

161151

162152
def aggregate_az_train_data(roots, test_tomograms, conditions, resolutions):
913 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)