Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions biahub/segment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from pathlib import Path

import click
Expand Down Expand Up @@ -83,13 +85,21 @@ def segment_data(
)
czyx_data[c_idx] = func(czyx_data[c_idx], **kwargs)

# Reorder channels based on channels_for_segmentation
cellpose_czyx = np.zeros(
(3, *czyx_data_to_segment.shape[1:]), dtype=czyx_data_to_segment.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the output datatype? If it's not stored in the same as input, it should always be integers whose width is determined automatically by cellpose. For example storing uint32 in a float32 container is lossy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The real solution here would be to save integer labels in a separate labels array in the zarr store. For now we save everything in one float32 array, which is not ideal. I'd suggest casting cellpose_czyx to float32 and leaving a note that we should fix that later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if the output store is not the same as input it should be integers. There is no reason for a store that only have integer arrays to be initialized with float32 data type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And since we do a lot of 2D segmentation, they wouldn't be stored with the 3D input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can default it to uint32.

)
for i, channel in enumerate(model_args.channels):
if channel is not None:
cellpose_czyx[i] = czyx_data_to_segment[channel]

# Apply the segmentation
model = models.CellposeModel(
model_type=model_args.path_to_model, gpu=gpu, device=device
gpu=gpu, device=device, pretrained_model=model_args.path_to_model
)
segmentation, _, _ = model.eval(
czyx_data_to_segment, channel_axis=0, z_axis=1, **model_args.eval_args
) # noqa: python-no-eval
cellpose_czyx, channel_axis=0, z_axis=1, **model_args.eval_args
)
if z_slice_2D is not None and isinstance(z_slice_2D, int):
segmentation = segmentation[np.newaxis, ...]
czyx_segmentation.append(segmentation)
Expand Down Expand Up @@ -147,23 +157,24 @@ def segment_cli(
if model_args.z_slice_2D is not None and isinstance(model_args.z_slice_2D, int):
Z = 1
# Ensure channel names exist in the dataset
if not all(channel in channel_names for channel in model_args.eval_args["channels"]):
if not all(channel in channel_names for channel in model_args.channels):
raise ValueError(
f"Channels {model_args.eval_args['channels']} not found in dataset {channel_names}"
f"Channels {model_args.channels} not found in dataset {channel_names}"
)
# Channel strings to indices with the cellpose offset of 1
model_args.eval_args["channels"] = [
channel_names.index(channel) + 1 for channel in model_args.eval_args["channels"]
]
# NOTE:List of channels, either of length 2 or of length number of images by 2.
# First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
# Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
if len(model_args.eval_args["channels"]) < 2:
model_args.eval_args["channels"].append(0)

click.echo(
f"Segmenting with model {model_name} using channels {model_args.eval_args['channels']}"
)
# Channel strings to indices to be used in cellpose. Hiding this from the
model_args.channels = [channel_names.index(channel) for channel in model_args.channels]
# NOTE: Cellpose requires 3 channels. If the channels list is less than 3, the first channel is repeated.

if len(model_args.channels) < 3:
model_args.channels.extend(
[model_args.channels[0]] * (3 - len(model_args.channels))
)
else:
warnings.warn(
f"Model {model_name} has more than 3 channels. Only the first 3 channels will be used."
)

click.echo(f"Segmenting with model {model_name} using channels {model_args.channels}")
if (
"anisotropy" not in model_args.eval_args
or model_args.eval_args["anisotropy"] is None
Expand Down Expand Up @@ -202,7 +213,7 @@ def segment_cli(
# Estimate resources
num_cpus, gb_ram_request = estimate_resources(shape=segmentation_shape, ram_multiplier=20)
num_gpus = 1
slurm_time = np.ceil(np.max([80, T * 2.5])).astype(int)
slurm_time = np.ceil(np.max([120, T * Z * 5])).astype(int)
slurm_array_parallelism = 100
# Prepare SLURM arguments
slurm_args = {
Expand Down Expand Up @@ -236,9 +247,9 @@ def segment_cli(
jobs.append(
executor.submit(
process_single_position,
segment_data,
input_position_path,
output_position_path,
func=segment_data,
input_position_path=input_position_path,
output_position_path=output_position_path,
input_channel_indices=[list(range(C))],
output_channel_indices=[list(range(C_segment))],
num_processes=np.min([20, int(num_cpus * 0.8)]),
Expand Down
9 changes: 6 additions & 3 deletions biahub/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PositiveInt,
field_validator,
model_validator,
validator,
)


Expand Down Expand Up @@ -393,10 +392,12 @@ class PreprocessingFunctions(BaseModel):
class SegmentationModel(BaseModel):
path_to_model: str
eval_args: Dict[str, Any]
channels: list[str]
z_slice_2D: Optional[int] = None
preprocessing: list[PreprocessingFunctions] = []

@validator("eval_args", pre=True)
@field_validator("eval_args")
@classmethod
def validate_eval_args(cls, value):
# Retrieve valid arguments dynamically if cellpose is required
valid_args = get_valid_eval_args()
Expand All @@ -410,7 +411,8 @@ def validate_eval_args(cls, value):

return value

@validator("z_slice_2D")
@field_validator("z_slice_2D")
@classmethod
def check_z_slice_with_do_3D(cls, z_slice_2D, values):
# Only run this check if z_slice is provided (not None) and do_3D exists in eval_args
if z_slice_2D is not None:
Expand All @@ -421,6 +423,7 @@ def check_z_slice_with_do_3D(cls, z_slice_2D, values):
"If 'z_slice_2D' is provided, 'do_3D' in 'eval_args' must be set to False."
)
z_slice_2D = 0

return z_slice_2D


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ dev = [
]

segmentation = [
"cellpose",
"cellpose>=4.0.4",
]

build = ["build", "twine"]
Expand Down
32 changes: 19 additions & 13 deletions settings/example_segmentation_settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,41 @@
models:
# One can instantiate as many models
membrane:
path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)"
# These are the common model.CellposeModel().eval() arguments used, but one can add more.
# For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0
path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model.
eval_args:
diameter: 65
channels: ['mem', 'nuc'] #The channel count for Cellpose starts at 1
cellprob_threshold: 0.4
invert: false
do_3D: false # Optional, if false, 2D segmentation is performed.
do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided.
anisotropy: 3.26
min_size: 8000
z_slice_2D: 10 # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True
normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity
z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True
channels: ['mem', 'nuc']
preprocessing:
- function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity
kwargs: {"out_range": [0, 1]}
channel: 'mem'
- function: skimage.exposure.equalize_adapthist
kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]}
channel: 'mem'
# One can instantiate as many models
nucleus:
path_to_model: "/path/to/nucleus/model or name of built-in cellpose model (e.g. cyto3 or nuclei)"
# These are the common model.CellposeModel().eval() arguments used, but one can add more.
# For more information, see https://cellpose.readthedocs.io/en/latest/api.html#id0
path_to_model: 'cpsam' # Default: Cellpose-SAM model. Path to the pretrained model.
eval_args:
diameter: 60
channels: ['nuc'] #For nucleus segmentation, only one channel is required. We populate the other channel with zero.
cellprob_threshold: 0.0
diameter: 55
cellprob_threshold: 0.4
invert: false
do_3D: true
do_3D: true # Optional, if false, 2D segmentation is performed. if true, z_slice and channel_axis must be provided.
anisotropy: 3.26
min_size: 8000
normalize: {"tile_norm_blocksize": 0} # Optional, if 0, the whole image is used. Cellpose suggests 100-200 if one sees imhomogeneity
z_slice_2D: null # Optional, if null, 3D segmentation is performed and checks eval_args.do_3D=True
channels: ['nuc','mem']
preprocessing:
- function: skimage.exposure.rescale_intensity #configurable callables like rescaling intensity
kwargs: {"out_range": [0, 1]}
channel: 'nuc'
- function: skimage.exposure.equalize_adapthist
kwargs: {"clip_limit": 0.01,"kernel_size":[5, 32, 32]}
channel: 'nuc'