Skip to content

Commit 74c3e69

Browse files
Merge pull request #46 from computational-cell-analytics/update-comp-seg
Update compartment segmentation and add code for full vesicle reconstruction
2 parents bc981e2 + 59e01dd commit 74c3e69

File tree

12 files changed

+556
-65
lines changed

12 files changed

+556
-65
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
from tqdm import tqdm
6+
7+
8+
INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04Dataset_for_vesicle_eval/model_segmentation" # noqa
9+
OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction"
10+
11+
12+
def assort_az_and_vesicles(in_path, out_path):
13+
if os.path.exists(out_path):
14+
return
15+
16+
with h5py.File(in_path, "r") as f:
17+
raw = f["raw"][:]
18+
vesicles = f["/vesicles/segment_from_combined_vesicles"][:]
19+
az = f["/AZ/segment_from_AZmodel_v3"][:]
20+
21+
os.makedirs(os.path.split(out_path)[0], exist_ok=True)
22+
with h5py.File(out_path, "a") as f:
23+
f.create_dataset("raw", data=raw, compression="gzip")
24+
f.create_dataset("labels/vesicles", data=vesicles, compression="gzip")
25+
f.create_dataset("labels/active_zone", data=az, compression="gzip")
26+
27+
28+
def main():
29+
paths = sorted(glob(os.path.join(INPUT_ROOT, "**/*.h5"), recursive=True))
30+
for path in tqdm(paths):
31+
fname = os.path.relpath(path, INPUT_ROOT)
32+
out_path = os.path.join(OUTPUT_ROOT, fname)
33+
assort_az_and_vesicles(path, out_path)
34+
35+
36+
if __name__ == "__main__":
37+
main()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
from synaptic_reconstruction.inference.compartments import segment_compartments
6+
from tqdm import tqdm
7+
8+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction" # noqa
9+
MODEL_PATH = "/user/pape41/u12086/Work/my_projects/synaptic-reconstruction/scripts/cooper/training/checkpoints/compartment_model_3d/v2" # noqa
10+
11+
12+
def label_transform_3d():
13+
pass
14+
15+
16+
def run_seg(path):
17+
with h5py.File(path, "r") as f:
18+
if "labels/compartments" in f:
19+
return
20+
raw = f["raw"][:]
21+
22+
scale = (0.25, 0.25, 0.25)
23+
seg = segment_compartments(raw, model_path=MODEL_PATH, scale=scale, verbose=False)
24+
with h5py.File(path, "a") as f:
25+
f.create_dataset("labels/compartments", data=seg, compression="gzip")
26+
27+
28+
def main():
29+
paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True))
30+
for path in tqdm(paths):
31+
run_seg(path)
32+
33+
34+
main()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
from synaptic_reconstruction.inference.mitochondria import segment_mitochondria
6+
from tqdm import tqdm
7+
8+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction" # noqa
9+
MODEL_PATH = "/scratch-grete/projects/nim00007/models/exports_for_cooper/mito_model_s2.pt" # noqa
10+
11+
12+
def run_seg(path):
13+
with h5py.File(path, "r") as f:
14+
if "labels/mitochondria" in f:
15+
return
16+
raw = f["raw"][:]
17+
18+
scale = (0.5, 0.5, 0.5)
19+
seg = segment_mitochondria(raw, model_path=MODEL_PATH, scale=scale, verbose=False)
20+
with h5py.File(path, "a") as f:
21+
f.create_dataset("labels/mitochondria", data=seg, compression="gzip")
22+
23+
24+
def main():
25+
paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True))
26+
for path in tqdm(paths):
27+
run_seg(path)
28+
29+
30+
main()
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import napari
6+
import numpy as np
7+
import pandas as pd
8+
9+
ROOT = "./04_full_reconstruction"
10+
TABLE = "/home/pape/Desktop/sfb1286/mboc_synapse/draft_figures/full_reconstruction.xlsx"
11+
12+
# Skip datasets for which all figures were already done.
13+
SKIP_DS = ["20241019_Tomo-eval_MF_Synapse"]
14+
15+
16+
def _get_name_and_row(path, table):
17+
ds_name, name = os.path.split(path)
18+
ds_name = os.path.split(ds_name)[1]
19+
row = table[(table["dataset"] == ds_name) & (table["tomogram"] == name)]
20+
return ds_name, name, row
21+
22+
23+
def _get_compartment_ids(row):
24+
compartment_ids = []
25+
for comp in ("Compartment 1", "Compartment 2", "Compartment 3", "Compartment 4"):
26+
comp_ids = row[comp].values[0]
27+
try:
28+
comp_ids = list(map(int, comp_ids.split(", ")))
29+
except AttributeError:
30+
pass
31+
32+
if np.isnan(comp_ids).all():
33+
compartment_ids.append(None)
34+
continue
35+
36+
if isinstance(comp_ids, int):
37+
comp_ids = [comp_ids]
38+
compartment_ids.append(comp_ids)
39+
40+
return compartment_ids
41+
42+
43+
def visualize_result(path, table):
44+
ds_name, name, row = _get_name_and_row(path, table)
45+
46+
if ds_name in SKIP_DS:
47+
return
48+
49+
# if row["Use for vis"].values[0] == "yes":
50+
if row["Use for vis"].values[0] in ("yes", "no"):
51+
return
52+
compartment_ids = _get_compartment_ids(row)
53+
54+
# access = np.s_[:]
55+
access = np.s_[::2, ::2, ::2]
56+
57+
with h5py.File(path, "r") as f:
58+
raw = f["raw"][access]
59+
vesicles = f["labels/vesicles"][access]
60+
active_zone = f["labels/active_zone"][access]
61+
mitos = f["labels/mitochondria"][access]
62+
compartments = f["labels/compartments"][access]
63+
64+
if any(comp_ids is not None for comp_ids in compartment_ids):
65+
mask = np.zeros(raw.shape, dtype="bool")
66+
compartments_new = np.zeros_like(compartments)
67+
68+
print("Filtering compartments:")
69+
for i, comp_ids in enumerate(compartment_ids, 1):
70+
if comp_ids is None:
71+
continue
72+
print(i, comp_ids)
73+
this_mask = np.isin(compartments, comp_ids)
74+
mask[this_mask] = 1
75+
compartments_new[this_mask] = i
76+
77+
vesicles[~mask] = 0
78+
mitos[~mask] = 0
79+
compartments = compartments_new
80+
81+
v = napari.Viewer()
82+
v.add_image(raw)
83+
v.add_labels(mitos)
84+
v.add_labels(vesicles)
85+
v.add_labels(compartments)
86+
v.add_labels(active_zone)
87+
v.title = f"{ds_name}/{name}"
88+
napari.run()
89+
90+
91+
def visualize_only_compartment(path, table):
92+
ds_name, name, row = _get_name_and_row(path, table)
93+
compartment_ids = _get_compartment_ids(row)
94+
95+
# Skip if we already have annotated the presynapse compartment(s)
96+
if any(comp_id is not None for comp_id in compartment_ids):
97+
print("Compartments already annotated for", ds_name, name)
98+
return
99+
100+
# access = np.s_[:]
101+
access = np.s_[::2, ::2, ::2]
102+
103+
with h5py.File(path, "r") as f:
104+
raw = f["raw"][access]
105+
compartments = f["labels/compartments"][access]
106+
107+
v = napari.Viewer()
108+
v.add_image(raw)
109+
v.add_labels(compartments)
110+
v.title = f"{ds_name}/{name}"
111+
napari.run()
112+
113+
114+
def main():
115+
paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True))
116+
table = pd.read_excel(TABLE)
117+
for path in paths:
118+
visualize_result(path, table)
119+
# visualize_only_compartment(path, table)
120+
121+
122+
if __name__ == "__main__":
123+
main()
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import napari
6+
7+
from synaptic_reconstruction.inference.compartments import _segment_compartments_3d
8+
9+
10+
def check_pred(path, pred_path, name):
11+
with h5py.File(path, "r") as f:
12+
raw = f["raw"][:]
13+
# seg = f["labels/compartments"][:]
14+
15+
with h5py.File(pred_path, "r") as f:
16+
pred = f["prediction"][:]
17+
18+
print("Run segmentation ...")
19+
seg_new = _segment_compartments_3d(pred)
20+
print("done")
21+
22+
v = napari.Viewer()
23+
v.add_image(raw)
24+
v.add_image(pred, visible=False)
25+
# v.add_labels(seg, visible=False)
26+
v.add_labels(seg_new)
27+
v.title = name
28+
napari.run()
29+
30+
31+
def main():
32+
seg_paths = sorted(glob("./predictions/segmentation/**/*.h5", recursive=True))
33+
34+
for seg_path in seg_paths:
35+
ds_name, fname = os.path.split(seg_path)
36+
ds_name = os.path.split(ds_name)[1]
37+
38+
# if ds_name in ("20241019_Tomo-eval_MF_Synapse", "20241019_Tomo-eval_PS_Synapse"):
39+
# continue
40+
41+
name = f"{ds_name}/{fname}"
42+
pred_path = os.path.join("./predictions/prediction", ds_name, fname)
43+
assert os.path.exists(pred_path), pred_path
44+
check_pred(seg_path, pred_path, name)
45+
46+
47+
main()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import napari
6+
from magicgui import magicgui
7+
8+
9+
def run_annotation(input_path, output_path):
10+
with h5py.File(input_path, "r") as f:
11+
raw = f["raw"][:]
12+
seg = f["labels/compartments"][:]
13+
14+
v = napari.Viewer()
15+
v.add_image(raw)
16+
v.add_labels(seg)
17+
18+
@magicgui(call_button="Save Annotations")
19+
def save_annotations():
20+
seg = v.layers["seg"].data
21+
22+
if os.path.exists(output_path):
23+
with h5py.File(output_path, "a") as f:
24+
f["labels/compartments"][:] = seg
25+
else:
26+
with h5py.File(output_path, "a") as f:
27+
f.create_dataset("raw", data=raw, compression="gzip")
28+
f.create_dataset("labels/compartments", data=seg, compression="gzip")
29+
30+
v.window.add_dock_widget(save_annotations)
31+
32+
napari.run()
33+
34+
35+
def main():
36+
inputs = sorted(glob("./predictions/**/*.h5", recursive=True))
37+
38+
output_folder = "./annotations"
39+
40+
for input_path in inputs:
41+
ds_name, fname = os.path.split(input_path)
42+
ds_name = os.path.split(ds_name)[1]
43+
ds_folder = os.path.join(output_folder, ds_name)
44+
output_path = os.path.join(ds_folder, fname)
45+
46+
if os.path.exists(output_path):
47+
print("Skipping annotations for", output_path)
48+
continue
49+
50+
os.makedirs(ds_folder, exist_ok=True)
51+
run_annotation(input_path, output_path)
52+
53+
54+
if __name__ == "__main__":
55+
main()
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
from tqdm import tqdm
6+
7+
from synaptic_reconstruction.inference.util import _Scaler
8+
from synaptic_reconstruction.inference.compartments import segment_compartments
9+
10+
INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval" # noqa
11+
# MODEL_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/compartment_models/compartment_model_3d.pt"
12+
MODEL_PATH = "/user/pape41/u12086/Work/my_projects/synaptic-reconstruction/scripts/cooper/training/checkpoints/compartment_model_3d/v2" # noqa
13+
OUTPUT = "/mnt/lustre-emmy-hdd/projects/nim00007/compartment_predictions"
14+
15+
16+
def label_transform_3d():
17+
pass
18+
19+
20+
def segment_volume(input_path, model_path):
21+
with h5py.File(input_path, "r") as f:
22+
raw = f["raw"][:]
23+
24+
scale = (0.25, 0.25, 0.25)
25+
scaler = _Scaler(scale, verbose=False)
26+
raw = scaler.scale_input(raw)
27+
28+
n_slices_exclude = 2
29+
seg, pred = segment_compartments(
30+
raw, model_path, verbose=False, n_slices_exclude=n_slices_exclude, return_predictions=True
31+
)
32+
# raw, seg = raw[n_slices_exclude:-n_slices_exclude], seg[n_slices_exclude:-n_slices_exclude]
33+
34+
return raw, seg, pred
35+
36+
37+
def main():
38+
inputs = sorted(glob(os.path.join(INPUT_ROOT, "**/*.h5"), recursive=True))
39+
inputs = [inp for inp in inputs if "cropped_for_2D" not in inp]
40+
41+
for input_path in tqdm(inputs, desc="Run prediction for 04."):
42+
ds_name, fname = os.path.split(input_path)
43+
ds_name = os.path.split(ds_name)[1]
44+
output_folder = os.path.join(OUTPUT, "segmentation", ds_name)
45+
output_path = os.path.join(output_folder, fname)
46+
47+
if os.path.exists(output_path):
48+
continue
49+
50+
pred_folder = os.path.join(OUTPUT, "prediction", ds_name)
51+
os.makedirs(pred_folder, exist_ok=True)
52+
pred_path = os.path.join(pred_folder, fname)
53+
54+
raw, seg, pred = segment_volume(input_path, MODEL_PATH)
55+
os.makedirs(output_folder, exist_ok=True)
56+
with h5py.File(output_path, "a") as f:
57+
f.create_dataset("raw", data=raw, compression="gzip")
58+
f.create_dataset("labels/compartments", data=seg, compression="gzip")
59+
with h5py.File(pred_path, "a") as f:
60+
f.create_dataset("prediction", data=pred, compression="gzip")
61+
62+
63+
if __name__ == "__main__":
64+
main()

0 commit comments

Comments
 (0)