Skip to content

Commit bc981e2

Browse files
Finalize CryoVesNet evaluation
1 parent c35b04d commit bc981e2

File tree

2 files changed

+52
-13
lines changed

2 files changed

+52
-13
lines changed
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
# CryoVesNet
22

3-
Scripts to run CryoVesNet on our data. See https://github.com/Zuber-group/CryoVesNet for details.
4-
5-
The code is currently not working due to this issue: https://github.com/Zuber-group/CryoVesNet/issues/6
3+
Scripts to run CryoVesNet on our data and evalute the results. See https://github.com/Zuber-group/CryoVesNet for details.

scripts/baselines/cryo_ves_net/evaluate_cooper.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import json
12
import os
23
from glob import glob
34

45
import h5py
56
import numpy as np
67
import pandas as pd
78
from elf.evaluation.matching import matching
9+
from tqdm import tqdm
810

911
INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets" # noqa
12+
INPUT_04 = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval" # noqa
1013
OUTPUT_ROOT = "./predictions/cooper" # noqa
1114

1215
DATASETS = [
@@ -31,8 +34,17 @@ def evaluate_dataset(ds_name):
3134
return results
3235

3336
print("Evaluating ds", ds_name)
34-
input_files = sorted(glob(os.path.join(INPUT_ROOT, ds_name, "**/*.h5"), recursive=True))
37+
if ds_name == "04":
38+
input_files = sorted(glob(os.path.join(INPUT_04, "**/*.h5"), recursive=True))
39+
seg_key = "labels/vesicles"
40+
mask_key = "labels/compartment"
41+
else:
42+
input_files = sorted(glob(os.path.join(INPUT_ROOT, ds_name, "**/*.h5"), recursive=True))
43+
seg_key = "/labels/vesicles/combined_vesicles"
44+
mask_key = None
45+
3546
pred_files = sorted(glob(os.path.join(OUTPUT_ROOT, ds_name, "**/*.h5"), recursive=True))
47+
assert len(input_files) == len(pred_files), f"{len(input_files)}, {len(pred_files)}"
3648

3749
results = {
3850
"dataset": [],
@@ -41,16 +53,45 @@ def evaluate_dataset(ds_name):
4153
"recall": [],
4254
"f1-score": [],
4355
}
44-
for inf, predf in zip(input_files, pred_files):
56+
for inf, predf in tqdm(zip(input_files, pred_files), total=len(input_files), desc=f"Evaluate {ds_name}"):
4557
fname = os.path.basename(inf)
46-
47-
with h5py.File(inf, "r") as f:
48-
gt = f["/labels/vesicles/combined_vesicles"][:]
49-
with h5py.File(predf, "r") as f:
50-
seg = f["/prediction/vesicles/cryovesnet"][:]
51-
assert gt.shape == seg.shape
52-
53-
scores = matching(seg, gt)
58+
sub_res_path = os.path.join(result_folder, f"{ds_name}_{fname}.json")
59+
60+
if os.path.exists(sub_res_path):
61+
print("Loading scores from", sub_res_path)
62+
with open(sub_res_path, "r") as f:
63+
scores = json.load(f)
64+
65+
else:
66+
try:
67+
with h5py.File(predf, "r") as f:
68+
seg = f["/prediction/vesicles/cryovesnet"][:]
69+
except Exception:
70+
print("Skipping", predf)
71+
continue
72+
73+
with h5py.File(inf, "r") as f:
74+
gt = f[seg_key][:]
75+
if mask_key is None:
76+
mask = None
77+
else:
78+
mask = f[mask_key][:]
79+
80+
assert gt.shape == seg.shape
81+
82+
if mask is not None:
83+
bb = np.where(mask != 0)
84+
bb = tuple(slice(
85+
int(b.min()), int(b.max()) + 1
86+
) for b in bb)
87+
seg, gt, mask = seg[bb], gt[bb], mask[bb]
88+
seg[mask == 0] = 0
89+
gt[mask == 0] = 0
90+
91+
scores = matching(seg, gt)
92+
93+
with open(sub_res_path, "w") as f:
94+
json.dump(scores, f)
5495

5596
results["dataset"].append(ds_name)
5697
results["file"].append(fname)

0 commit comments

Comments
 (0)