Skip to content

Commit 6011472

Browse files
Update compartment segmentation code
1 parent c5fceb7 commit 6011472

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

scripts/cooper/ground_truth/compartments/run_prediction_04.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import h5py
55
from tqdm import tqdm
66

7-
from synaptic_reconstruction.inference.utils import _Scaler
7+
from synaptic_reconstruction.inference.util import _Scaler
88
from synaptic_reconstruction.inference.compartments import segment_compartments
99

10-
INPUT_ROOT = ""
11-
MODEL_PATH = ""
10+
INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval" # noqa
11+
MODEL_PATH = "/mnt/lustre-emmy-hdd/projects/nim00007/compartment_models/compartment_model_3d.pt"
1212
OUTPUT = "./predictions"
1313

1414

@@ -17,10 +17,12 @@ def segment_volume(input_path, model_path):
1717
raw = f["raw"][:]
1818

1919
scale = (0.25, 0.25, 0.25)
20-
scaler = _Scaler(scale)
20+
scaler = _Scaler(scale, verbose=False)
2121
raw = scaler.scale_input(raw)
2222

23-
seg = segment_compartments(raw, model_path, verbose=False)
23+
n_slices_exclude = 4
24+
seg = segment_compartments(raw, model_path, verbose=False, n_slices_exclude=n_slices_exclude)
25+
raw, seg = raw[n_slices_exclude:-n_slices_exclude], seg[n_slices_exclude:-n_slices_exclude]
2426

2527
return raw, seg
2628

synaptic_reconstruction/inference/compartments.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def segment_compartments(
117117
return_predictions: bool = False,
118118
scale: Optional[List[float]] = None,
119119
mask: Optional[np.ndarray] = None,
120+
n_slices_exclude: int = 5,
120121
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
121122
"""
122123
Segment synaptic compartments in an input volume.
@@ -129,6 +130,7 @@ def segment_compartments(
129130
verbose: Whether to print timing information.
130131
return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
131132
scale: The scale factor to use for rescaling the input volume before prediction.
133+
n_slices_exclude:
132134
133135
Returns:
134136
The segmentation mask as a numpy array, or a tuple containing the segmentation mask
@@ -156,7 +158,7 @@ def segment_compartments(
156158
if input_volume.ndim == 2:
157159
seg = _segment_compartments_2d(pred)
158160
else:
159-
seg = _segment_compartments_3d(pred)
161+
seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude)
160162
if verbose:
161163
print("Run segmentation in", time.time() - t0, "s")
162164
seg = scaler.rescale_output(seg, is_segmentation=True)

synaptic_reconstruction/inference/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def get_default_tiling():
388388
tile = {"x": 352, "y": 352, "z": 48}
389389
else:
390390
# TODO determine tilings for smaller VRAM
391-
raise NotImplementedError
391+
raise NotImplementedError(f"Estimating the tile size for a GPU with {vram} GB is not yet supported.")
392392

393393
print(f"Determined tile size: {tile}")
394394
tiling = {"tile": tile, "halo": halo}

0 commit comments

Comments
 (0)