|
| 1 | +import os |
| 2 | +import h5py |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | + |
| 6 | +from synapse_net.inference.inference import get_model |
| 7 | +from synapse_net.inference.compartments import segment_compartments |
| 8 | +from skimage.segmentation import find_boundaries |
| 9 | + |
| 10 | +from elf.evaluation.matching import matching |
| 11 | + |
| 12 | +from train_compartments import get_paths_3d |
| 13 | +from sklearn.model_selection import train_test_split |
| 14 | + |
| 15 | + |
| 16 | +def run_prediction(paths): |
| 17 | + output_folder = "./compartment_eval" |
| 18 | + os.makedirs(output_folder, exist_ok=True) |
| 19 | + |
| 20 | + model = get_model("compartments") |
| 21 | + for path in paths: |
| 22 | + with h5py.File(path, "r") as f: |
| 23 | + input_vol = f["raw"][:] |
| 24 | + seg, pred = segment_compartments(input_vol, model=model, return_predictions=True) |
| 25 | + fname = os.path.basename(path) |
| 26 | + out = os.path.join(output_folder, fname) |
| 27 | + with h5py.File(out, "a") as f: |
| 28 | + f.create_dataset("seg", data=seg, compression="gzip") |
| 29 | + f.create_dataset("pred", data=pred, compression="gzip") |
| 30 | + |
| 31 | + |
| 32 | +def binary_recall(gt, pred): |
| 33 | + tp = np.logical_and(gt, pred).sum() |
| 34 | + fn = np.logical_and(gt, ~pred).sum() |
| 35 | + return float(tp) / (tp + fn) if (tp + fn) else 0.0 |
| 36 | + |
| 37 | + |
| 38 | +def run_evaluation(paths): |
| 39 | + output_folder = "./compartment_eval" |
| 40 | + |
| 41 | + results = { |
| 42 | + "name": [], |
| 43 | + "recall-pred": [], |
| 44 | + "recall-seg": [], |
| 45 | + } |
| 46 | + |
| 47 | + for path in paths: |
| 48 | + with h5py.File(path, "r") as f: |
| 49 | + labels = f["labels/compartments"][:] |
| 50 | + boundary_labels = find_boundaries(labels).astype("bool") |
| 51 | + |
| 52 | + fname = os.path.basename(path) |
| 53 | + out = os.path.join(output_folder, fname) |
| 54 | + with h5py.File(out, "a") as f: |
| 55 | + seg, pred = f["seg"][:], f["pred"][:] |
| 56 | + |
| 57 | + recall_pred = binary_recall(boundary_labels, pred > 0.5) |
| 58 | + recall_seg = matching(seg, labels)["recall"] |
| 59 | + |
| 60 | + results["name"].append(fname) |
| 61 | + results["recall-pred"].append(recall_pred) |
| 62 | + results["recall-seg"].append(recall_seg) |
| 63 | + |
| 64 | + results = pd.DataFrame(results) |
| 65 | + print(results) |
| 66 | + print(results[["recall-pred", "recall-seg"]].mean()) |
| 67 | + |
| 68 | + |
| 69 | +def check_predictions(paths): |
| 70 | + import napari |
| 71 | + output_folder = "./compartment_eval" |
| 72 | + |
| 73 | + for path in paths: |
| 74 | + with h5py.File(path, "r") as f: |
| 75 | + raw = f["raw"][:] |
| 76 | + labels = f["labels/compartments"][:] |
| 77 | + boundary_labels = find_boundaries(labels) |
| 78 | + |
| 79 | + fname = os.path.basename(path) |
| 80 | + out = os.path.join(output_folder, fname) |
| 81 | + with h5py.File(out, "a") as f: |
| 82 | + seg, pred = f["seg"][:], f["pred"][:] |
| 83 | + |
| 84 | + v = napari.Viewer() |
| 85 | + v.add_image(raw) |
| 86 | + v.add_image(pred) |
| 87 | + v.add_labels(labels) |
| 88 | + v.add_labels(boundary_labels) |
| 89 | + v.add_labels(seg) |
| 90 | + napari.run() |
| 91 | + |
| 92 | + |
| 93 | +def main(): |
| 94 | + paths = get_paths_3d() |
| 95 | + _, val_paths = train_test_split(paths, test_size=0.10, random_state=42) |
| 96 | + |
| 97 | + # run_prediction(val_paths) |
| 98 | + run_evaluation(val_paths) |
| 99 | + # check_predictions(val_paths) |
| 100 | + |
| 101 | + |
| 102 | +if __name__ == "__main__": |
| 103 | + main() |
0 commit comments