Skip to content

Commit aad578c

Browse files
Update compartment model training
1 parent 6011472 commit aad578c

File tree

4 files changed

+56
-27
lines changed

4 files changed

+56
-27
lines changed

scripts/cooper/ground_truth/compartments/run_prediction_04.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
from synaptic_reconstruction.inference.compartments import segment_compartments
99

1010
INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval" # noqa
11-
MODEL_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/compartment_models/compartment_model_3d.pt"
12-
OUTPUT = "./predictions"
11+
# MODEL_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/compartment_models/compartment_model_3d.pt"
12+
MODEL_PATH = "/user/pape41/u12086/Work/my_projects/synaptic-reconstruction/scripts/cooper/training/checkpoints/compartment_model_3d/v2" # noqa
13+
OUTPUT = "/mnt/lustre-emmy-hdd/projects/nim00007/compartment_predictions"
14+
15+
16+
def label_transform_3d():
17+
pass
1318

1419

1520
def segment_volume(input_path, model_path):
@@ -20,29 +25,39 @@ def segment_volume(input_path, model_path):
2025
scaler = _Scaler(scale, verbose=False)
2126
raw = scaler.scale_input(raw)
2227

23-
n_slices_exclude = 4
24-
seg = segment_compartments(raw, model_path, verbose=False, n_slices_exclude=n_slices_exclude)
25-
raw, seg = raw[n_slices_exclude:-n_slices_exclude], seg[n_slices_exclude:-n_slices_exclude]
28+
n_slices_exclude = 2
29+
seg, pred = segment_compartments(
30+
raw, model_path, verbose=False, n_slices_exclude=n_slices_exclude, return_predictions=True
31+
)
32+
# raw, seg = raw[n_slices_exclude:-n_slices_exclude], seg[n_slices_exclude:-n_slices_exclude]
2633

27-
return raw, seg
34+
return raw, seg, pred
2835

2936

3037
def main():
3138
inputs = sorted(glob(os.path.join(INPUT_ROOT, "**/*.h5"), recursive=True))
32-
for input_path in tqdm(inputs):
39+
inputs = [inp for inp in inputs if "cropped_for_2D" not in inp]
40+
41+
for input_path in tqdm(inputs, desc="Run prediction for 04."):
3342
ds_name, fname = os.path.split(input_path)
3443
ds_name = os.path.split(ds_name)[1]
35-
output_folder = os.path.join(OUTPUT, ds_name)
44+
output_folder = os.path.join(OUTPUT, "segmentation", ds_name)
3645
output_path = os.path.join(output_folder, fname)
3746

3847
if os.path.exists(output_path):
3948
continue
4049

41-
raw, seg = segment_volume(input_path, MODEL_PATH)
50+
pred_folder = os.path.join(OUTPUT, "prediction", ds_name)
51+
os.makedirs(pred_folder, exist_ok=True)
52+
pred_path = os.path.join(pred_folder, fname)
53+
54+
raw, seg, pred = segment_volume(input_path, MODEL_PATH)
4255
os.makedirs(output_folder, exist_ok=True)
4356
with h5py.File(output_path, "a") as f:
4457
f.create_dataset("raw", data=raw, compression="gzip")
4558
f.create_dataset("labels/compartments", data=seg, compression="gzip")
59+
with h5py.File(pred_path, "a") as f:
60+
f.create_dataset("prediction", data=pred, compression="gzip")
4661

4762

4863
if __name__ == "__main__":

scripts/cooper/training/train_compartments.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from glob import glob
33

44
import numpy as np
5-
from sklearn.model_selection import train_test_split
5+
import torch_em
66

7+
from sklearn.model_selection import train_test_split
78
from skimage import img_as_ubyte
89
from skimage.segmentation import find_boundaries
910
from skimage.filters import gaussian, rank
1011
from skimage.morphology import disk
11-
from scipy.ndimage import binary_dilation
12+
from scipy.ndimage import binary_dilation, distance_transform_edt
1213

1314
from synaptic_reconstruction.training import supervised_training
1415

@@ -23,19 +24,26 @@ def get_paths_2d():
2324

2425
def get_paths_3d():
2526
paths = sorted(glob(os.path.join(TRAIN_ROOT, "v2", "**", "*.h5"), recursive=True))
27+
paths += sorted(glob(os.path.join(TRAIN_ROOT, "v3", "**", "*.h5"), recursive=True))
2628
return paths
2729

2830

2931
def label_transform_2d(seg):
30-
boundaries = find_boundaries(seg).astype("float32")
31-
boundaries = gaussian(boundaries, sigma=1.0)
32+
boundaries = find_boundaries(seg)
33+
distances = distance_transform_edt(~seg).astype("float32")
34+
distances /= distances.max()
35+
36+
boundaries = gaussian(boundaries.astype("float32"), sigma=1.0)
3237
boundaries = rank.autolevel(img_as_ubyte(boundaries), disk(8)).astype("float") / 255
33-
mask = binary_dilation(seg != 0, iterations=8)
34-
return np.stack([boundaries, mask])
38+
39+
distance_mask = seg != 0
40+
boundary_mask = binary_dilation(distance_mask, iterations=8)
41+
42+
return np.stack([boundaries, distances, boundary_mask, distance_mask])
3543

3644

3745
def label_transform_3d(seg):
38-
output = np.zeros((2,) + seg.shape, dtype="float32")
46+
output = np.zeros((4,) + seg.shape, dtype="float32")
3947
for z in range(seg.shape[0]):
4048
out = label_transform_2d(seg[z])
4149
output[:, z] = out
@@ -70,18 +78,21 @@ def train_compartments_2d_v1():
7078
)
7179

7280

73-
def train_compartments_3d_v1():
81+
def train_compartments_3d_v2():
7482
paths = get_paths_3d()
75-
train_paths, val_paths = train_test_split(paths, test_size=0.15, random_state=42)
83+
train_paths, val_paths = train_test_split(paths, test_size=0.10, random_state=42)
84+
print("Number of train paths:", len(train_paths))
85+
print("Number of val paths:", len(val_paths))
7686

7787
patch_shape = (64, 384, 384)
7888
batch_size = 1
7989

8090
check = False
91+
sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=2)
8192

8293
save_root = "."
8394
supervised_training(
84-
name="compartment_model_3d/v1",
95+
name="compartment_model_3d/v2",
8596
train_paths=train_paths,
8697
val_paths=val_paths,
8798
label_key="/labels/compartments",
@@ -90,16 +101,18 @@ def train_compartments_3d_v1():
90101
save_root=save_root,
91102
label_transform=label_transform_3d,
92103
mask_channel=True,
93-
n_samples_train=100,
94-
n_samples_val=10,
95-
n_iterations=int(2e4),
96-
out_channels=1,
104+
n_samples_train=250,
105+
n_samples_val=25,
106+
n_iterations=int(5e4),
107+
out_channels=2,
108+
sampler=sampler,
109+
num_workers=8,
97110
)
98111

99112

100113
def main():
101114
# train_compartments_2d_v1()
102-
train_compartments_3d_v1()
115+
train_compartments_3d_v2()
103116

104117

105118
if __name__ == "__main__":

synaptic_reconstruction/inference/compartments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ def segment_compartments(
143143
scaler = _Scaler(scale, verbose)
144144
input_volume = scaler.scale_input(input_volume)
145145

146-
# Run prediction.
146+
# Run prediction. Support models with a single or multiple channels,
147+
# assuming that the first channel is the boundary prediction.
147148
pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose)
148149

149150
# Remove channel axis if necessary.
150151
if pred.ndim != input_volume.ndim:
151152
assert pred.ndim == input_volume.ndim + 1
152-
assert pred.shape[0] == 1
153153
pred = pred[0]
154154

155155
# Run the compartment segmentation.

synaptic_reconstruction/training/supervised_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,8 @@ def supervised_training(
266266
elif mask_channel:
267267
loss = torch_em.loss.LossWrapper(
268268
loss=torch_em.loss.DiceLoss(),
269-
transform=torch_em.loss.wrapper.ApplyAndRemoveMask()
269+
transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
270+
masking_method="crop" if out_channels == 1 else "multiply")
270271
)
271272
metric = loss
272273
else:

0 commit comments

Comments
 (0)