Skip to content

Commit 78825be

Browse files
committed
add option to change boundary threshold for compartment seg
1 parent db3e654 commit 78825be

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

scripts/cooper/run_compartment_segmentation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def run_compartment_segmentation(args):
1515
model_path = args.model
1616

1717
segmentation_function = partial(
18-
segment_compartments, model_path=model_path, verbose=False, tiling=tiling, scale=[0.25, 0.25, 0.25]
18+
segment_compartments, model_path=model_path, verbose=False, tiling=tiling, scale=[0.25, 0.25, 0.25], boundary_threshold=args.boundary_threshold
1919
)
2020
inference_helper(
2121
args.input_path, args.output_path, segmentation_function, force=args.force, data_ext=args.data_ext
@@ -50,6 +50,9 @@ def main():
5050
parser.add_argument(
5151
"--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
5252
)
53+
parser.add_argument(
54+
"--boundary_threshold", type=float, default=0.4, help="Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM."
55+
)
5356

5457
args = parser.parse_args()
5558
run_compartment_segmentation(args)

synapse_net/inference/compartments.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def segment_compartments(
155155
scale: Optional[List[float]] = None,
156156
mask: Optional[np.ndarray] = None,
157157
n_slices_exclude: int = 0,
158+
boundary_threshold: float=0.4,
158159
**kwargs,
159160
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
160161
"""Segment synaptic compartments in an input volume.
@@ -168,6 +169,7 @@ def segment_compartments(
168169
return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
169170
scale: The scale factor to use for rescaling the input volume before prediction.
170171
n_slices_exclude:
172+
boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM.
171173
172174
Returns:
173175
The segmentation mask as a numpy array, or a tuple containing the segmentation mask
@@ -193,9 +195,9 @@ def segment_compartments(
193195
# We may want to expose some of the parameters here.
194196
t0 = time.time()
195197
if input_volume.ndim == 2:
196-
seg = _segment_compartments_2d(pred)
198+
seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold)
197199
else:
198-
seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude)
200+
seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude, boundary_threshold=boundary_threshold)
199201
if verbose:
200202
print("Run segmentation in", time.time() - t0, "s")
201203

0 commit comments

Comments
 (0)