diff --git a/biahub/segment.py b/biahub/segment.py index 5ceb4c73..1d7e6d8c 100644 --- a/biahub/segment.py +++ b/biahub/segment.py @@ -1,3 +1,5 @@ +import warnings + from pathlib import Path import click @@ -28,10 +30,8 @@ def segment_data( segmentation_models: dict, gpu: bool = True, ) -> np.ndarray: - from cellpose import models - """ - Segment a CZYX image using a Cellpose segmentation model + Segment a CZYX image using Cellpose segmentation models. Parameters ---------- @@ -47,32 +47,50 @@ def segment_data( np.ndarray A CZYX segmentation image """ + from cellpose import models - # Segmenetation in cpu or gpu if gpu: try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + click.echo(f"Using GPU: {device}") except torch.cuda.CudaError: click.echo("No GPU available. Using CPU") device = torch.device("cpu") else: device = torch.device("cpu") - click.echo(f"Using device: {device}") + # Pre-load unique models to avoid redundant loading + unique_models = {} + for model_name, model_args in segmentation_models.items(): + model_path = model_args.path_to_model + if model_path not in unique_models: + click.echo(f"Loading model: {model_path}") + unique_models[model_path] = models.CellposeModel( + gpu=True if device.type == 'cuda' else False, + device=device, + pretrained_model=model_path, + ) + czyx_segmentation = [] - # Process each model in a loop - for i, (model_name, model_args) in enumerate(segmentation_models.items()): - click.echo(f"Segmenting with model {model_name}") + + # Process each model + for model_name, model_args in segmentation_models.items(): + click.echo(f"Starting segmentation with model {model_name}") + + # Extract the data we need for this model 2D or 3D z_slice_2D = model_args.z_slice_2D - czyx_data_to_segment = ( - czyx_data[:, z_slice_2D : z_slice_2D + 1] if z_slice_2D is not None else czyx_data - ) - # Apply preprocessing functions - preprocessing_functions = model_args.preprocessing - for preproc in preprocessing_functions: + if z_slice_2D is not None: + czyx_data_to_segment = czyx_data[:, z_slice_2D].copy() + z_axis = None + else: + czyx_data_to_segment = czyx_data.copy() + z_axis = 1 + click.echo(f"Segmenting {model_name} with z_axis {z_axis}") + # Apply preprocessing specific to this model + for preproc in model_args.preprocessing: func = preproc.function - kwargs = preproc.kwargs + kwargs = preproc.kwargs.copy() c_idx = preproc.channel # Convert list to tuple for out_range if needed @@ -82,21 +100,37 @@ def segment_data( click.echo( f"Processing with {func.__name__} with kwargs {kwargs} to channel {c_idx}" ) - czyx_data[c_idx] = func(czyx_data[c_idx], **kwargs) + czyx_data_to_segment[c_idx] = func(czyx_data_to_segment[c_idx], **kwargs) - # Apply the segmentation - model = models.CellposeModel( - model_type=model_args.path_to_model, gpu=gpu, device=device + # Prepare cellpose input + cellpose_czyx = np.zeros( + (3, *czyx_data_to_segment.shape[1:]), dtype=czyx_data_to_segment.dtype ) + for i, channel in enumerate(model_args.channels): + if channel is not None: + cellpose_czyx[i] = czyx_data_to_segment[channel] + + # Get pre-loaded model and run segmentation + click.echo(f"Running segmentation for {model_args.path_to_model}") + model = unique_models[model_args.path_to_model] segmentation, _, _ = model.eval( - czyx_data_to_segment, channel_axis=0, z_axis=1, **model_args.eval_args - ) # noqa: python-no-eval + cellpose_czyx, channel_axis=0, z_axis=z_axis, **model_args.eval_args + ) + + # Handle 2D output formatting if z_slice_2D is not None and isinstance(z_slice_2D, int): segmentation = segmentation[np.newaxis, ...] + czyx_segmentation.append(segmentation) - czyx_segmentation = np.stack(czyx_segmentation, axis=0) - return czyx_segmentation + # Clean up intermediate arrays + del cellpose_czyx, czyx_data_to_segment + + # Clean up GPU memory + if gpu and device.type == 'cuda': + torch.cuda.empty_cache() + + return np.stack(czyx_segmentation, axis=0).astype(np.uint32) @click.command("segment") @@ -121,6 +155,7 @@ def segment_cli( -i ./input.zarr/*/*/* \ -c ./segment_params.yml \ -o ./output.zarr + """ # Convert string paths to Path objects @@ -144,29 +179,30 @@ def segment_cli( # Load the segmentation models with their respective configurations # TODO: implement logic for 2D segmentation. Have a slicing parameter - segment_args = settings.models - C_segment = len(segment_args) - for model_name, model_args in segment_args.items(): + + C_segment = len(settings.models) + for model_name, model_args in settings.models.items(): if model_args.z_slice_2D is not None and isinstance(model_args.z_slice_2D, int): Z = 1 # Ensure channel names exist in the dataset - if not all(channel in channel_names for channel in model_args.eval_args["channels"]): + if not all(channel in channel_names for channel in model_args.channels): raise ValueError( - f"Channels {model_args.eval_args['channels']} not found in dataset {channel_names}" + f"Channels {model_args.channels} not found in dataset {channel_names}" ) - # Channel strings to indices with the cellpose offset of 1 - model_args.eval_args["channels"] = [ - channel_names.index(channel) + 1 for channel in model_args.eval_args["channels"] - ] - # NOTE:List of channels, either of length 2 or of length number of images by 2. - # First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). - # Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). - if len(model_args.eval_args["channels"]) < 2: - model_args.eval_args["channels"].append(0) - - click.echo( - f"Segmenting with model {model_name} using channels {model_args.eval_args['channels']}" - ) + # Channel strings to indices to be used in cellpose. Hiding this from the + model_args.channels = [channel_names.index(channel) for channel in model_args.channels] + # NOTE: Cellpose requires 3 channels. If the channels list is less than 3, the first channel is repeated. + + if len(model_args.channels) < 3: + model_args.channels.extend( + [model_args.channels[0]] * (3 - len(model_args.channels)) + ) + else: + warnings.warn( + f"Model {model_name} has more than 3 channels. Only the first 3 channels will be used." + ) + + click.echo(f"Segmenting {model_name} using channels {model_args.channels}") if ( "anisotropy" not in model_args.eval_args or model_args.eval_args["anisotropy"] is None @@ -193,27 +229,29 @@ def segment_cli( segmentation_shape = (T, C_segment, Z, Y, X) # Create a zarr store output to mirror the input + # Note, dtype is set to uint32. Change this if one envisions having more than 2^32 labels. create_empty_plate( store_path=output_dirpath, position_keys=[path.parts[-3:] for path in input_position_dirpaths], - channel_names=[model_name + "_labels" for model_name in segment_args.keys()], + channel_names=[model_name + "_labels" for model_name in settings.models.keys()], shape=segmentation_shape, chunks=None, scale=scale, + dtype=np.uint32, ) # Estimate resources - num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=20) + num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=10) num_gpus = 1 - slurm_time = np.ceil(np.max([80, T * 2.5])).astype(int) - slurm_array_parallelism = 100 + slurm_time = np.ceil(np.max([600, T * Z * 10])).astype(int) + slurm_array_parallelism = 9 # Prepare SLURM arguments slurm_args = { "slurm_job_name": "segment", "slurm_gres": f"gpu:{num_gpus}", "slurm_mem_per_cpu": f"{gb_ram_request}G", - "slurm_cpus_per_task": np.max([int(20 * 1.3), num_cpus]), - "slurm_array_parallelism": slurm_array_parallelism, # process up to 20 positions at a time + "slurm_cpus_per_task": np.max([int(slurm_array_parallelism * 2), num_cpus]), + "slurm_array_parallelism": slurm_array_parallelism, "slurm_time": slurm_time, "slurm_partition": "gpu", } @@ -239,13 +277,13 @@ def segment_cli( jobs.append( executor.submit( process_single_position, - segment_data, - input_position_path, - output_position_path, + func=segment_data, + input_position_path=input_position_path, + output_position_path=output_position_path, input_channel_indices=[list(range(C))], output_channel_indices=[list(range(C_segment))], - num_processes=np.min([20, int(num_cpus * 0.8)]), - segmentation_models=segment_args, + num_processes=np.min([5, int(num_cpus * 0.8)]), + segmentation_models=settings.models, ) ) diff --git a/biahub/settings.py b/biahub/settings.py index eb505a68..1ae4e287 100644 --- a/biahub/settings.py +++ b/biahub/settings.py @@ -613,6 +613,7 @@ class PreprocessingFunctions(BaseModel): class SegmentationModel(BaseModel): path_to_model: str eval_args: Dict[str, Any] + channels: list[str] z_slice_2D: Optional[int] = None preprocessing: list[PreprocessingFunctions] = [] @@ -648,3 +649,4 @@ def check_z_slice_with_do_3D(cls, z_slice_2D, info: ValidationInfo): class SegmentationSettings(BaseModel): models: Dict[str, SegmentationModel] model_config = {"extra": "forbid", "protected_namespaces": ()} + gpu: bool = True diff --git a/pyproject.toml b/pyproject.toml index b3337a2a..1ce3bdfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ track = [ ] segment = [ - "cellpose", + "cellpose>=4.0.5", ] build = ["build", "twine"] diff --git a/settings/example_segmentation_settings.yml b/settings/example_segmentation_settings.yml index a9fa6bf0..37dddfb7 100644 --- a/settings/example_segmentation_settings.yml +++ b/settings/example_segmentation_settings.yml @@ -3,18 +3,17 @@ models: # One can instantiate as many models membrane: - path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)" - # These are the common model.CellposeModel().eval() arguments used, but one can add more. - # For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0 + path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model. eval_args: diameter: 65 - channels: ['mem', 'nuc'] #The channel count for Cellpose starts at 1 cellprob_threshold: 0.4 invert: false - do_3D: false # Optional, if false, 2D segmentation is performed. + do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided. anisotropy: 3.26 min_size: 8000 - z_slice_2D: 10 # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True + normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity + z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True + channels: ['mem', 'nuc'] preprocessing: - function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity kwargs: {"out_range": [0, 1]} @@ -22,16 +21,23 @@ models: - function: skimage.exposure.equalize_adapthist kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]} channel: 'mem' + # One can instantiate as many models nucleus: - path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)" - # These are the common model.CellposeModel().eval() arguments used, but one can add more. - # For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0 + path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model. eval_args: - diameter: 60 - channels: ['nuc'] #For nucleus segmentation, only one channel is required. We populate the other channel with zero. - cellprob_threshold: 0.0 + diameter: 55 + cellprob_threshold: 0.4 invert: false - do_3D: true + do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided. anisotropy: 3.26 min_size: 8000 + normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True + channels: ['nuc','mem'] + preprocessing: + - function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity + kwargs: {"out_range": [0, 1]} + channel: 'nuc' + - function: skimage.exposure.equalize_adapthist + kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]} + channel: 'nuc'