Skip to content

Implement scalable segmentation #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions synapse_net/inference/scalable_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import tempfile
from typing import Dict, List, Optional

import elf.parallel as parallel
import numpy as np
import torch

from elf.io import open_file
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
from elf.wrapper.base import MultiTransformationWrapper
from elf.wrapper.resized_volume import ResizedVolume
from numpy.typing import ArrayLike
from synapse_net.inference.util import get_prediction


class SelectChannel(SimpleTransformationWrapper):
"""Wrapper to select a chanel from an array-like dataset object.
Args:
volume: The array-like input dataset.
channel: The channel that will be selected.
"""
def __init__(self, volume: np.typing.ArrayLike, channel: int):
self.channel = channel
super().__init__(volume, lambda x: x[self.channel], with_channels=True)

@property
def shape(self):
return self._volume.shape[1:]

@property
def chunks(self):
return self._volume.chunks[1:]

@property
def ndim(self):
return self._volume.ndim - 1


def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape):
# Create wrappers for selecting the foreground and the boundary channel.
foreground = SelectChannel(pred, 0)
boundaries = SelectChannel(pred, 1)

# Create wrappers for subtracting and thresholding boundary subtracted from the foreground.
# And then compute the seeds based on this.
seed_input = ThresholdWrapper(
MultiTransformationWrapper(np.subtract, foreground, boundaries), seed_threshold
)
parallel.label(seed_input, seeds, verbose=verbose, block_shape=chunks)

# Run watershed to extend back from the seeds to the boundaries.
mask = ThresholdWrapper(foreground, 0.5)

# Resize if necessary.
if original_shape is not None:
boundaries = ResizedVolume(boundaries, original_shape, order=1)
seeds = ResizedVolume(seeds, original_shape, order=0)
mask = ResizedVolume(mask, original_shape, order=0)

parallel.seeded_watershed(
boundaries, seeds=seeds, out=output, verbose=verbose, mask=mask, block_shape=chunks, halo=3 * (16,)
)

# Run the size filter.
if min_size > 0:
parallel.size_filter(output, output, min_size=min_size, verbose=verbose, block_shape=chunks)


def scalable_segmentation(
input_: ArrayLike,
output: ArrayLike,
model: torch.nn.Module,
tiling: Optional[Dict[str, Dict[str, int]]] = None,
scale: Optional[List[float]] = None,
seed_threshold: float = 0.5,
min_size: int = 500,
prediction: Optional[ArrayLike] = None,
verbose: bool = True,
mask: Optional[ArrayLike] = None,
) -> None:
"""Run segmentation based on a prediction with foreground and boundary channel.
This function first subtracts the boundary prediction from the foreground prediction,
then applies a threshold, connected components, and a watershed to fit the components
back to the foreground. All processing steps are implemented in a scalable fashion,
so that the function runs for large input volumes.
Args:
input_: The input data.
output: The array for storing the output segmentation.
Can be a numpy array, a zarr array, or similar.
model: The model for prediction.
tiling: The tiling configuration for the prediction.
scale: The scale factor to use for rescaling the input volume before prediction.
seed_threshold: The threshold applied before computing connected components.
min_size: The minimum size of a vesicle to be considered.
prediction: The array for storing the prediction.
If given, this can be a numpy array, a zarr array, or similar
If not given will be stored in a temporary n5 array.
verbose: Whether to print timing information.
"""
if mask is not None:
raise NotImplementedError
assert model.out_channels == 2

# Create a temporary directory for storing the predictions.
chunks = (128,) * 3
with tempfile.TemporaryDirectory() as tmp_dir:

if scale is None or np.allclose(scale, 1.0, atol=1e-3):
original_shape = None
else:
original_shape = input_.shape
new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale))
input_ = ResizedVolume(input_, shape=new_shape, order=1)

if prediction is None:
# Create the dataset for storing the prediction.
tmp_pred = os.path.join(tmp_dir, "prediction.n5")
f = open_file(tmp_pred, mode="a")
pred_shape = (2,) + input_.shape
pred_chunks = (1,) + chunks
prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
else:
assert prediction.shape[0] == 2
assert prediction.shape[1:] == input_.shape

# Create temporary storage for the seeds.
tmp_seeds = os.path.join(tmp_dir, "seeds.n5")
f = open_file(tmp_seeds, mode="a")
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)

# Run prediction and segmentation.
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose)
_run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)
71 changes: 48 additions & 23 deletions synapse_net/inference/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# import xarray

from elf.io import open_file
from numpy.typing import ArrayLike
from scipy.ndimage import binary_closing
from skimage.measure import regionprops
from skimage.morphology import remove_small_holes
Expand Down Expand Up @@ -99,16 +100,32 @@ def rescale_output(self, output, is_segmentation):
return output


def _preprocess(input_volume, with_channels, channels_to_standardize):
# We standardize the data for the whole volume beforehand.
# If we have channels then the standardization is done independently per channel.
if with_channels:
input_volume = input_volume.astype(np.float32, copy=False)
# TODO Check that this is the correct axis.
if channels_to_standardize is None: # assume all channels
channels_to_standardize = range(input_volume.shape[0])
for ch in channels_to_standardize:
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
else:
input_volume = torch_em.transform.raw.standardize(input_volume)
return input_volume


def get_prediction(
input_volume: np.ndarray, # [z, y, x]
input_volume: ArrayLike, # [z, y, x]
tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
model_path: Optional[str] = None,
model: Optional[torch.nn.Module] = None,
verbose: bool = True,
with_channels: bool = False,
channels_to_standardize: Optional[List[int]] = None,
mask: Optional[np.ndarray] = None,
) -> np.ndarray:
mask: Optional[ArrayLike] = None,
prediction: Optional[ArrayLike] = None,
) -> ArrayLike:
"""Run prediction on a given volume.

This function will automatically choose the correct prediction implementation,
Expand All @@ -124,6 +141,8 @@ def get_prediction(
channels_to_standardize: List of channels to standardize. Defaults to None.
mask: Optional binary mask. If given, the prediction will only be run in
the foreground region of the mask.
prediction: An array like object for writing the prediction.
If not given, the prediction will be computed in moemory.

Returns:
The predicted volume.
Expand All @@ -140,17 +159,11 @@ def get_prediction(
if tiling is None:
tiling = get_default_tiling()

# We standardize the data for the whole volume beforehand.
# If we have channels then the standardization is done independently per channel.
if with_channels:
input_volume = input_volume.astype(np.float32, copy=False)
# TODO Check that this is the correct axis.
if channels_to_standardize is None: # assume all channels
channels_to_standardize = range(input_volume.shape[0])
for ch in channels_to_standardize:
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
else:
input_volume = torch_em.transform.raw.standardize(input_volume)
# Normalize the whole input volume if it is a numpy array.
# Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
# Normalization will be applied later per block in this case.
if isinstance(input_volume, np.ndarray):
input_volume = _preprocess(input_volume, with_channels, channels_to_standardize)

# Run prediction with the bioimage.io library.
if is_bioimageio:
Expand All @@ -174,21 +187,23 @@ def get_prediction(
for dim in tiling["tile"]:
updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
# print(f"updated_tiling {updated_tiling}")
pred = get_prediction_torch_em(
input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
prediction = get_prediction_torch_em(
input_volume, updated_tiling, model_path, model, verbose, with_channels,
mask=mask, prediction=prediction,
)

return pred
return prediction


def get_prediction_torch_em(
input_volume: np.ndarray, # [z, y, x]
input_volume: ArrayLike, # [z, y, x]
tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
model_path: Optional[str] = None,
model: Optional[torch.nn.Module] = None,
verbose: bool = True,
with_channels: bool = False,
mask: Optional[np.ndarray] = None,
mask: Optional[ArrayLike] = None,
prediction: Optional[ArrayLike] = None,
) -> np.ndarray:
"""Run prediction using torch-em on a given volume.

Expand All @@ -201,6 +216,8 @@ def get_prediction_torch_em(
with_channels: Whether to predict with channels.
mask: Optional binary mask. If given, the prediction will only be run in
the foreground region of the mask.
prediction: An array like object for writing the prediction.
If not given, the prediction will be computed in moemory.

Returns:
The predicted volume.
Expand Down Expand Up @@ -234,14 +251,16 @@ def get_prediction_torch_em(
print("Run prediction with mask.")
mask = mask.astype("bool")

pred = predict_with_halo(
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
prediction = predict_with_halo(
input_volume, model, gpu_ids=[device],
block_shape=block_shape, halo=halo,
preprocess=None, with_channels=with_channels, mask=mask,
preprocess=preprocess, with_channels=with_channels, mask=mask,
output=prediction,
)
if verbose:
print("Prediction time in", time.time() - t0, "s")
return pred
return prediction


def _get_file_paths(input_path, ext=".mrc"):
Expand Down Expand Up @@ -325,6 +344,7 @@ def inference_helper(
output_key: Optional[str] = None,
model_resolution: Optional[Tuple[float, float, float]] = None,
scale: Optional[Tuple[float, float, float]] = None,
allocate_output: bool = False,
) -> None:
"""Helper function to run segmentation for mrc files.

Expand All @@ -347,6 +367,7 @@ def inference_helper(
model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
allocate_output: Whether to allocate the output for the segmentation function.
"""
if (scale is not None) and (model_resolution is not None):
raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
Expand Down Expand Up @@ -412,7 +433,11 @@ def inference_helper(
this_scale = _derive_scale(img_path, model_resolution)

# Run the segmentation.
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
if allocate_output:
segmentation = np.zeros(input_volume.shape, dtype="uint32")
segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale)
else:
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)

# Write the result to tif or h5.
os.makedirs(os.path.split(output_path)[0], exist_ok=True)
Expand Down
26 changes: 23 additions & 3 deletions synapse_net/tools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch_em
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
from ..inference.scalable_segmentation import scalable_segmentation
from ..inference.util import inference_helper, parse_tiling


Expand Down Expand Up @@ -152,6 +153,10 @@ def segmentation_cli():
"--verbose", "-v", action="store_true",
help="Whether to print verbose information about the segmentation progress."
)
parser.add_argument(
"--scalable", action="store_true", help="Use the scalable segmentation implementation. "
"Currently this only works for vesicles, mitochondria, or active zones."
)
args = parser.parse_args()

if args.checkpoint is None:
Expand Down Expand Up @@ -181,11 +186,26 @@ def segmentation_cli():
model_resolution = None
scale = (2 if is_2d else 3) * (args.scale,)

segmentation_function = partial(
run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
)
if args.scalable:
if not args.model.startswith(("vesicle", "mito", "active")):
raise ValueError(
"The scalable segmentation implementation is currently only supported for "
f"vesicles, mitochondria, or active zones, not for {args.model}."
)
segmentation_function = partial(
scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose
)
allocate_output = True

else:
segmentation_function = partial(
run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling,
)
allocate_output = False

inference_helper(
args.input_path, args.output_path, segmentation_function,
mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
allocate_output=allocate_output
)
Loading