diff --git a/biahub/concatenate.py b/biahub/concatenate.py index 6a732e3c..c331c2e4 100644 --- a/biahub/concatenate.py +++ b/biahub/concatenate.py @@ -238,22 +238,32 @@ def calculate_cropped_size( def concatenate( settings: ConcatenateSettings, output_dirpath: Path, - sbatch_filepath: str = None, + sbatch_filepath: str | None = None, local: bool = False, + block: bool = False, monitor: bool = True, ): - """ - Concatenate datasets (with optional cropping) - - >> biahub concatenate -c ./concat.yml -o ./output_concat.zarr -j 8 + """Concatenate datasets (with optional cropping). + + Parameters + ---------- + settings : ConcatenateSettings + Configuration settings for concatenation + output_dirpath : Path + Path to the output dataset + sbatch_filepath : str | None, optional + Path to the SLURM batch file, by default None + local : bool, optional + Whether to run locally or on a cluster, by default False + block : bool, optional + Whether to block until all the jobs are complete, + by default False + monitor : bool, optional + Whether to monitor the jobs, by default True """ slurm_out_path = output_dirpath.parent / "slurm_output" - slicing_params = [ - settings.Z_slice, - settings.Y_slice, - settings.X_slice, - ] + slicing_params = [settings.Z_slice, settings.Y_slice, settings.X_slice] ( all_data_paths, all_channel_names, @@ -334,11 +344,12 @@ def concatenate( chunk_size = [1] + list(settings.chunks_czyx) else: chunk_size = settings.chunks_czyx - # Logic for creation of zarr and metadata output_metadata = { "shape": (len(input_time_indices), len(all_channel_names)) + tuple(cropped_shape_zyx), "chunks": chunk_size, + "shards_ratio": settings.shards_ratio, + "version": settings.output_ome_zarr_version, "scale": (1,) * 2 + tuple(output_voxel_size), "channel_names": all_channel_names, "dtype": dtype, @@ -352,8 +363,9 @@ def concatenate( ) # Estimate resources + batch_size = settings.shards_ratio[0] if settings.shards_ratio else 1 num_cpus, gb_ram_per_cpu = estimate_resources( - shape=[T, C, Z, Y, X], ram_multiplier=16, max_num_cpus=16 + shape=(T // batch_size, C, Z, Y, X), ram_multiplier=4 * batch_size, max_num_cpus=16 ) # Prepare SLURM arguments slurm_args = { @@ -380,8 +392,9 @@ def concatenate( executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster) executor.update_parameters(**slurm_args) - click.echo("Submitting SLURM jobs...") + click.echo(f"Submitting {cluster} jobs...") jobs = [] + with submitit.helpers.clean_env(), executor.batch(): for i, ( input_position_path, @@ -424,6 +437,9 @@ def concatenate( with log_path.open("w") as log_file: log_file.write("\n".join(job_ids)) + if block: + _ = [job.result() for job in jobs] + if monitor: monitor_jobs(jobs, all_data_paths) @@ -437,14 +453,14 @@ def concatenate( def concatenate_cli( config_filepath: Path, output_dirpath: str, - sbatch_filepath: str = None, + sbatch_filepath: str | None = None, local: bool = False, monitor: bool = True, ): """ Concatenate datasets (with optional cropping) - >> biahub concatenate -c ./concat.yml -o ./output_concat.zarr -j 8 + >> biahub concatenate -c ./concat.yml -o ./output_concat.zarr """ concatenate( @@ -452,6 +468,7 @@ def concatenate_cli( output_dirpath=Path(output_dirpath), sbatch_filepath=sbatch_filepath, local=local, + block=False, monitor=monitor, ) diff --git a/biahub/settings.py b/biahub/settings.py index eb505a68..93f635d6 100644 --- a/biahub/settings.py +++ b/biahub/settings.py @@ -375,7 +375,9 @@ class ConcatenateSettings(MyBaseModel): Y_slice: Union[list, list[Union[list, Literal["all"]]], Literal["all"]] = "all" Z_slice: Union[list, list[Union[list, Literal["all"]]], Literal["all"]] = "all" chunks_czyx: Union[Literal[None], list[int]] = None + shards_ratio: list[int] | None = None ensure_unique_positions: Optional[bool] = False + output_ome_zarr_version: Literal["0.4", "0.5"] = "0.4" @field_validator("concat_data_paths") @classmethod diff --git a/biahub/tests/test_concatenate.py b/biahub/tests/test_concatenate.py index 03edfe63..19ad8522 100644 --- a/biahub/tests/test_concatenate.py +++ b/biahub/tests/test_concatenate.py @@ -1,3 +1,6 @@ +import numpy as np +import pytest + from iohub import open_ome_zarr from biahub.concatenate import concatenate @@ -204,7 +207,13 @@ def test_concatenate_with_cropping(create_custom_plate, tmp_path, sbatch_file): assert output_X == x_end - x_start -def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_file): +@pytest.mark.parametrize( + ["version", "shards_ratio_time"], + [["0.4", 1], ["0.5", None], ["0.5", 1], ["0.5", 2], ["0.5", 5]], +) +def test_concatenate_with_custom_chunks( + create_custom_plate, tmp_path, sbatch_file, version, shards_ratio_time +): """ Test concatenating with custom chunk sizes """ @@ -227,13 +236,22 @@ def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_fi ) # Define custom chunk sizes - custom_chunks = [1, 2, 4, 3] # [C, Z, Y, X] + chunks = [1, 1, 2, 4, 3] # [C, Z, Y, X] + if version == "0.5": + if shards_ratio_time is None: + shards_ratio = None + else: + shards_ratio = [shards_ratio_time, 1, 1, 2, 2] + elif version == "0.4": + shards_ratio = None settings = ConcatenateSettings( concat_data_paths=[str(plate_1_path) + "/*/*/*", str(plate_2_path) + "/*/*/*"], channel_names=['all', 'all'], time_indices='all', - chunks_czyx=custom_chunks, + chunks_czyx=chunks[1:], + shards_ratio=shards_ratio, + output_ome_zarr_version=version, ) output_path = tmp_path / "output.zarr" @@ -242,10 +260,21 @@ def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_fi output_dirpath=output_path, sbatch_filepath=sbatch_file, local=True, + monitor=False, + block=True, ) - # We can't easily check the chunks directly, but we can verify the operation completed successfully output_plate = open_ome_zarr(output_path) + for pos_name, pos in output_plate.positions(): + assert pos.data.chunks == tuple(chunks) + if version == "0.5" and shards_ratio is not None: + assert pos.data.shards == tuple(c * s for c, s in zip(chunks, shards_ratio)) + np.testing.assert_array_equal( + pos.data.numpy(), + np.concatenate( + [plate_1[pos_name].data.numpy(), plate_2[pos_name].data.numpy()], axis=1 + ), + ) # Check that the output plate has all the channels from the input plates output_channels = output_plate.channel_names diff --git a/biahub/track.py b/biahub/track.py index e3bffb50..f56371be 100644 --- a/biahub/track.py +++ b/biahub/track.py @@ -36,6 +36,8 @@ ) from biahub.settings import ProcessingInputChannel, TrackingSettings +# Lazy imports for ultrack - imported only when needed in specific functions + def mem_nuc_contour(nuclei_prediction: ArrayLike, membrane_prediction: ArrayLike) -> ArrayLike: """ diff --git a/pyproject.toml b/pyproject.toml index 05c435ff..cb26183f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ # list package dependencies here dependencies = [ - "iohub>=0.2,<0.3", + "iohub>=0.3.0a2,<0.4", "matplotlib", "napari", "PyQt6", @@ -53,9 +53,11 @@ segment = ["cellpose"] track = ["ultrack>=0.6.3"] +shard = ["tensorstore"] + build = ["build", "twine"] -all = ["biahub[segment,track,build]"] +all = ["biahub[segment,track,shard,build]"] dev = [ "biahub[all]", diff --git a/settings/example_concatenate_settings.yml b/settings/example_concatenate_settings.yml index 045b35e3..476c27b3 100644 --- a/settings/example_concatenate_settings.yml +++ b/settings/example_concatenate_settings.yml @@ -4,8 +4,8 @@ # List of paths to concatenate - can use glob patterns # Each path will be treated as a separate input dataset concat_data_paths: - - "/path/to/data1.zarr/*/*/*" # First dataset - - "/path/to/data2.zarr/*/*/*" # Second dataset + - "/path/to/data1.zarr/*/*/*" # First dataset + - "/path/to/data2.zarr/*/*/*" # Second dataset # - "/path/to/data3.zarr/A/1/0" # You can also specify exact positions # Time indices to include in the output @@ -22,8 +22,8 @@ time_indices: "all" # - For multiple datasets, specify channels for each: # [["DAPI"], ["GFP", "RFP"]] - Take DAPI from first dataset, GFP and RFP from second channel_names: - - "all" # Include all channels from first dataset - - "all" # Include all channels from second dataset + - "all" # Include all channels from first dataset + - "all" # Include all channels from second dataset # Spatial cropping options for X dimension # Options: @@ -55,12 +55,23 @@ Z_slice: "all" # - [1, 10, 100, 100]: Specify custom chunk sizes chunks_czyx: null +# Number of chunks in a shard for each dimension [T, C, Z, Y, X] +# Options: +# - null: No sharding +# - [1, 1, 4, 8, 8]: Specify custom shards ratio +shards_ratio: null + +# Version of the OME-Zarr format to use for the output +# Options: +# - "0.4" (default) +# - "0.5" +output_ome_zarr_version: "0.4" + # Whether to ensure unique position names in the output # Options: # - false or null: Positions with the same name will overwrite each other # - true: Ensure unique position names by adding suffixes (e.g., A/1d1/0) ensure_unique_positions: null - # EXAMPLE USE CASES: # 1. Basic concatenation of all data: