Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 90 additions & 52 deletions biahub/segment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from pathlib import Path

import click
Expand Down Expand Up @@ -28,10 +30,8 @@ def segment_data(
segmentation_models: dict,
gpu: bool = True,
) -> np.ndarray:
from cellpose import models

"""
Segment a CZYX image using a Cellpose segmentation model
Segment a CZYX image using Cellpose segmentation models.

Parameters
----------
Expand All @@ -47,32 +47,50 @@ def segment_data(
np.ndarray
A CZYX segmentation image
"""
from cellpose import models

# Segmenetation in cpu or gpu
if gpu:
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
click.echo(f"Using GPU: {device}")
except torch.cuda.CudaError:
click.echo("No GPU available. Using CPU")
device = torch.device("cpu")
else:
device = torch.device("cpu")

click.echo(f"Using device: {device}")

# Pre-load unique models to avoid redundant loading
unique_models = {}
for model_name, model_args in segmentation_models.items():
model_path = model_args.path_to_model
if model_path not in unique_models:
click.echo(f"Loading model: {model_path}")
unique_models[model_path] = models.CellposeModel(
gpu=True if device.type == 'cuda' else False,
device=device,
pretrained_model=model_path,
)

czyx_segmentation = []
# Process each model in a loop
for i, (model_name, model_args) in enumerate(segmentation_models.items()):
click.echo(f"Segmenting with model {model_name}")

# Process each model
for model_name, model_args in segmentation_models.items():
click.echo(f"Starting segmentation with model {model_name}")

# Extract the data we need for this model 2D or 3D
z_slice_2D = model_args.z_slice_2D
czyx_data_to_segment = (
czyx_data[:, z_slice_2D : z_slice_2D + 1] if z_slice_2D is not None else czyx_data
)
# Apply preprocessing functions
preprocessing_functions = model_args.preprocessing
for preproc in preprocessing_functions:
if z_slice_2D is not None:
czyx_data_to_segment = czyx_data[:, z_slice_2D].copy()
z_axis = None
else:
czyx_data_to_segment = czyx_data.copy()
z_axis = 1
click.echo(f"Segmenting {model_name} with z_axis {z_axis}")
# Apply preprocessing specific to this model
for preproc in model_args.preprocessing:
func = preproc.function
kwargs = preproc.kwargs
kwargs = preproc.kwargs.copy()
c_idx = preproc.channel

# Convert list to tuple for out_range if needed
Expand All @@ -82,21 +100,37 @@ def segment_data(
click.echo(
f"Processing with {func.__name__} with kwargs {kwargs} to channel {c_idx}"
)
czyx_data[c_idx] = func(czyx_data[c_idx], **kwargs)
czyx_data_to_segment[c_idx] = func(czyx_data_to_segment[c_idx], **kwargs)

# Apply the segmentation
model = models.CellposeModel(
model_type=model_args.path_to_model, gpu=gpu, device=device
# Prepare cellpose input
cellpose_czyx = np.zeros(
(3, *czyx_data_to_segment.shape[1:]), dtype=czyx_data_to_segment.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the output datatype? If it's not stored in the same as input, it should always be integers whose width is determined automatically by cellpose. For example storing uint32 in a float32 container is lossy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The real solution here would be to save integer labels in a separate labels array in the zarr store. For now we save everything in one float32 array, which is not ideal. I'd suggest casting cellpose_czyx to float32 and leaving a note that we should fix that later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if the output store is not the same as input it should be integers. There is no reason for a store that only have integer arrays to be initialized with float32 data type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And since we do a lot of 2D segmentation, they wouldn't be stored with the 3D input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can default it to uint32.

)
for i, channel in enumerate(model_args.channels):
if channel is not None:
cellpose_czyx[i] = czyx_data_to_segment[channel]

# Get pre-loaded model and run segmentation
click.echo(f"Running segmentation for {model_args.path_to_model}")
model = unique_models[model_args.path_to_model]
segmentation, _, _ = model.eval(
czyx_data_to_segment, channel_axis=0, z_axis=1, **model_args.eval_args
) # noqa: python-no-eval
cellpose_czyx, channel_axis=0, z_axis=z_axis, **model_args.eval_args
)

# Handle 2D output formatting
if z_slice_2D is not None and isinstance(z_slice_2D, int):
segmentation = segmentation[np.newaxis, ...]

czyx_segmentation.append(segmentation)
czyx_segmentation = np.stack(czyx_segmentation, axis=0)

return czyx_segmentation
# Clean up intermediate arrays
del cellpose_czyx, czyx_data_to_segment

# Clean up GPU memory
if gpu and device.type == 'cuda':
torch.cuda.empty_cache()

return np.stack(czyx_segmentation, axis=0).astype(np.uint32)


@click.command("segment")
Expand All @@ -121,6 +155,7 @@ def segment_cli(
-i ./input.zarr/*/*/* \
-c ./segment_params.yml \
-o ./output.zarr

"""

# Convert string paths to Path objects
Expand All @@ -144,29 +179,30 @@ def segment_cli(

# Load the segmentation models with their respective configurations
# TODO: implement logic for 2D segmentation. Have a slicing parameter
segment_args = settings.models
C_segment = len(segment_args)
for model_name, model_args in segment_args.items():

C_segment = len(settings.models)
for model_name, model_args in settings.models.items():
if model_args.z_slice_2D is not None and isinstance(model_args.z_slice_2D, int):
Z = 1
# Ensure channel names exist in the dataset
if not all(channel in channel_names for channel in model_args.eval_args["channels"]):
if not all(channel in channel_names for channel in model_args.channels):
raise ValueError(
f"Channels {model_args.eval_args['channels']} not found in dataset {channel_names}"
f"Channels {model_args.channels} not found in dataset {channel_names}"
)
# Channel strings to indices with the cellpose offset of 1
model_args.eval_args["channels"] = [
channel_names.index(channel) + 1 for channel in model_args.eval_args["channels"]
]
# NOTE:List of channels, either of length 2 or of length number of images by 2.
# First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
# Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
if len(model_args.eval_args["channels"]) < 2:
model_args.eval_args["channels"].append(0)

click.echo(
f"Segmenting with model {model_name} using channels {model_args.eval_args['channels']}"
)
# Channel strings to indices to be used in cellpose. Hiding this from the
model_args.channels = [channel_names.index(channel) for channel in model_args.channels]
# NOTE: Cellpose requires 3 channels. If the channels list is less than 3, the first channel is repeated.

if len(model_args.channels) < 3:
model_args.channels.extend(
[model_args.channels[0]] * (3 - len(model_args.channels))
)
else:
warnings.warn(
f"Model {model_name} has more than 3 channels. Only the first 3 channels will be used."
)

click.echo(f"Segmenting {model_name} using channels {model_args.channels}")
if (
"anisotropy" not in model_args.eval_args
or model_args.eval_args["anisotropy"] is None
Expand All @@ -193,27 +229,29 @@ def segment_cli(
segmentation_shape = (T, C_segment, Z, Y, X)

# Create a zarr store output to mirror the input
# Note, dtype is set to uint32. Change this if one envisions having more than 2^32 labels.
create_empty_plate(
store_path=output_dirpath,
position_keys=[path.parts[-3:] for path in input_position_dirpaths],
channel_names=[model_name + "_labels" for model_name in segment_args.keys()],
channel_names=[model_name + "_labels" for model_name in settings.models.keys()],
shape=segmentation_shape,
chunks=None,
scale=scale,
dtype=np.uint32,
)

# Estimate resources
num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=20)
num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a lot of funny math to estimate the resources. Does estimate_resources not do everything you need? For example, it now allows for specifying the max number of CPUs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest leaving reasonable defaults in the code and using an sbatch file to tune these parameters as needed for specific reconstructions or depending on the cluster usage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same as before. We just dont use the CPUs that much.

num_gpus = 1
slurm_time = np.ceil(np.max([80, T * 2.5])).astype(int)
slurm_array_parallelism = 100
slurm_time = np.ceil(np.max([600, T * Z * 10])).astype(int)
slurm_array_parallelism = 9
# Prepare SLURM arguments
slurm_args = {
"slurm_job_name": "segment",
"slurm_gres": f"gpu:{num_gpus}",
"slurm_mem_per_cpu": f"{gb_ram_request}G",
"slurm_cpus_per_task": np.max([int(20 * 1.3), num_cpus]),
"slurm_array_parallelism": slurm_array_parallelism, # process up to 20 positions at a time
"slurm_cpus_per_task": np.max([int(slurm_array_parallelism * 2), num_cpus]),
"slurm_array_parallelism": slurm_array_parallelism,
"slurm_time": slurm_time,
"slurm_partition": "gpu",
}
Expand All @@ -239,13 +277,13 @@ def segment_cli(
jobs.append(
executor.submit(
process_single_position,
segment_data,
input_position_path,
output_position_path,
func=segment_data,
input_position_path=input_position_path,
output_position_path=output_position_path,
input_channel_indices=[list(range(C))],
output_channel_indices=[list(range(C_segment))],
num_processes=np.min([20, int(num_cpus * 0.8)]),
segmentation_models=segment_args,
num_processes=np.min([5, int(num_cpus * 0.8)]),
segmentation_models=settings.models,
)
)

Expand Down
2 changes: 2 additions & 0 deletions biahub/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ class PreprocessingFunctions(BaseModel):
class SegmentationModel(BaseModel):
path_to_model: str
eval_args: Dict[str, Any]
channels: list[str]
z_slice_2D: Optional[int] = None
preprocessing: list[PreprocessingFunctions] = []

Expand Down Expand Up @@ -648,3 +649,4 @@ def check_z_slice_with_do_3D(cls, z_slice_2D, info: ValidationInfo):
class SegmentationSettings(BaseModel):
models: Dict[str, SegmentationModel]
model_config = {"extra": "forbid", "protected_namespaces": ()}
gpu: bool = True
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ track = [
]

segment = [
"cellpose",
"cellpose>=4.0.5",
]

build = ["build", "twine"]
Expand Down
32 changes: 19 additions & 13 deletions settings/example_segmentation_settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,41 @@
models:
# One can instantiate as many models
membrane:
path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)"
# These are the common model.CellposeModel().eval() arguments used, but one can add more.
# For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0
path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model.
eval_args:
diameter: 65
channels: ['mem', 'nuc'] #The channel count for Cellpose starts at 1
cellprob_threshold: 0.4
invert: false
do_3D: false # Optional, if false, 2D segmentation is performed.
do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided.
anisotropy: 3.26
min_size: 8000
z_slice_2D: 10 # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True
normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity
z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True
channels: ['mem', 'nuc']
preprocessing:
- function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity
kwargs: {"out_range": [0, 1]}
channel: 'mem'
- function: skimage.exposure.equalize_adapthist
kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]}
channel: 'mem'
# One can instantiate as many models
nucleus:
path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)"
# These are the common model.CellposeModel().eval() arguments used, but one can add more.
# For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0
path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model.
eval_args:
diameter: 60
channels: ['nuc'] #For nucleus segmentation, only one channel is required. We populate the other channel with zero.
cellprob_threshold: 0.0
diameter: 55
cellprob_threshold: 0.4
invert: false
do_3D: true
do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided.
anisotropy: 3.26
min_size: 8000
normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity
z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True
channels: ['nuc','mem']
preprocessing:
- function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity
kwargs: {"out_range": [0, 1]}
channel: 'nuc'
- function: skimage.exposure.equalize_adapthist
kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]}
channel: 'nuc'