-
Notifications
You must be signed in to change notification settings - Fork 2
Updating the segmentation to use Cellpose SAM #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4dbf169
a234a41
6d9245c
8101134
95cb565
7443759
202f943
0f3ed50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import warnings | ||
|
||
from pathlib import Path | ||
|
||
import click | ||
|
@@ -28,10 +30,8 @@ def segment_data( | |
segmentation_models: dict, | ||
gpu: bool = True, | ||
) -> np.ndarray: | ||
from cellpose import models | ||
|
||
""" | ||
Segment a CZYX image using a Cellpose segmentation model | ||
Segment a CZYX image using Cellpose segmentation models. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -47,32 +47,50 @@ def segment_data( | |
np.ndarray | ||
A CZYX segmentation image | ||
""" | ||
from cellpose import models | ||
|
||
# Segmenetation in cpu or gpu | ||
if gpu: | ||
try: | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
click.echo(f"Using GPU: {device}") | ||
except torch.cuda.CudaError: | ||
click.echo("No GPU available. Using CPU") | ||
device = torch.device("cpu") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
click.echo(f"Using device: {device}") | ||
|
||
# Pre-load unique models to avoid redundant loading | ||
unique_models = {} | ||
for model_name, model_args in segmentation_models.items(): | ||
model_path = model_args.path_to_model | ||
if model_path not in unique_models: | ||
click.echo(f"Loading model: {model_path}") | ||
unique_models[model_path] = models.CellposeModel( | ||
gpu=True if device.type == 'cuda' else False, | ||
device=device, | ||
pretrained_model=model_path, | ||
) | ||
|
||
czyx_segmentation = [] | ||
# Process each model in a loop | ||
for i, (model_name, model_args) in enumerate(segmentation_models.items()): | ||
click.echo(f"Segmenting with model {model_name}") | ||
|
||
# Process each model | ||
for model_name, model_args in segmentation_models.items(): | ||
click.echo(f"Starting segmentation with model {model_name}") | ||
|
||
# Extract the data we need for this model 2D or 3D | ||
z_slice_2D = model_args.z_slice_2D | ||
czyx_data_to_segment = ( | ||
czyx_data[:, z_slice_2D : z_slice_2D + 1] if z_slice_2D is not None else czyx_data | ||
) | ||
# Apply preprocessing functions | ||
preprocessing_functions = model_args.preprocessing | ||
for preproc in preprocessing_functions: | ||
if z_slice_2D is not None: | ||
czyx_data_to_segment = czyx_data[:, z_slice_2D].copy() | ||
z_axis = None | ||
else: | ||
czyx_data_to_segment = czyx_data.copy() | ||
z_axis = 1 | ||
click.echo(f"Segmenting {model_name} with z_axis {z_axis}") | ||
# Apply preprocessing specific to this model | ||
for preproc in model_args.preprocessing: | ||
func = preproc.function | ||
kwargs = preproc.kwargs | ||
kwargs = preproc.kwargs.copy() | ||
c_idx = preproc.channel | ||
|
||
# Convert list to tuple for out_range if needed | ||
|
@@ -82,21 +100,37 @@ def segment_data( | |
click.echo( | ||
f"Processing with {func.__name__} with kwargs {kwargs} to channel {c_idx}" | ||
) | ||
czyx_data[c_idx] = func(czyx_data[c_idx], **kwargs) | ||
czyx_data_to_segment[c_idx] = func(czyx_data_to_segment[c_idx], **kwargs) | ||
|
||
# Apply the segmentation | ||
model = models.CellposeModel( | ||
model_type=model_args.path_to_model, gpu=gpu, device=device | ||
# Prepare cellpose input | ||
cellpose_czyx = np.zeros( | ||
(3, *czyx_data_to_segment.shape[1:]), dtype=czyx_data_to_segment.dtype | ||
) | ||
for i, channel in enumerate(model_args.channels): | ||
if channel is not None: | ||
cellpose_czyx[i] = czyx_data_to_segment[channel] | ||
|
||
# Get pre-loaded model and run segmentation | ||
click.echo(f"Running segmentation for {model_args.path_to_model}") | ||
model = unique_models[model_args.path_to_model] | ||
segmentation, _, _ = model.eval( | ||
czyx_data_to_segment, channel_axis=0, z_axis=1, **model_args.eval_args | ||
) # noqa: python-no-eval | ||
cellpose_czyx, channel_axis=0, z_axis=z_axis, **model_args.eval_args | ||
) | ||
|
||
# Handle 2D output formatting | ||
if z_slice_2D is not None and isinstance(z_slice_2D, int): | ||
segmentation = segmentation[np.newaxis, ...] | ||
|
||
czyx_segmentation.append(segmentation) | ||
czyx_segmentation = np.stack(czyx_segmentation, axis=0) | ||
|
||
return czyx_segmentation | ||
# Clean up intermediate arrays | ||
del cellpose_czyx, czyx_data_to_segment | ||
|
||
# Clean up GPU memory | ||
if gpu and device.type == 'cuda': | ||
torch.cuda.empty_cache() | ||
|
||
return np.stack(czyx_segmentation, axis=0).astype(np.uint32) | ||
|
||
|
||
@click.command("segment") | ||
|
@@ -121,6 +155,7 @@ def segment_cli( | |
-i ./input.zarr/*/*/* \ | ||
-c ./segment_params.yml \ | ||
-o ./output.zarr | ||
|
||
""" | ||
|
||
# Convert string paths to Path objects | ||
|
@@ -144,29 +179,30 @@ def segment_cli( | |
|
||
# Load the segmentation models with their respective configurations | ||
# TODO: implement logic for 2D segmentation. Have a slicing parameter | ||
segment_args = settings.models | ||
C_segment = len(segment_args) | ||
for model_name, model_args in segment_args.items(): | ||
|
||
C_segment = len(settings.models) | ||
for model_name, model_args in settings.models.items(): | ||
if model_args.z_slice_2D is not None and isinstance(model_args.z_slice_2D, int): | ||
Z = 1 | ||
# Ensure channel names exist in the dataset | ||
if not all(channel in channel_names for channel in model_args.eval_args["channels"]): | ||
if not all(channel in channel_names for channel in model_args.channels): | ||
raise ValueError( | ||
f"Channels {model_args.eval_args['channels']} not found in dataset {channel_names}" | ||
f"Channels {model_args.channels} not found in dataset {channel_names}" | ||
) | ||
# Channel strings to indices with the cellpose offset of 1 | ||
model_args.eval_args["channels"] = [ | ||
channel_names.index(channel) + 1 for channel in model_args.eval_args["channels"] | ||
] | ||
# NOTE:List of channels, either of length 2 or of length number of images by 2. | ||
# First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). | ||
# Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). | ||
if len(model_args.eval_args["channels"]) < 2: | ||
model_args.eval_args["channels"].append(0) | ||
|
||
click.echo( | ||
f"Segmenting with model {model_name} using channels {model_args.eval_args['channels']}" | ||
) | ||
# Channel strings to indices to be used in cellpose. Hiding this from the | ||
model_args.channels = [channel_names.index(channel) for channel in model_args.channels] | ||
# NOTE: Cellpose requires 3 channels. If the channels list is less than 3, the first channel is repeated. | ||
|
||
if len(model_args.channels) < 3: | ||
model_args.channels.extend( | ||
[model_args.channels[0]] * (3 - len(model_args.channels)) | ||
) | ||
else: | ||
warnings.warn( | ||
f"Model {model_name} has more than 3 channels. Only the first 3 channels will be used." | ||
) | ||
|
||
click.echo(f"Segmenting {model_name} using channels {model_args.channels}") | ||
if ( | ||
"anisotropy" not in model_args.eval_args | ||
or model_args.eval_args["anisotropy"] is None | ||
|
@@ -193,27 +229,29 @@ def segment_cli( | |
segmentation_shape = (T, C_segment, Z, Y, X) | ||
|
||
# Create a zarr store output to mirror the input | ||
# Note, dtype is set to uint32. Change this if one envisions having more than 2^32 labels. | ||
create_empty_plate( | ||
store_path=output_dirpath, | ||
position_keys=[path.parts[-3:] for path in input_position_dirpaths], | ||
channel_names=[model_name + "_labels" for model_name in segment_args.keys()], | ||
channel_names=[model_name + "_labels" for model_name in settings.models.keys()], | ||
shape=segmentation_shape, | ||
chunks=None, | ||
scale=scale, | ||
dtype=np.uint32, | ||
) | ||
|
||
# Estimate resources | ||
num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=20) | ||
num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=10) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a lot of funny math to estimate the resources. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest leaving reasonable defaults in the code and using an sbatch file to tune these parameters as needed for specific reconstructions or depending on the cluster usage There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the same as before. We just dont use the CPUs that much. |
||
num_gpus = 1 | ||
slurm_time = np.ceil(np.max([80, T * 2.5])).astype(int) | ||
slurm_array_parallelism = 100 | ||
slurm_time = np.ceil(np.max([600, T * Z * 10])).astype(int) | ||
slurm_array_parallelism = 9 | ||
# Prepare SLURM arguments | ||
slurm_args = { | ||
"slurm_job_name": "segment", | ||
"slurm_gres": f"gpu:{num_gpus}", | ||
"slurm_mem_per_cpu": f"{gb_ram_request}G", | ||
"slurm_cpus_per_task": np.max([int(20 * 1.3), num_cpus]), | ||
"slurm_array_parallelism": slurm_array_parallelism, # process up to 20 positions at a time | ||
"slurm_cpus_per_task": np.max([int(slurm_array_parallelism * 2), num_cpus]), | ||
"slurm_array_parallelism": slurm_array_parallelism, | ||
"slurm_time": slurm_time, | ||
"slurm_partition": "gpu", | ||
} | ||
|
@@ -239,13 +277,13 @@ def segment_cli( | |
jobs.append( | ||
executor.submit( | ||
process_single_position, | ||
segment_data, | ||
input_position_path, | ||
output_position_path, | ||
func=segment_data, | ||
input_position_path=input_position_path, | ||
output_position_path=output_position_path, | ||
input_channel_indices=[list(range(C))], | ||
output_channel_indices=[list(range(C_segment))], | ||
num_processes=np.min([20, int(num_cpus * 0.8)]), | ||
segmentation_models=segment_args, | ||
num_processes=np.min([5, int(num_cpus * 0.8)]), | ||
segmentation_models=settings.models, | ||
) | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,7 @@ track = [ | |
] | ||
|
||
segment = [ | ||
"cellpose", | ||
"cellpose>=4.0.5", | ||
] | ||
|
||
build = ["build", "twine"] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the output datatype? If it's not stored in the same as input, it should always be integers whose width is determined automatically by cellpose. For example storing uint32 in a float32 container is lossy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The real solution here would be to save integer labels in a separate
labels
array in the zarr store. For now we save everything in one float32 array, which is not ideal. I'd suggest casting cellpose_czyx to float32 and leaving a note that we should fix that later.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant if the output store is not the same as input it should be integers. There is no reason for a store that only have integer arrays to be initialized with float32 data type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And since we do a lot of 2D segmentation, they wouldn't be stored with the 3D input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can default it to uint32.