Skip to content

Commit feadf6f

Browse files
committed
2 parents be7ba41 + fc7a3a1 commit feadf6f

File tree

20 files changed

+1192
-49
lines changed

20 files changed

+1192
-49
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# CryoVesNet
2+
3+
Scripts to run CryoVesNet on our data. See https://github.com/Zuber-group/CryoVesNet for details.
4+
5+
The code is currently not working due to this issue: https://github.com/Zuber-group/CryoVesNet/issues/6
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import os
2+
import tempfile
3+
from glob import glob
4+
from pathlib import Path
5+
from shutil import copyfile
6+
7+
import h5py
8+
import mrcfile
9+
import numpy as np
10+
11+
import cryovesnet
12+
13+
14+
# additional parameters?
15+
def _segment_vesicles(directory):
16+
pl = cryovesnet.Pipeline(directory, pattern="*.mrc")
17+
pl.setup_cryovesnet_dir(make_masks=False)
18+
19+
pl.run_deep()
20+
pl.rescale()
21+
pl.label_vesicles(within_segmentation_region=False)
22+
pl.label_vesicles_adaptive(separating=True)
23+
pl.make_spheres()
24+
pl.repair_spheres()
25+
26+
27+
def _prepare_input(path, output_folder, input_key, resolution, rel_folder=None):
28+
fname = Path(path).stem
29+
if rel_folder is None:
30+
sub_folder = os.path.join(output_folder, fname)
31+
else:
32+
sub_folder = os.path.join(output_folder, rel_folder, fname)
33+
34+
os.makedirs(sub_folder, exist_ok=True)
35+
out_path = os.path.join(sub_folder, f"{fname}.mrc")
36+
37+
if path.endswith(".h5"):
38+
assert resolution is not None
39+
with h5py.File(path, "r") as f:
40+
vol = f[input_key][:]
41+
mrcfile.new(out_path, data=vol)
42+
43+
# Copy the mrc file.
44+
elif path.endswith(".mrc"):
45+
copyfile(path, out_path)
46+
47+
# Update the resolution if it was given.
48+
if resolution is not None:
49+
with mrcfile.open(out_path, mode="r+") as f:
50+
f.voxel_size = resolution
51+
f.update_header_from_data()
52+
53+
return out_path, sub_folder
54+
55+
56+
def _process_output(tmp, tmp_file, output_folder, output_key, rel_folder=None, mask_file=None, mask_key=None):
57+
fname = Path(tmp_file).stem
58+
seg_path = os.path.join(tmp, "cryovesnet", f"{fname}_convex_labels.mrc")
59+
with mrcfile.open(seg_path, "r") as f:
60+
seg = f.data[:]
61+
62+
if mask_file is not None:
63+
with h5py.File(mask_file, "r") as f:
64+
mask = f[mask_key][:].astype("bool")
65+
# We need to make this copy, otherwise seg is assignment only.
66+
seg = np.asarray(seg).copy()
67+
seg[~mask] = 0
68+
69+
if rel_folder is None:
70+
this_output_folder = output_folder
71+
else:
72+
this_output_folder = os.path.join(output_folder, rel_folder)
73+
os.makedirs(this_output_folder, exist_ok=True)
74+
75+
out_path = os.path.join(this_output_folder, f"{fname}.h5")
76+
with h5py.File(out_path, "a") as f:
77+
f.create_dataset(output_key, data=seg, compression="gzip")
78+
79+
80+
def apply_cryo_vesnet(
81+
input_folder, output_folder, pattern, input_key,
82+
resolution=None, output_key="prediction/vesicles/cryovesnet",
83+
mask_folder=None, mask_key=None, nested=False,
84+
):
85+
os.makedirs(output_folder, exist_ok=True)
86+
if nested:
87+
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
88+
else:
89+
files = sorted(glob(os.path.join(input_folder, pattern)))
90+
91+
if mask_folder is None:
92+
mask_files = None
93+
else:
94+
assert mask_key is not None
95+
if nested:
96+
mask_files = sorted(glob(os.path.join(mask_folder, "**", pattern), recursive=True))
97+
else:
98+
mask_files = sorted(glob(os.path.join(mask_folder, pattern)))
99+
assert len(mask_files) == len(files)
100+
101+
with tempfile.TemporaryDirectory() as tmp:
102+
103+
for i, file in enumerate(files):
104+
105+
# Get the resolution info for this file.
106+
if resolution is None:
107+
res = None
108+
else:
109+
fname = Path(file).stem
110+
res = resolution[fname] if isinstance(resolution, dict) else resolution
111+
112+
# Prepare the input files by copying them over or resaving them (if h5).
113+
rel_folder = os.path.split(os.path.relpath(file, input_folder))[0] if nested else None
114+
tmp_file, sub_folder = _prepare_input(file, tmp, input_key, res, rel_folder=rel_folder)
115+
116+
# Segment the vesicles in the file.
117+
_segment_vesicles(sub_folder)
118+
119+
# Write the output file.
120+
mask_file = None if mask_files is None else mask_files[i]
121+
_process_output(
122+
sub_folder, tmp_file, output_folder, output_key, rel_folder, mask_file=mask_file, mask_key=mask_key
123+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from common import apply_cryo_vesnet
2+
3+
4+
def main():
5+
input_folder = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets" # noqa
6+
output_folder = "./cryo-vesnet-test2"
7+
8+
# TODO determine the correct resolution (in angstrom) for each dataset
9+
resolution = (10, 10, 10)
10+
apply_cryo_vesnet(input_folder, output_folder, pattern="*.h5", input_key="raw", resolution=resolution, nested=True)
11+
12+
13+
if __name__ == "__main__":
14+
main()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from common import apply_cryo_vesnet
2+
3+
4+
def main():
5+
input_folder = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/fernandez-busnadiego/vesicle_gt/v2" # noqa
6+
output_folder = "./cryo-vesnet-test"
7+
8+
# Resolution in Angstrom in XYZ
9+
# The two tomograms have a different resolution.
10+
resolution = {
11+
"vesicles-33K-L1": (14.6, 14.6, 14.6),
12+
"vesicles-64K-LAM12": (7.56, 7.56, 7.56),
13+
}
14+
apply_cryo_vesnet(
15+
input_folder, output_folder,
16+
pattern="*.h5", input_key="raw",
17+
mask_folder=input_folder, mask_key="/labels/mask",
18+
resolution=resolution
19+
)
20+
21+
22+
if __name__ == "__main__":
23+
main()

scripts/cooper/export_vesicles_to_imod.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66

77
def export_vesicles_to_imod(args):
8-
export_function = partial(write_segmentation_to_imod_as_points, min_radius=args.min_radius, radius_factor=args.increase_radius)
8+
export_function = partial(
9+
write_segmentation_to_imod_as_points, min_radius=args.min_radius, radius_factor=args.increase_radius
10+
)
911
export_helper(args.input_path, args.segmentation_path, args.output_path, export_function, force=args.force)
1012

1113

@@ -32,7 +34,7 @@ def main():
3234
help="Whether to over-write already present export results."
3335
)
3436
parser.add_argument(
35-
"--increase_radius", type=float, default=1.5,
37+
"--increase_radius", type=float, default=1.3,
3638
help="The factor to increase the radius of the exported vesicles.",
3739
)
3840
args = parser.parse_args()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import os
2+
from glob import glob
3+
from pathlib import Path
4+
5+
import imageio.v3 as imageio
6+
import h5py
7+
import numpy as np
8+
9+
from scipy.ndimage import binary_erosion, binary_closing
10+
from skimage.measure import label, regionprops
11+
from skimage.morphology import remove_small_holes
12+
from skimage.segmentation import watershed
13+
from synaptic_reconstruction.ground_truth.shape_refinement import edge_filter
14+
from tqdm import tqdm
15+
16+
17+
def process_compartment_gt(im_path, ann_path, output_root, view=True, snap_to_bd=False):
18+
output_path = os.path.join(output_root, os.path.basename(im_path))
19+
if os.path.exists(output_path):
20+
return
21+
22+
seg = imageio.imread(ann_path)
23+
24+
with h5py.File(im_path, "r") as f:
25+
tomo = f["data"][:]
26+
27+
if snap_to_bd:
28+
hmap = edge_filter(tomo, sigma=3.0, per_slice=True)
29+
else:
30+
hmap = None
31+
32+
seg_pp = label(seg)
33+
props = regionprops(seg_pp)
34+
35+
# for dilation / eroision
36+
structure_element = np.ones((3, 3)) # 3x3 structure for XY plane
37+
structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane
38+
structure_3d[0] = structure_element
39+
40+
# Apply the post-processing for each segment.
41+
min_size = 500
42+
for prop in props:
43+
# 1. size filter
44+
if prop.area < min_size:
45+
seg_pp[seg_pp == prop.label] = 0
46+
continue
47+
48+
# 2. get the box and mask for the current object
49+
bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:]))
50+
mask = seg_pp[bb] == prop.label
51+
52+
# 3. filling smal holes and closing closing
53+
mask = remove_small_holes(mask, area_threshold=500)
54+
mask = np.logical_or(binary_closing(mask, iterations=4), mask)
55+
mask = np.logical_or(binary_closing(mask, iterations=8, structure=structure_3d), mask)
56+
57+
# 4. snap to boundary
58+
if snap_to_bd:
59+
seeds = binary_erosion(mask, structure=structure_3d, iterations=3).astype("uint8")
60+
bg_seeds = binary_erosion(~mask, structure=structure_3d, iterations=3)
61+
seeds[bg_seeds] = 2
62+
mask = watershed(hmap[bb], markers=seeds) == 1
63+
64+
# 5. write back
65+
seg_pp[bb][mask] = prop.label
66+
67+
if view:
68+
import napari
69+
70+
v = napari.Viewer()
71+
v.add_image(tomo)
72+
if hmap is not None:
73+
v.add_image(hmap)
74+
v.add_labels(seg, visible=False)
75+
v.add_labels(seg_pp)
76+
napari.run()
77+
return
78+
79+
# Cut some border pixels to avoid artifacts.
80+
bb = np.s_[4:-4, 16:-16, 16:-16]
81+
tomo = tomo[bb]
82+
seg_pp = seg_pp[bb]
83+
84+
with h5py.File(output_path, "a") as f:
85+
f.create_dataset("raw", data=tomo, compression="gzip")
86+
f.create_dataset("labels/compartments", data=seg_pp, compression="gzip")
87+
88+
89+
def main():
90+
for ds in ["05_stem750_sv_training", "06_hoi_wt_stem750_fm"]:
91+
annotation_folder = f"output/{ds}/segmentations"
92+
annotations = sorted(glob(os.path.join(annotation_folder, "*.tif")))
93+
94+
output_root = f"output/compartment_gt/v2/{ds}"
95+
os.makedirs(output_root, exist_ok=True)
96+
97+
image_folder = f"output/{ds}/tomograms"
98+
for ann_path in tqdm(annotations):
99+
fname = Path(ann_path).stem
100+
im_path = os.path.join(image_folder, f"{fname}.h5")
101+
assert os.path.exists(im_path)
102+
process_compartment_gt(im_path, ann_path, output_root, view=False)
103+
104+
105+
if __name__ == "__main__":
106+
main()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import imageio.v3 as imageio
6+
import napari
7+
import numpy as np
8+
9+
from skimage.measure import label
10+
# from skimage.morphology import remove_small_holes
11+
from tqdm import tqdm
12+
13+
14+
def process_labels(labels):
15+
labels = label(labels)
16+
17+
min_size = 75
18+
ids, sizes = np.unique(labels, return_counts=True)
19+
filter_ids = ids[sizes < min_size]
20+
labels[np.isin(labels, filter_ids)] = 0
21+
22+
# labels = remove_small_holes(labels, area_threshold=min_size)
23+
return labels
24+
25+
26+
def postprocess_annotation(im_path, ann_path, output_folder, view=False):
27+
fname = os.path.basename(im_path)
28+
29+
out_path = os.path.join(output_folder, fname.replace(".tif", ".h5"))
30+
if os.path.exists(out_path):
31+
return
32+
33+
labels = imageio.imread(ann_path)
34+
35+
# Skip empty labels.
36+
if labels.max() == 0:
37+
print("Skipping", im_path)
38+
return
39+
40+
image = imageio.imread(im_path)
41+
labels = process_labels(labels)
42+
43+
if view:
44+
v = napari.Viewer()
45+
v.add_image(image)
46+
v.add_labels(labels)
47+
napari.run()
48+
return
49+
50+
with h5py.File(out_path, "a") as f:
51+
f.create_dataset("data", data=image, compression="gzip")
52+
f.create_dataset("labels/compartments", data=labels, compression="gzip")
53+
54+
55+
def postprocess_annotations(view):
56+
images = sorted(glob("output/images/*.tif"))
57+
annotations = sorted(glob("output/annotations/*.tif"))
58+
59+
output_folder = "output/postprocessed_annotations"
60+
os.makedirs(output_folder, exist_ok=True)
61+
for im, ann in tqdm(zip(images, annotations), total=len(images)):
62+
postprocess_annotation(im, ann, output_folder, view=view)
63+
64+
65+
def main():
66+
postprocess_annotations(view=False)
67+
68+
69+
if __name__ == "__main__":
70+
main()

0 commit comments

Comments
 (0)