1
+ import json
1
2
import os
2
3
from glob import glob
3
4
4
5
import h5py
5
6
import numpy as np
6
7
import pandas as pd
7
8
from elf .evaluation .matching import matching
9
+ from tqdm import tqdm
8
10
9
11
INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets" # noqa
12
+ INPUT_04 = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval" # noqa
10
13
OUTPUT_ROOT = "./predictions/cooper" # noqa
11
14
12
15
DATASETS = [
@@ -31,8 +34,17 @@ def evaluate_dataset(ds_name):
31
34
return results
32
35
33
36
print ("Evaluating ds" , ds_name )
34
- input_files = sorted (glob (os .path .join (INPUT_ROOT , ds_name , "**/*.h5" ), recursive = True ))
37
+ if ds_name == "04" :
38
+ input_files = sorted (glob (os .path .join (INPUT_04 , "**/*.h5" ), recursive = True ))
39
+ seg_key = "labels/vesicles"
40
+ mask_key = "labels/compartment"
41
+ else :
42
+ input_files = sorted (glob (os .path .join (INPUT_ROOT , ds_name , "**/*.h5" ), recursive = True ))
43
+ seg_key = "/labels/vesicles/combined_vesicles"
44
+ mask_key = None
45
+
35
46
pred_files = sorted (glob (os .path .join (OUTPUT_ROOT , ds_name , "**/*.h5" ), recursive = True ))
47
+ assert len (input_files ) == len (pred_files ), f"{ len (input_files )} , { len (pred_files )} "
36
48
37
49
results = {
38
50
"dataset" : [],
@@ -41,16 +53,45 @@ def evaluate_dataset(ds_name):
41
53
"recall" : [],
42
54
"f1-score" : [],
43
55
}
44
- for inf , predf in zip (input_files , pred_files ):
56
+ for inf , predf in tqdm ( zip (input_files , pred_files ), total = len ( input_files ), desc = f"Evaluate { ds_name } " ):
45
57
fname = os .path .basename (inf )
46
-
47
- with h5py .File (inf , "r" ) as f :
48
- gt = f ["/labels/vesicles/combined_vesicles" ][:]
49
- with h5py .File (predf , "r" ) as f :
50
- seg = f ["/prediction/vesicles/cryovesnet" ][:]
51
- assert gt .shape == seg .shape
52
-
53
- scores = matching (seg , gt )
58
+ sub_res_path = os .path .join (result_folder , f"{ ds_name } _{ fname } .json" )
59
+
60
+ if os .path .exists (sub_res_path ):
61
+ print ("Loading scores from" , sub_res_path )
62
+ with open (sub_res_path , "r" ) as f :
63
+ scores = json .load (f )
64
+
65
+ else :
66
+ try :
67
+ with h5py .File (predf , "r" ) as f :
68
+ seg = f ["/prediction/vesicles/cryovesnet" ][:]
69
+ except Exception :
70
+ print ("Skipping" , predf )
71
+ continue
72
+
73
+ with h5py .File (inf , "r" ) as f :
74
+ gt = f [seg_key ][:]
75
+ if mask_key is None :
76
+ mask = None
77
+ else :
78
+ mask = f [mask_key ][:]
79
+
80
+ assert gt .shape == seg .shape
81
+
82
+ if mask is not None :
83
+ bb = np .where (mask != 0 )
84
+ bb = tuple (slice (
85
+ int (b .min ()), int (b .max ()) + 1
86
+ ) for b in bb )
87
+ seg , gt , mask = seg [bb ], gt [bb ], mask [bb ]
88
+ seg [mask == 0 ] = 0
89
+ gt [mask == 0 ] = 0
90
+
91
+ scores = matching (seg , gt )
92
+
93
+ with open (sub_res_path , "w" ) as f :
94
+ json .dump (scores , f )
54
95
55
96
results ["dataset" ].append (ds_name )
56
97
results ["file" ].append (fname )
0 commit comments