Skip to content

Commit 1e25907

Browse files
Merge pull request #131 from computational-cell-analytics/revision3
Add new AZ model
2 parents 928f330 + 0ec2a50 commit 1e25907

File tree

4 files changed

+116
-5
lines changed

4 files changed

+116
-5
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import pandas as pd
5+
6+
from synapse_net.inference.inference import get_model
7+
from synapse_net.inference.compartments import segment_compartments
8+
from skimage.segmentation import find_boundaries
9+
10+
from elf.evaluation.matching import matching
11+
12+
from train_compartments import get_paths_3d
13+
from sklearn.model_selection import train_test_split
14+
15+
16+
def run_prediction(paths):
17+
output_folder = "./compartment_eval"
18+
os.makedirs(output_folder, exist_ok=True)
19+
20+
model = get_model("compartments")
21+
for path in paths:
22+
with h5py.File(path, "r") as f:
23+
input_vol = f["raw"][:]
24+
seg, pred = segment_compartments(input_vol, model=model, return_predictions=True)
25+
fname = os.path.basename(path)
26+
out = os.path.join(output_folder, fname)
27+
with h5py.File(out, "a") as f:
28+
f.create_dataset("seg", data=seg, compression="gzip")
29+
f.create_dataset("pred", data=pred, compression="gzip")
30+
31+
32+
def binary_recall(gt, pred):
33+
tp = np.logical_and(gt, pred).sum()
34+
fn = np.logical_and(gt, ~pred).sum()
35+
return float(tp) / (tp + fn) if (tp + fn) else 0.0
36+
37+
38+
def run_evaluation(paths):
39+
output_folder = "./compartment_eval"
40+
41+
results = {
42+
"name": [],
43+
"recall-pred": [],
44+
"recall-seg": [],
45+
}
46+
47+
for path in paths:
48+
with h5py.File(path, "r") as f:
49+
labels = f["labels/compartments"][:]
50+
boundary_labels = find_boundaries(labels).astype("bool")
51+
52+
fname = os.path.basename(path)
53+
out = os.path.join(output_folder, fname)
54+
with h5py.File(out, "a") as f:
55+
seg, pred = f["seg"][:], f["pred"][:]
56+
57+
recall_pred = binary_recall(boundary_labels, pred > 0.5)
58+
recall_seg = matching(seg, labels)["recall"]
59+
60+
results["name"].append(fname)
61+
results["recall-pred"].append(recall_pred)
62+
results["recall-seg"].append(recall_seg)
63+
64+
results = pd.DataFrame(results)
65+
print(results)
66+
print(results[["recall-pred", "recall-seg"]].mean())
67+
68+
69+
def check_predictions(paths):
70+
import napari
71+
output_folder = "./compartment_eval"
72+
73+
for path in paths:
74+
with h5py.File(path, "r") as f:
75+
raw = f["raw"][:]
76+
labels = f["labels/compartments"][:]
77+
boundary_labels = find_boundaries(labels)
78+
79+
fname = os.path.basename(path)
80+
out = os.path.join(output_folder, fname)
81+
with h5py.File(out, "a") as f:
82+
seg, pred = f["seg"][:], f["pred"][:]
83+
84+
v = napari.Viewer()
85+
v.add_image(raw)
86+
v.add_image(pred)
87+
v.add_labels(labels)
88+
v.add_labels(boundary_labels)
89+
v.add_labels(seg)
90+
napari.run()
91+
92+
93+
def main():
94+
paths = get_paths_3d()
95+
_, val_paths = train_test_split(paths, test_size=0.10, random_state=42)
96+
97+
# run_prediction(val_paths)
98+
run_evaluation(val_paths)
99+
# check_predictions(val_paths)
100+
101+
102+
if __name__ == "__main__":
103+
main()

scripts/cooper/training/train_compartments.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from synapse_net.training import supervised_training
1515

1616
TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/compartments"
17-
# TRAIN_ROOT = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/cooper/ground_truth/compartments/output/compartment_gt" # noqa
1817

1918

2019
def get_paths_2d():

synapse_net/inference/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
def _get_model_registry():
2424
registry = {
25-
"active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0",
25+
"active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a",
2626
"compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
2727
"mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
2828
"mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
@@ -37,7 +37,7 @@ def _get_model_registry():
3737
"vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d",
3838
}
3939
urls = {
40-
"active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download",
40+
"active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download",
4141
"compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
4242
"mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
4343
"mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
@@ -109,7 +109,7 @@ def get_model_training_resolution(model_type: str) -> Dict[str, float]:
109109
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
110110
"""
111111
resolutions = {
112-
"active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
112+
"active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
113113
"compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
114114
"mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
115115
"cristae": {"x": 1.44, "y": 1.44, "z": 1.44},

synapse_net/tools/cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import argparse
2+
import os
23
from functools import partial
34

45
import torch
6+
import torch_em
57
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
68
from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
79
from ..inference.util import inference_helper, parse_tiling
@@ -155,7 +157,14 @@ def segmentation_cli():
155157
if args.checkpoint is None:
156158
model = get_model(args.model)
157159
else:
158-
model = torch.load(args.checkpoint, weights_only=False)
160+
checkpoint_path = args.checkpoint
161+
if checkpoint_path.endswith("best.pt"):
162+
checkpoint_path = os.path.split(checkpoint_path)[0]
163+
164+
if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint.
165+
model = torch_em.util.load_model(checkpoint=checkpoint_path)
166+
else:
167+
model = torch.load(checkpoint_path, weights_only=False)
159168
assert model is not None, f"The model from {args.checkpoint} could not be loaded."
160169

161170
is_2d = "2d" in args.model

0 commit comments

Comments
 (0)