Skip to content

Commit 64d56e5

Browse files
Update segmentation functionality and add more inference scripts
1 parent 993a006 commit 64d56e5

File tree

7 files changed

+117
-26
lines changed

7 files changed

+117
-26
lines changed

scripts/cooper/full_reconstruction/assort_az_and_vesicles.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
11
import os
22
from glob import glob
33

4+
import h5py
5+
from tqdm import tqdm
46

5-
INPUT_ROOT = ""
6-
OUTPUT_ROOT = ""
7+
8+
INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04Dataset_for_vesicle_eval/model_segmentation" # noqa
9+
OUTPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction"
10+
11+
12+
def assort_az_and_vesicles(in_path, out_path):
13+
if os.path.exists(out_path):
14+
return
15+
16+
with h5py.File(in_path, "r") as f:
17+
raw = f["raw"][:]
18+
vesicles = f["/vesicles/segment_from_combined_vesicles"][:]
19+
az = f["/AZ/segment_from_AZmodel_v3"][:]
20+
21+
os.makedirs(os.path.split(out_path)[0], exist_ok=True)
22+
with h5py.File(out_path, "a") as f:
23+
f.create_dataset("raw", data=raw, compression="gzip")
24+
f.create_dataset("labels/vesicles", data=vesicles, compression="gzip")
25+
f.create_dataset("labels/active_zone", data=az, compression="gzip")
726

827

928
def main():
10-
pass
29+
paths = sorted(glob(os.path.join(INPUT_ROOT, "**/*.h5"), recursive=True))
30+
for path in tqdm(paths):
31+
fname = os.path.relpath(path, INPUT_ROOT)
32+
out_path = os.path.join(OUTPUT_ROOT, fname)
33+
assort_az_and_vesicles(path, out_path)
1134

1235

1336
if __name__ == "__main__":
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
from synaptic_reconstruction.inference.compartments import segment_compartments
6+
from tqdm import tqdm
7+
8+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction" # noqa
9+
MODEL_PATH = "/user/pape41/u12086/Work/my_projects/synaptic-reconstruction/scripts/cooper/training/checkpoints/compartment_model_3d/v2" # noqa
10+
11+
12+
def label_transform_3d():
13+
pass
14+
15+
16+
def run_seg(path):
17+
with h5py.File(path, "r") as f:
18+
if "labels/compartments" in f:
19+
return
20+
raw = f["raw"][:]
21+
22+
scale = (0.25, 0.25, 0.25)
23+
seg = segment_compartments(raw, model_path=MODEL_PATH, scale=scale, verbose=False)
24+
with h5py.File(path, "a") as f:
25+
f.create_dataset("labels/compartments", data=seg, compression="gzip")
26+
27+
28+
def main():
29+
paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True))
30+
for path in tqdm(paths):
31+
run_seg(path)
32+
33+
34+
main()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
from synaptic_reconstruction.inference.mitochondria import segment_mitochondria
6+
from tqdm import tqdm
7+
8+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/04_full_reconstruction" # noqa
9+
MODEL_PATH = "/scratch-grete/projects/nim00007/models/exports_for_cooper/mito_model_s2.pt" # noqa
10+
11+
12+
def run_seg(path):
13+
with h5py.File(path, "r") as f:
14+
if "labels/mitochondria" in f:
15+
return
16+
raw = f["raw"][:]
17+
18+
scale = (0.5, 0.5, 0.5)
19+
seg = segment_mitochondria(raw, model_path=MODEL_PATH, scale=scale, verbose=False)
20+
with h5py.File(path, "a") as f:
21+
f.create_dataset("labels/mitochondria", data=seg, compression="gzip")
22+
23+
24+
def main():
25+
paths = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True))
26+
for path in tqdm(paths):
27+
run_seg(path)
28+
29+
30+
main()

scripts/cooper/full_reconstruction/visualize_results.py

Whitespace-only changes.

synaptic_reconstruction/inference/compartments.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from skimage.segmentation import watershed
1414
from skimage.morphology import remove_small_holes
1515

16-
from synaptic_reconstruction.inference.util import get_prediction, _Scaler
16+
from synaptic_reconstruction.inference.util import get_prediction, _Scaler, _postprocess_seg_3d
1717

1818

1919
def _segment_compartments_2d(
@@ -112,27 +112,6 @@ def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
112112
return segmentation
113113

114114

115-
def _postprocess_seg_3d(seg):
116-
# Structure lement for 2d dilation in 3d.
117-
structure_element = np.ones((3, 3)) # 3x3 structure for XY plane
118-
structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane
119-
structure_3d[0] = structure_element
120-
121-
props = regionprops(seg)
122-
for prop in props:
123-
# Get bounding box and mask.
124-
bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:]))
125-
mask = seg[bb] == prop.label
126-
127-
# Fill small holes and apply closing.
128-
mask = remove_small_holes(mask, area_threshold=1000)
129-
mask = np.logical_or(binary_closing(mask, iterations=4), mask)
130-
mask = np.logical_or(binary_closing(mask, iterations=8, structure=structure_3d), mask)
131-
seg[bb][mask] = prop.label
132-
133-
return seg
134-
135-
136115
def _segment_compartments_3d(
137116
prediction,
138117
boundary_threshold=0.4,

synaptic_reconstruction/inference/mitochondria.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import torch
77

8-
from synaptic_reconstruction.inference.util import apply_size_filter, get_prediction, _Scaler
8+
from synaptic_reconstruction.inference.util import apply_size_filter, get_prediction, _Scaler, _postprocess_seg_3d
99

1010

1111
def _run_segmentation(
@@ -37,6 +37,7 @@ def _run_segmentation(
3737
print("Compute watershed in", time.time() - t0, "s")
3838

3939
seg = apply_size_filter(seg, min_size, verbose=verbose, block_shape=block_shape)
40+
seg = _postprocess_seg_3d(seg)
4041
return seg
4142

4243

synaptic_reconstruction/inference/util.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import xarray
1818

1919
from elf.io import open_file
20+
from scipy.ndimage import binary_closing
21+
from skimage.measure import regionprops
22+
from skimage.morphology import remove_small_holes
2023
from skimage.transform import rescale, resize
2124
from torch_em.util.prediction import predict_with_halo
2225
from tqdm import tqdm
@@ -465,3 +468,24 @@ def apply_size_filter(
465468
if verbose:
466469
print("Size filter in", time.time() - t0, "s")
467470
return segmentation
471+
472+
473+
def _postprocess_seg_3d(seg):
474+
# Structure lement for 2d dilation in 3d.
475+
structure_element = np.ones((3, 3)) # 3x3 structure for XY plane
476+
structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane
477+
structure_3d[0] = structure_element
478+
479+
props = regionprops(seg)
480+
for prop in props:
481+
# Get bounding box and mask.
482+
bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:]))
483+
mask = seg[bb] == prop.label
484+
485+
# Fill small holes and apply closing.
486+
mask = remove_small_holes(mask, area_threshold=1000)
487+
mask = np.logical_or(binary_closing(mask, iterations=4), mask)
488+
mask = np.logical_or(binary_closing(mask, iterations=8, structure=structure_3d), mask)
489+
seg[bb][mask] = prop.label
490+
491+
return seg

0 commit comments

Comments
 (0)