Skip to content

Commit 475a9b5

Browse files
Merge pull request #40 from computational-cell-analytics/compartment-training
Compartment segmentation and CryoVesNet
2 parents d1b6bba + 4363ce4 commit 475a9b5

File tree

11 files changed

+633
-48
lines changed

11 files changed

+633
-48
lines changed

scripts/baselines/cryo_ves_net/common.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
import tempfile
33
from glob import glob
44
from pathlib import Path
5+
from shutil import copyfile
56

67
import h5py
78
import mrcfile
9+
import numpy as np
10+
811
import cryovesnet
912

1013

@@ -14,51 +17,107 @@ def _segment_vesicles(directory):
1417
pl.setup_cryovesnet_dir(make_masks=False)
1518

1619
pl.run_deep()
20+
pl.rescale()
1721
pl.label_vesicles(within_segmentation_region=False)
1822
pl.label_vesicles_adaptive(separating=True)
1923
pl.make_spheres()
2024
pl.repair_spheres()
2125

2226

23-
def _prepare_input(path, output_folder, input_key, resolution):
24-
out_path = os.path.join(output_folder, f"{Path(path).stem}.mrc")
27+
def _prepare_input(path, output_folder, input_key, resolution, rel_folder=None):
28+
fname = Path(path).stem
29+
if rel_folder is None:
30+
sub_folder = os.path.join(output_folder, fname)
31+
else:
32+
sub_folder = os.path.join(output_folder, rel_folder, fname)
33+
34+
os.makedirs(sub_folder, exist_ok=True)
35+
out_path = os.path.join(sub_folder, f"{fname}.mrc")
2536

2637
if path.endswith(".h5"):
2738
assert resolution is not None
2839
with h5py.File(path, "r") as f:
2940
vol = f[input_key][:]
30-
3141
mrcfile.new(out_path, data=vol)
32-
with mrcfile.open(out_path, mode="r+") as f:
33-
f.header.cella.x = resolution[0]
34-
f.header.cella.y = resolution[1]
35-
f.header.cella.z = resolution[2]
3642

37-
# TODO just copy the file
43+
# Copy the mrc file.
3844
elif path.endswith(".mrc"):
39-
pass
45+
copyfile(path, out_path)
46+
47+
# Update the resolution if it was given.
48+
if resolution is not None:
49+
with mrcfile.open(out_path, mode="r+") as f:
50+
f.voxel_size = resolution
51+
f.update_header_from_data()
52+
53+
return out_path, sub_folder
54+
55+
56+
def _process_output(tmp, tmp_file, output_folder, output_key, rel_folder=None, mask_file=None, mask_key=None):
57+
fname = Path(tmp_file).stem
58+
seg_path = os.path.join(tmp, "cryovesnet", f"{fname}_convex_labels.mrc")
59+
with mrcfile.open(seg_path, "r") as f:
60+
seg = f.data[:]
61+
62+
if mask_file is not None:
63+
with h5py.File(mask_file, "r") as f:
64+
mask = f[mask_key][:].astype("bool")
65+
# We need to make this copy, otherwise seg is assignment only.
66+
seg = np.asarray(seg).copy()
67+
seg[~mask] = 0
68+
69+
if rel_folder is None:
70+
this_output_folder = output_folder
71+
else:
72+
this_output_folder = os.path.join(output_folder, rel_folder)
73+
os.makedirs(this_output_folder, exist_ok=True)
74+
75+
out_path = os.path.join(this_output_folder, f"{fname}.h5")
76+
with h5py.File(out_path, "a") as f:
77+
f.create_dataset(output_key, data=seg, compression="gzip")
4078

4179

42-
# TODO support nested
4380
def apply_cryo_vesnet(
4481
input_folder, output_folder, pattern, input_key,
45-
resolution=None, output_key="prediction/vesicles/cryovesnet"
82+
resolution=None, output_key="prediction/vesicles/cryovesnet",
83+
mask_folder=None, mask_key=None, nested=False,
4684
):
47-
files = sorted(glob(os.path.join(input_folder, pattern)))
85+
os.makedirs(output_folder, exist_ok=True)
86+
if nested:
87+
files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
88+
else:
89+
files = sorted(glob(os.path.join(input_folder, pattern)))
90+
91+
if mask_folder is None:
92+
mask_files = None
93+
else:
94+
assert mask_key is not None
95+
if nested:
96+
mask_files = sorted(glob(os.path.join(mask_folder, "**", pattern), recursive=True))
97+
else:
98+
mask_files = sorted(glob(os.path.join(mask_folder, pattern)))
99+
assert len(mask_files) == len(files)
100+
48101
with tempfile.TemporaryDirectory() as tmp:
49102

50-
# Prepare the input files by copying them over or resaving them (if h5).
51-
for file in files:
103+
for i, file in enumerate(files):
104+
105+
# Get the resolution info for this file.
52106
if resolution is None:
53107
res = None
54108
else:
55109
fname = Path(file).stem
56110
res = resolution[fname] if isinstance(resolution, dict) else resolution
57-
_prepare_input(file, tmp, input_key, res)
58111

59-
# Segment the vesicles in all files.
60-
_segment_vesicles(tmp)
61-
breakpoint()
112+
# Prepare the input files by copying them over or resaving them (if h5).
113+
rel_folder = os.path.split(os.path.relpath(file, input_folder))[0] if nested else None
114+
tmp_file, sub_folder = _prepare_input(file, tmp, input_key, res, rel_folder=rel_folder)
115+
116+
# Segment the vesicles in the file.
117+
_segment_vesicles(sub_folder)
62118

63-
# TODO
64-
# Re-save the segmentations to the output folder.
119+
# Write the output file.
120+
mask_file = None if mask_files is None else mask_files[i]
121+
_process_output(
122+
sub_folder, tmp_file, output_folder, output_key, rel_folder, mask_file=mask_file, mask_key=mask_key
123+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from common import apply_cryo_vesnet
2+
3+
4+
def main():
5+
input_folder = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets" # noqa
6+
output_folder = "./cryo-vesnet-test2"
7+
8+
# TODO determine the correct resolution (in angstrom) for each dataset
9+
resolution = (10, 10, 10)
10+
apply_cryo_vesnet(input_folder, output_folder, pattern="*.h5", input_key="raw", resolution=resolution, nested=True)
11+
12+
13+
if __name__ == "__main__":
14+
main()

scripts/baselines/cryo_ves_net/segment_cryo.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@ def main():
1111
"vesicles-33K-L1": (14.6, 14.6, 14.6),
1212
"vesicles-64K-LAM12": (7.56, 7.56, 7.56),
1313
}
14-
apply_cryo_vesnet(input_folder, output_folder, pattern="*.h5", input_key="raw", resolution=resolution)
14+
apply_cryo_vesnet(
15+
input_folder, output_folder,
16+
pattern="*.h5", input_key="raw",
17+
mask_folder=input_folder, mask_key="/labels/mask",
18+
resolution=resolution
19+
)
1520

1621

1722
if __name__ == "__main__":
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import os
2+
from glob import glob
3+
from pathlib import Path
4+
5+
import imageio.v3 as imageio
6+
import h5py
7+
import numpy as np
8+
9+
from scipy.ndimage import binary_erosion, binary_closing
10+
from skimage.measure import label, regionprops
11+
from skimage.morphology import remove_small_holes
12+
from skimage.segmentation import watershed
13+
from synaptic_reconstruction.ground_truth.shape_refinement import edge_filter
14+
from tqdm import tqdm
15+
16+
17+
def process_compartment_gt(im_path, ann_path, output_root, view=True, snap_to_bd=False):
18+
output_path = os.path.join(output_root, os.path.basename(im_path))
19+
if os.path.exists(output_path):
20+
return
21+
22+
seg = imageio.imread(ann_path)
23+
24+
with h5py.File(im_path, "r") as f:
25+
tomo = f["data"][:]
26+
27+
if snap_to_bd:
28+
hmap = edge_filter(tomo, sigma=3.0, per_slice=True)
29+
else:
30+
hmap = None
31+
32+
seg_pp = label(seg)
33+
props = regionprops(seg_pp)
34+
35+
# for dilation / eroision
36+
structure_element = np.ones((3, 3)) # 3x3 structure for XY plane
37+
structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane
38+
structure_3d[0] = structure_element
39+
40+
# Apply the post-processing for each segment.
41+
min_size = 500
42+
for prop in props:
43+
# 1. size filter
44+
if prop.area < min_size:
45+
seg_pp[seg_pp == prop.label] = 0
46+
continue
47+
48+
# 2. get the box and mask for the current object
49+
bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:]))
50+
mask = seg_pp[bb] == prop.label
51+
52+
# 3. filling smal holes and closing closing
53+
mask = remove_small_holes(mask, area_threshold=500)
54+
mask = np.logical_or(binary_closing(mask, iterations=4), mask)
55+
mask = np.logical_or(binary_closing(mask, iterations=8, structure=structure_3d), mask)
56+
57+
# 4. snap to boundary
58+
if snap_to_bd:
59+
seeds = binary_erosion(mask, structure=structure_3d, iterations=3).astype("uint8")
60+
bg_seeds = binary_erosion(~mask, structure=structure_3d, iterations=3)
61+
seeds[bg_seeds] = 2
62+
mask = watershed(hmap[bb], markers=seeds) == 1
63+
64+
# 5. write back
65+
seg_pp[bb][mask] = prop.label
66+
67+
if view:
68+
import napari
69+
70+
v = napari.Viewer()
71+
v.add_image(tomo)
72+
if hmap is not None:
73+
v.add_image(hmap)
74+
v.add_labels(seg, visible=False)
75+
v.add_labels(seg_pp)
76+
napari.run()
77+
return
78+
79+
# Cut some border pixels to avoid artifacts.
80+
bb = np.s_[4:-4, 16:-16, 16:-16]
81+
tomo = tomo[bb]
82+
seg_pp = seg_pp[bb]
83+
84+
with h5py.File(output_path, "a") as f:
85+
f.create_dataset("raw", data=tomo, compression="gzip")
86+
f.create_dataset("labels/compartments", data=seg_pp, compression="gzip")
87+
88+
89+
def main():
90+
for ds in ["05_stem750_sv_training", "06_hoi_wt_stem750_fm"]:
91+
annotation_folder = f"output/{ds}/segmentations"
92+
annotations = sorted(glob(os.path.join(annotation_folder, "*.tif")))
93+
94+
output_root = f"output/compartment_gt/v2/{ds}"
95+
os.makedirs(output_root, exist_ok=True)
96+
97+
image_folder = f"output/{ds}/tomograms"
98+
for ann_path in tqdm(annotations):
99+
fname = Path(ann_path).stem
100+
im_path = os.path.join(image_folder, f"{fname}.h5")
101+
assert os.path.exists(im_path)
102+
process_compartment_gt(im_path, ann_path, output_root, view=False)
103+
104+
105+
if __name__ == "__main__":
106+
main()

scripts/cooper/ground_truth/compartments/preprocess.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def preprocess_tomogram(dataset, tomogram):
4040
output_path=output_path,
4141
model_type="vit_b",
4242
key="data",
43-
checkpoint_path="./checkpoints/compartment_model/best.pt",
43+
checkpoint_path="./checkpoints/compartment_model_v2/best.pt",
4444
ndim=3,
45+
precompute_amg_state=True,
4546
)
4647

4748

@@ -78,7 +79,7 @@ def preprocess_cryo_tomogram(fname):
7879
output_path=output_path,
7980
model_type="vit_b",
8081
key="data",
81-
checkpoint_path="./checkpoints/compartment_model/best.pt",
82+
checkpoint_path="./checkpoints/compartment_model_v2/best.pt",
8283
ndim=3,
8384
)
8485

@@ -113,10 +114,10 @@ def preprocess_cryo():
113114

114115

115116
def main():
116-
# preprocess_05()
117-
# preprocess_06()
118-
# preprocess_09()
119117
preprocess_cryo()
118+
preprocess_05()
119+
preprocess_06()
120+
preprocess_09()
120121

121122

122123
if __name__ == "__main__":

scripts/cooper/ground_truth/compartments/run_annotation.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
2+
from glob import glob
3+
from pathlib import Path
24

35
from elf.io import open_file
46
from micro_sam.sam_annotator import annotator_3d, image_folder_annotator
57

68

79
def run_volume_annotation(ds, name):
8-
checkpoint_path = "./checkpoints/compartment_model/best.pt"
10+
checkpoint_path = "./checkpoints/compartment_model_v2/best.pt"
911

1012
tomogram_path = f"./output/{ds}/tomograms/{name}.h5"
1113
embedding_path = f"./output/{ds}/embeddings/{name}.zarr"
@@ -25,16 +27,40 @@ def run_image_annotation():
2527
)
2628

2729

28-
def main():
29-
run_image_annotation()
30+
def annotate_cryo():
31+
ds = "cryo"
32+
name = "vesicles-33K-L1"
33+
run_volume_annotation(ds, name)
34+
35+
36+
def _series_annotation(ds):
37+
# name = "upSTEM750_36859_J2_TS_SP_001_rec_2kb1dawbp_crop"
38+
images = glob(f"./output/{ds}/tomograms/*.h5")
39+
for image in images:
40+
name = Path(image).stem
41+
seg_path = f"./output/{ds}/segmentations/{name}.tif"
42+
print("Run segmentation for:", ds, name)
43+
if os.path.exists(seg_path):
44+
print("Skipping", ds, name, "because it is already segmented.")
45+
continue
46+
run_volume_annotation(ds, name)
3047

31-
# ds = "09_stem750_66k"
32-
# name = "36859_J1_66K_TS_PS_05_rec_2kb1dawbp_crop"
3348

34-
# ds = "cryo"
35-
# name = "vesicles-64K-LAM12"
49+
def annotate_05():
50+
ds = "05_stem750_sv_training"
51+
_series_annotation(ds)
3652

37-
# run_annotation(ds, name)
53+
54+
def annotate_06():
55+
ds = "06_hoi_wt_stem750_fm"
56+
_series_annotation(ds)
57+
58+
59+
def main():
60+
# run_image_annotation()
61+
# annotate_cryo()
62+
# annotate_05()
63+
annotate_06()
3864

3965

4066
if __name__ == "__main__":

0 commit comments

Comments
 (0)