Skip to content

Commit bbecced

Browse files
2 parents 9aba9f3 + 7657af0 commit bbecced

File tree

6 files changed

+398
-13
lines changed

6 files changed

+398
-13
lines changed
Lines changed: 242 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,242 @@
1-
# The results look very good! Scores are misleading because of the artifacts
2-
# and can be significantly improved by post-processing.
3-
# TODO wait for vesicle segmentation, then apply post-processing and evaluate
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import pandas as pd
6+
7+
from elf.evaluation.dice import dice_score
8+
from synaptic_reconstruction.inference.vesicles import segment_vesicles
9+
from synaptic_reconstruction.inference.postprocessing.ribbon import segment_ribbon
10+
from synaptic_reconstruction.inference.postprocessing.presynaptic_density import segment_presynaptic_density
11+
from torch_em.util import load_model
12+
from tqdm import tqdm
13+
14+
from train_structure_segmentation import get_train_val_test_split
15+
16+
ROOT = "/home/pape/Work/data/synaptic_reconstruction/moser"
17+
# ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser"
18+
MODEL_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/models/synaptic-reconstruction/vesicle-DA-inner_ear-v2"
19+
OUTPUT_ROOT = "./predictions"
20+
21+
22+
def run_vesicle_segmentation(input_paths, model_path, name, is_nested=False):
23+
output_root = os.path.join(OUTPUT_ROOT, name)
24+
model = None
25+
26+
for path in input_paths:
27+
root, fname = os.path.split(path)
28+
if is_nested:
29+
folder_name = os.path.split(root)[1]
30+
output_folder = os.path.join(output_root, folder_name)
31+
else:
32+
output_folder = output_root
33+
34+
os.makedirs(output_folder, exist_ok=True)
35+
output_path = os.path.join(output_folder, fname)
36+
37+
if os.path.exists(output_path):
38+
with h5py.File(output_path, "r") as f:
39+
if "vesicles" in f:
40+
continue
41+
42+
if model is None:
43+
model = load_model(model_path)
44+
45+
with h5py.File(path, "r") as f:
46+
tomogram = f["raw"][:]
47+
48+
seg = segment_vesicles(input_volume=tomogram, model=model)
49+
with h5py.File(output_path, "a") as f:
50+
f.create_dataset("vesicles", data=seg, compression="gzip")
51+
52+
53+
def postprocess_structures(paths, name, prefix=None, is_nested=False):
54+
output_root = os.path.join(OUTPUT_ROOT, name)
55+
56+
for path in tqdm(paths):
57+
root, fname = os.path.split(path)
58+
if is_nested:
59+
folder_name = os.path.split(root)[1]
60+
output_folder = os.path.join(output_root, folder_name)
61+
else:
62+
output_folder = output_root
63+
output_path = os.path.join(output_folder, fname)
64+
65+
with h5py.File(output_path, "r") as f:
66+
if prefix is None and "segmentation" in f:
67+
continue
68+
elif prefix is not None and f"{prefix}/segmentation" in f:
69+
continue
70+
71+
vesicles = f["vesicles"][:]
72+
if prefix is None:
73+
ribbon_pred = f["ribbon"][:]
74+
presyn_pred = f["PD"][:]
75+
else:
76+
ribbon_pred = f[f"{prefix}/ribbon"][:]
77+
presyn_pred = f[f"{prefix}/PD"][:]
78+
79+
# import napari
80+
# v = napari.Viewer()
81+
# v.add_image(ribbon_pred)
82+
# v.add_image(presyn_pred)
83+
# v.add_labels(vesicles)
84+
# napari.run()
85+
86+
ribbon = segment_ribbon(ribbon_pred, vesicles, n_slices_exclude=15, n_ribbons=1)
87+
presyn = segment_presynaptic_density(presyn_pred, ribbon, n_slices_exclude=15)
88+
89+
with h5py.File(output_path, "a") as f:
90+
if prefix is None:
91+
f.create_dataset("segmentation/ribbon", data=ribbon, compression="gzip")
92+
f.create_dataset("segmentation/PD", data=presyn, compression="gzip")
93+
else:
94+
f.create_dataset(f"{prefix}/segmentation/ribbon", data=ribbon, compression="gzip")
95+
f.create_dataset(f"{prefix}/segmentation/PD", data=presyn, compression="gzip")
96+
97+
98+
def visualize(input_paths, name, is_nested=False, label_names=None, prefixes=None):
99+
import napari
100+
101+
structure_names = ["ribbon", "PD"]
102+
if label_names is None:
103+
label_names = structure_names
104+
105+
output_root = os.path.join(OUTPUT_ROOT, name)
106+
for path in input_paths:
107+
root, fname = os.path.split(path)
108+
if is_nested:
109+
folder_name = os.path.split(root)[1]
110+
output_folder = os.path.join(output_root, folder_name)
111+
else:
112+
output_folder = output_root
113+
output_path = os.path.join(output_folder, fname)
114+
115+
labels = {}
116+
with h5py.File(path, "r") as f:
117+
raw = f["raw"][:]
118+
for name, sname in zip(label_names, structure_names):
119+
labels[name] = f[f"labels/{name}"][:]
120+
121+
predictions = {}
122+
with h5py.File(output_path, "r") as f:
123+
if prefixes is None:
124+
for name in structure_names:
125+
predictions[name] = f[f"segmentation/{name}"][:]
126+
else:
127+
for prefix in prefixes:
128+
for name in structure_names:
129+
predictions[f"{prefix}/{name}"] = f[f"{prefix}/segmentation/{name}"][:]
130+
131+
v = napari.Viewer()
132+
v.add_image(raw)
133+
for name, seg in labels.items():
134+
v.add_labels(seg, name=f"labels/{name}", visible=False)
135+
for name, seg in predictions.items():
136+
if name == "ribbon":
137+
cmap = {1: "orange"}
138+
else:
139+
cmap = {1: "green"}
140+
v.add_labels(seg, name=name, colormap=cmap)
141+
v.title = fname
142+
napari.run()
143+
144+
145+
def evaluate(input_paths, name, is_nested=False, prefix=None, save_path=None, label_names=None):
146+
if save_path is not None and os.path.exists(save_path):
147+
return pd.read_csv(save_path)
148+
149+
structure_names = ["ribbon", "PD"]
150+
if label_names is None:
151+
label_names = structure_names
152+
output_root = os.path.join(OUTPUT_ROOT, name)
153+
154+
results = {
155+
"method": [],
156+
"file_name": [],
157+
}
158+
results.update({nn: [] for nn in structure_names})
159+
for path in tqdm(input_paths, desc="Run evaluation"):
160+
root, fname = os.path.split(path)
161+
if is_nested:
162+
folder_name = os.path.split(root)[1]
163+
output_folder = os.path.join(output_root, folder_name)
164+
else:
165+
output_folder = output_root
166+
output_path = os.path.join(output_folder, fname)
167+
168+
results["method"].append("Src" if prefix is None else prefix)
169+
results["file_name"].append(f"{folder_name}/{fname}" if is_nested else fname)
170+
171+
with h5py.File(path, "r") as f_in, h5py.File(output_path, "r") as f_out:
172+
for sname, label_name in zip(structure_names, label_names):
173+
gt = f_in[f"labels/{label_name}"][:]
174+
pred = f_out[f"segmentation/{sname}" if prefix is None else f"{prefix}/segmentation/{sname}"][:]
175+
score = dice_score(pred, gt)
176+
results[sname].append(score)
177+
178+
results = pd.DataFrame(results)
179+
if save_path is not None:
180+
results.to_csv(save_path, index=False)
181+
return results
182+
183+
184+
def segment_train_domain():
185+
_, _, paths = get_train_val_test_split(os.path.join(ROOT, "inner_ear_data"))
186+
print("Run evaluation on", len(paths), "tomos")
187+
name = "train_domain"
188+
run_vesicle_segmentation(paths, MODEL_PATH, name, is_nested=True)
189+
postprocess_structures(paths, name, is_nested=True)
190+
visualize(paths, name, is_nested=True)
191+
results = evaluate(paths, name, is_nested=True, save_path="./results/train_domain_postprocessed.csv")
192+
print(results)
193+
print("Ribbon segmentation:", results["ribbon"].mean(), "+-", results["ribbon"].std())
194+
print("PD segmentation:", results["PD"].mean(), "+-", results["PD"].std())
195+
196+
197+
def segment_vesicle_pools():
198+
paths = sorted(glob(os.path.join(ROOT, "other_tomograms/01_vesicle_pools", "*.h5")))
199+
run_vesicle_segmentation(paths, MODEL_PATH, "vesicle_pools")
200+
201+
name = "vesicle_pools"
202+
prefixes = ("Src", "Adapted")
203+
label_names = ["ribbons", "presynapse", "membrane"]
204+
205+
for prefix in prefixes:
206+
postprocess_structures(paths, name, prefix=prefix, is_nested=False)
207+
208+
save_path = f"./results/{name}_{prefix}.csv"
209+
results = evaluate(paths, name, prefix=prefix, save_path=save_path, label_names=label_names)
210+
print("Results for", name, prefix, ":")
211+
print(results)
212+
213+
# visualize(paths, name, label_names=label_names, prefixes=prefixes)
214+
215+
216+
def segment_rat():
217+
paths = sorted(glob(os.path.join(ROOT, "other_tomograms/03_ratten_tomos", "*.h5")))
218+
run_vesicle_segmentation(paths, MODEL_PATH, "rat")
219+
220+
name = "rat"
221+
prefixes = ("Src", "Adapted")
222+
label_names = ["ribbons", "presynapse", "membrane"]
223+
224+
for prefix in prefixes:
225+
postprocess_structures(paths, name, prefix=prefix, is_nested=False)
226+
227+
save_path = f"./results/{name}_{prefix}.csv"
228+
results = evaluate(paths, name, prefix=prefix, save_path=save_path, label_names=label_names)
229+
print("Results for", name, prefix, ":")
230+
print(results)
231+
232+
# visualize(paths, name, label_names=label_names, prefixes=prefixes)
233+
234+
235+
def main():
236+
segment_train_domain()
237+
# segment_vesicle_pools()
238+
# segment_rat()
239+
240+
241+
if __name__ == "__main__":
242+
main()

scripts/inner_ear/training/structure_prediction_and_evaluation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from train_structure_segmentation import get_train_val_test_split
1212
from train_structure_segmentation import noop # noqa
1313

14-
# ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser"
15-
ROOT = "/home/pape/Work/data/synaptic_reconstruction/moser"
14+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser"
15+
# ROOT = "/home/pape/Work/data/synaptic_reconstruction/moser"
1616
OUTPUT_ROOT = "./predictions"
1717

1818

@@ -176,13 +176,13 @@ def predict_and_evaluate_target_domain(paths, name, adapted_model_path):
176176

177177
def predict_and_evaluate_vesicle_pools():
178178
paths = sorted(glob(os.path.join(ROOT, "other_tomograms/01_vesicle_pools", "*.h5")))
179-
adapted_model_path = "./checkpoints/structure-model-adapt-vesicle_pools"
179+
adapted_model_path = "./checkpoints/structure-model-adapt-vesicle_pools-v2"
180180
predict_and_evaluate_target_domain(paths, "vesicle_pools", adapted_model_path)
181181

182182

183183
def predict_and_evaluate_rat():
184184
paths = sorted(glob(os.path.join(ROOT, "other_tomograms/03_ratten_tomos", "*.h5")))
185-
adapted_model_path = "./checkpoints/structure-model-adapt-rat"
185+
adapted_model_path = "./checkpoints/structure-model-adapt-rat-v2"
186186
predict_and_evaluate_target_domain(paths, "rat", adapted_model_path)
187187

188188

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
5+
def summarize_source_domain():
6+
result_path = "./results/train_domain_postprocessed.csv"
7+
results = pd.read_csv(result_path)
8+
9+
ribbon_results = {
10+
"dataset": [],
11+
"source_model": [],
12+
"target_model": [],
13+
}
14+
ribbon_mean = np.round(results["ribbon"].mean() * 100, 2)
15+
ribbon_std = np.round(results["ribbon"].std() * 100, 2)
16+
ribbon_results["dataset"].append("source")
17+
ribbon_results["source_model"].append(f"{ribbon_mean} +- {ribbon_std}")
18+
ribbon_results["target_model"].append("")
19+
ribbon_results = pd.DataFrame(ribbon_results)
20+
21+
PD_results = {
22+
"dataset": [],
23+
"source_model": [],
24+
"target_model": [],
25+
}
26+
PD_mean = np.round(results["PD"].mean() * 100, 2)
27+
PD_std = np.round(results["PD"].std() * 100, 2)
28+
PD_results["dataset"].append("source")
29+
PD_results["source_model"].append(f"{PD_mean} +- {PD_std}")
30+
PD_results["target_model"].append("")
31+
PD_results = pd.DataFrame(PD_results)
32+
33+
return ribbon_results, PD_results
34+
35+
36+
def summarize_rat():
37+
ribbon_results = {
38+
"dataset": [],
39+
"source_model": [],
40+
"target_model": [],
41+
}
42+
PD_results = {
43+
"dataset": [],
44+
"source_model": [],
45+
"target_model": [],
46+
}
47+
48+
result_paths = {
49+
"source_model": "results/rat_Src.csv",
50+
"target_model": "results/rat_Adapted.csv",
51+
}
52+
53+
ribbon_results["dataset"].append("source")
54+
PD_results["dataset"].append("source")
55+
56+
for model, result_path in result_paths.items():
57+
results = pd.read_csv(result_path)
58+
ribbon_mean = np.round(results["ribbon"].mean() * 100, 2)
59+
ribbon_std = np.round(results["ribbon"].std() * 100, 2)
60+
ribbon_results[model].append(f"{ribbon_mean} +- {ribbon_std}")
61+
62+
PD_mean = np.round(results["PD"].mean() * 100, 2)
63+
PD_std = np.round(results["PD"].std() * 100, 2)
64+
PD_results[model].append(f"{PD_mean} +- {PD_std}")
65+
66+
ribbon_results = pd.DataFrame(ribbon_results)
67+
PD_results = pd.DataFrame(PD_results)
68+
return ribbon_results, PD_results
69+
70+
71+
def summarize_ves_pool():
72+
ribbon_results = {
73+
"dataset": [],
74+
"source_model": [],
75+
"target_model": [],
76+
}
77+
PD_results = {
78+
"dataset": [],
79+
"source_model": [],
80+
"target_model": [],
81+
}
82+
83+
result_paths = {
84+
"source_model": "results/vesicle_pools_Src.csv",
85+
"target_model": "results/vesicle_pools_Adapted.csv",
86+
}
87+
88+
ribbon_results["dataset"].append("source")
89+
PD_results["dataset"].append("source")
90+
91+
for model, result_path in result_paths.items():
92+
results = pd.read_csv(result_path)
93+
ribbon_mean = np.round(results["ribbon"].mean() * 100, 2)
94+
ribbon_std = np.round(results["ribbon"].std() * 100, 2)
95+
ribbon_results[model].append(f"{ribbon_mean} +- {ribbon_std}")
96+
97+
PD_mean = np.round(results["PD"].mean() * 100, 2)
98+
PD_std = np.round(results["PD"].std() * 100, 2)
99+
PD_results[model].append(f"{PD_mean} +- {PD_std}")
100+
101+
ribbon_results = pd.DataFrame(ribbon_results)
102+
PD_results = pd.DataFrame(PD_results)
103+
return ribbon_results, PD_results
104+
105+
106+
def main():
107+
ribbon_results, PD_results = summarize_source_domain()
108+
# ribbon_results, PD_results = summarize_ves_pool()
109+
print("Ribbon")
110+
print(ribbon_results)
111+
print("PD")
112+
print(PD_results)
113+
114+
115+
if __name__ == "__main__":
116+
main()

0 commit comments

Comments
 (0)