Skip to content

Commit 997d804

Browse files
Update ground-truth annotation scripts
1 parent c590869 commit 997d804

File tree

5 files changed

+294
-33
lines changed

5 files changed

+294
-33
lines changed

scripts/cooper/ground_truth/compartments/preprocess.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from skimage.transform import rescale
77

88
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/original_imod_data/20240909_cp_datatransfer" # noqa
9+
ROOT_CRYO = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/fernandez-busnadiego/vesicle_gt/v1" # noqa
910

1011

1112
def preprocess_tomogram(dataset, tomogram):
@@ -44,6 +45,44 @@ def preprocess_tomogram(dataset, tomogram):
4445
)
4546

4647

48+
def preprocess_cryo_tomogram(fname):
49+
scale = (0.5, 0.5, 0.5)
50+
51+
dataset = "cryo"
52+
output_root = f"./output/{dataset}"
53+
output_tomos = os.path.join(output_root, "tomograms")
54+
output_embed = os.path.join(output_root, "embeddings")
55+
os.makedirs(output_tomos, exist_ok=True)
56+
os.makedirs(output_embed, exist_ok=True)
57+
58+
tomogram = os.path.join(ROOT_CRYO, f"{fname}.h5")
59+
60+
input_path = os.path.join(output_tomos, f"{fname}.h5")
61+
output_path = os.path.join(output_embed, f"{fname}.zarr")
62+
if os.path.exists(output_path):
63+
return
64+
65+
tomogram_path = os.path.join(ROOT_CRYO, dataset, tomogram)
66+
with open_file(tomogram_path, "r") as f:
67+
tomo = f["raw"][:]
68+
69+
print("Resizing tomogram ...")
70+
tomo = rescale(tomo, scale, preserve_range=True).astype(tomo.dtype)
71+
72+
with open_file(input_path, "a") as f:
73+
f.create_dataset("data", data=tomo, compression="gzip")
74+
75+
print("Precompute state ...")
76+
precompute_state(
77+
input_path=input_path,
78+
output_path=output_path,
79+
model_type="vit_b",
80+
key="data",
81+
checkpoint_path="./checkpoints/compartment_model/best.pt",
82+
ndim=3,
83+
)
84+
85+
4786
def preprocess_05():
4887
dataset = "05_stem750_sv_training"
4988
tomograms = sorted(glob(os.path.join(ROOT, dataset, "*.mrc")))
@@ -65,10 +104,19 @@ def preprocess_09():
65104
preprocess_tomogram(dataset, os.path.basename(tomo))
66105

67106

107+
def preprocess_cryo():
108+
fname = "vesicles-33K-L1"
109+
preprocess_cryo_tomogram(fname)
110+
111+
fname = "vesicles-64K-LAM12"
112+
preprocess_cryo_tomogram(fname)
113+
114+
68115
def main():
69-
preprocess_05()
70-
preprocess_06()
71-
preprocess_09()
116+
# preprocess_05()
117+
# preprocess_06()
118+
# preprocess_09()
119+
preprocess_cryo()
72120

73121

74122
if __name__ == "__main__":
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
6+
from elf.io import open_file
7+
from skimage.transform import rescale
8+
from tqdm import tqdm
9+
10+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/original_imod_data/20240909_cp_datatransfer" # noqa
11+
ROOT_CRYO = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/fernandez-busnadiego/vesicle_gt/v1" # noqa
12+
OUTPUT_IMAGES = "./output/images"
13+
14+
15+
def process_tomogram(tomo_path, scale, tomo_key="data"):
16+
with open_file(tomo_path, "r") as f:
17+
tomo = f[tomo_key][:]
18+
19+
os.makedirs(OUTPUT_IMAGES, exist_ok=True)
20+
offset = len(glob(os.path.join(OUTPUT_IMAGES, "*.tif")))
21+
22+
print("Resizing tomogram ...")
23+
tomo = rescale(tomo, scale, preserve_range=True).astype(tomo.dtype)
24+
25+
z_max = tomo.shape[0]
26+
slices = [z_max // 2, z_max // 4, 3 * z_max // 4]
27+
28+
for i, z in enumerate(slices):
29+
im = tomo[z]
30+
idx = i + offset
31+
out_path = os.path.join(OUTPUT_IMAGES, f"image_{idx:05}.tif")
32+
imageio.imwrite(out_path, im, compression="zlib")
33+
34+
35+
def preprocess_05():
36+
scale = (0.25, 0.25, 0.25)
37+
dataset = "05_stem750_sv_training"
38+
tomograms = sorted(glob(os.path.join(ROOT, dataset, "*.mrc")))
39+
for tomo in tqdm(tomograms):
40+
process_tomogram(tomo, scale)
41+
42+
43+
def preprocess_06():
44+
scale = (0.25, 0.25, 0.25)
45+
dataset = "06_hoi_wt_stem750_fm"
46+
tomograms = sorted(glob(os.path.join(ROOT, dataset, "*.mrc")))
47+
for tomo in tqdm(tomograms):
48+
process_tomogram(tomo, scale)
49+
50+
51+
def preprocess_09():
52+
scale = (0.25, 0.25, 0.25)
53+
dataset = "09_stem750_66k"
54+
tomograms = sorted(glob(os.path.join(ROOT, dataset, "*.mrc")))
55+
for tomo in tqdm(tomograms):
56+
process_tomogram(tomo, scale)
57+
58+
59+
def preprocess_cryo():
60+
scale = (0.5, 0.5, 0.5)
61+
tomograms = sorted(glob(os.path.join(ROOT_CRYO, "*.h5")))
62+
for tomo in tqdm(tomograms):
63+
process_tomogram(tomo, scale, tomo_key="raw")
64+
65+
66+
def precompute_state():
67+
from micro_sam.util import get_sam_model
68+
from micro_sam.precompute_state import _precompute_state_for_files
69+
70+
images = sorted(glob(os.path.join(OUTPUT_IMAGES, "*.tif")))
71+
embedding_path = "./output/embeddings"
72+
73+
predictor = get_sam_model(model_type="vit_b", checkpoint_path="./checkpoints/compartment_model/best.pt")
74+
precompute_amg_state = False
75+
decoder = None
76+
77+
_precompute_state_for_files(
78+
predictor, images, embedding_path, ndim=2, tile_shape=None, halo=None,
79+
precompute_amg_state=precompute_amg_state, decoder=decoder,
80+
)
81+
82+
83+
def main():
84+
# preprocess_05()
85+
# preprocess_06()
86+
# preprocess_09()
87+
# preprocess_cryo()
88+
precompute_state()
89+
90+
91+
if __name__ == "__main__":
92+
main()
Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,92 @@
1+
import os
2+
from glob import glob
3+
14
import numpy as np
5+
26
from micro_sam.training import train_sam, default_sam_dataset
7+
from sklearn.model_selection import train_test_split
38
from torch_em.data.sampler import MinInstanceSampler
49
from torch_em.segmentation import get_data_loader
510

6-
data_path = "./segmentation.h5"
7-
8-
with_segmentation_decoder = False
9-
patch_shape = [1, 462, 462]
10-
z_split = 400
11-
12-
train_ds = default_sam_dataset(
13-
raw_paths=data_path, raw_key="raw_downscaled",
14-
label_paths=data_path, label_key="segmentation/compartments",
15-
patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder,
16-
sampler=MinInstanceSampler(2), rois=np.s_[z_split:, :, :],
17-
n_samples=200,
18-
)
19-
train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2)
20-
21-
val_ds = default_sam_dataset(
22-
raw_paths=data_path, raw_key="raw_downscaled",
23-
label_paths=data_path, label_key="segmentation/compartments",
24-
patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder,
25-
sampler=MinInstanceSampler(2), rois=np.s_[:z_split, :, :],
26-
is_train=False, n_samples=25,
27-
)
28-
val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1)
29-
30-
train_sam(
31-
name="compartment_model", model_type="vit_b",
32-
train_loader=train_loader, val_loader=val_loader,
33-
n_epochs=100, n_objects_per_batch=10,
34-
with_segmentation_decoder=with_segmentation_decoder,
35-
)
11+
12+
def train_v1():
13+
data_path = "./segmentation.h5"
14+
15+
with_segmentation_decoder = False
16+
patch_shape = [1, 462, 462]
17+
z_split = 400
18+
19+
train_ds = default_sam_dataset(
20+
raw_paths=data_path, raw_key="raw_downscaled",
21+
label_paths=data_path, label_key="segmentation/compartments",
22+
patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder,
23+
sampler=MinInstanceSampler(2), rois=np.s_[z_split:, :, :],
24+
n_samples=200,
25+
)
26+
train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2)
27+
28+
val_ds = default_sam_dataset(
29+
raw_paths=data_path, raw_key="raw_downscaled",
30+
label_paths=data_path, label_key="segmentation/compartments",
31+
patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder,
32+
sampler=MinInstanceSampler(2), rois=np.s_[:z_split, :, :],
33+
is_train=False, n_samples=25,
34+
)
35+
val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1)
36+
37+
train_sam(
38+
name="compartment_model", model_type="vit_b",
39+
train_loader=train_loader, val_loader=val_loader,
40+
n_epochs=100, n_objects_per_batch=10,
41+
with_segmentation_decoder=with_segmentation_decoder,
42+
)
43+
44+
45+
def normalize_trafo(raw):
46+
raw = raw.astype("float32")
47+
raw -= raw.min()
48+
raw /= raw.max()
49+
raw *= 255
50+
return raw
51+
52+
53+
def train_v2():
54+
data_root = "./output/postprocessed_annotations"
55+
paths = glob(os.path.join(data_root, "*.h5"))
56+
train_paths, val_paths = train_test_split(paths, test_size=0.1, random_state=42)
57+
58+
with_segmentation_decoder = True
59+
patch_shape = (462, 462)
60+
61+
train_ds = default_sam_dataset(
62+
raw_paths=train_paths, raw_key="data",
63+
label_paths=train_paths, label_key="labels/compartments",
64+
patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder,
65+
sampler=MinInstanceSampler(2), n_samples=250,
66+
raw_transform=normalize_trafo,
67+
)
68+
train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2)
69+
70+
val_ds = default_sam_dataset(
71+
raw_paths=val_paths, raw_key="data",
72+
label_paths=val_paths, label_key="labels/compartments",
73+
patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder,
74+
sampler=MinInstanceSampler(2), is_train=False, n_samples=25,
75+
raw_transform=normalize_trafo,
76+
)
77+
val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1)
78+
79+
train_sam(
80+
name="compartment_model_v2", model_type="vit_b",
81+
train_loader=train_loader, val_loader=val_loader,
82+
n_epochs=100, n_objects_per_batch=10,
83+
with_segmentation_decoder=with_segmentation_decoder,
84+
)
85+
86+
87+
def main():
88+
train_v2()
89+
90+
91+
if __name__ == "__main__":
92+
main()
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import numpy as np
6+
7+
from tqdm import tqdm
8+
9+
10+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2"
11+
SKIP_PREFIX = ("06", "08", "09")
12+
13+
14+
def main():
15+
n_tomograms = {}
16+
n_vesicles_imod = {}
17+
n_vesicles_auto = {}
18+
n_vesicles_total = {}
19+
20+
datasets = sorted(glob(os.path.join(ROOT, "*")))
21+
22+
for ds in tqdm(datasets):
23+
ds_name = os.path.basename(ds)
24+
if ds_name.startswith(SKIP_PREFIX):
25+
continue
26+
tomograms = glob(os.path.join(ds, "*.h5"))
27+
28+
n_ves_imod, n_ves_auto = 0, 0
29+
for tomo in tomograms:
30+
with h5py.File(tomo, "r") as f:
31+
ves_imod = f["/labels/vesicles/imod"][:]
32+
ves_auto = f["/labels/vesicles/additional_vesicles"][:]
33+
n_ves_imod += (len(np.unique(ves_imod)) - 1)
34+
n_ves_auto += (len(np.unique(ves_auto)) - 1)
35+
36+
n_tomograms[ds_name] = len(tomograms)
37+
n_vesicles_imod[ds_name] = n_ves_imod
38+
n_vesicles_auto[ds_name] = n_ves_auto
39+
n_vesicles_total[ds_name] = n_ves_imod + n_ves_auto
40+
41+
print("Total number of tomograms:")
42+
print(sum(n_tomograms.values()))
43+
44+
print("Total number of vesicles:")
45+
print(sum(n_vesicles_total.values()))
46+
47+
# TODO analyze the number of vesicles from IMOD and auto annotation further for the methods tile
48+
49+
50+
if __name__ == "__main__":
51+
main()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import os
2+
from glob import glob
3+
4+
ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/inner_ear_data"
5+
6+
7+
def main():
8+
tomograms = glob(os.path.join(ROOT, "**/*.h5"), recursive=True)
9+
print("Number of tomograms:")
10+
print(len(tomograms))
11+
12+
13+
main()

0 commit comments

Comments
 (0)