Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions biahub/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,32 @@ def calculate_cropped_size(
def concatenate(
settings: ConcatenateSettings,
output_dirpath: Path,
sbatch_filepath: str = None,
sbatch_filepath: str | None = None,
local: bool = False,
block: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the motivation for including the block parameter? Was it useful during testing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running locally there isn't a good way to check if the jobs (processes) have finished. It is also useful for testing.

monitor: bool = True,
):
"""
Concatenate datasets (with optional cropping)

>> biahub concatenate -c ./concat.yml -o ./output_concat.zarr -j 8
"""Concatenate datasets (with optional cropping).

Parameters
----------
settings : ConcatenateSettings
Configuration settings for concatenation
output_dirpath : Path
Path to the output dataset
sbatch_filepath : str | None, optional
Path to the SLURM batch file, by default None
local : bool, optional
Whether to run locally or on a cluster, by default False
block : bool, optional
Whether to block until all the jobs are complete,
by default False
monitor : bool, optional
Whether to monitor the jobs, by default True
"""
slurm_out_path = output_dirpath.parent / "slurm_output"

slicing_params = [
settings.Z_slice,
settings.Y_slice,
settings.X_slice,
]
slicing_params = [settings.Z_slice, settings.Y_slice, settings.X_slice]
(
all_data_paths,
all_channel_names,
Expand Down Expand Up @@ -334,11 +344,12 @@ def concatenate(
chunk_size = [1] + list(settings.chunks_czyx)
else:
chunk_size = settings.chunks_czyx

# Logic for creation of zarr and metadata
output_metadata = {
"shape": (len(input_time_indices), len(all_channel_names)) + tuple(cropped_shape_zyx),
"chunks": chunk_size,
"shards_ratio": settings.shards_ratio,
"version": settings.output_ome_zarr_version,
"scale": (1,) * 2 + tuple(output_voxel_size),
"channel_names": all_channel_names,
"dtype": dtype,
Expand All @@ -352,8 +363,9 @@ def concatenate(
)

# Estimate resources
batch_size = settings.shards_ratio[0] if settings.shards_ratio else 1
num_cpus, gb_ram_per_cpu = estimate_resources(
shape=[T, C, Z, Y, X], ram_multiplier=16, max_num_cpus=16
shape=(T // batch_size, C, Z, Y, X), ram_multiplier=4 * batch_size, max_num_cpus=16
)
# Prepare SLURM arguments
slurm_args = {
Expand All @@ -380,8 +392,9 @@ def concatenate(
executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster)
executor.update_parameters(**slurm_args)

click.echo("Submitting SLURM jobs...")
click.echo(f"Submitting {cluster} jobs...")
jobs = []

with submitit.helpers.clean_env(), executor.batch():
for i, (
input_position_path,
Expand Down Expand Up @@ -424,6 +437,9 @@ def concatenate(
with log_path.open("w") as log_file:
log_file.write("\n".join(job_ids))

if block:
_ = [job.result() for job in jobs]

if monitor:
monitor_jobs(jobs, all_data_paths)

Expand All @@ -437,21 +453,22 @@ def concatenate(
def concatenate_cli(
config_filepath: Path,
output_dirpath: str,
sbatch_filepath: str = None,
sbatch_filepath: str | None = None,
local: bool = False,
monitor: bool = True,
):
"""
Concatenate datasets (with optional cropping)

>> biahub concatenate -c ./concat.yml -o ./output_concat.zarr -j 8
>> biahub concatenate -c ./concat.yml -o ./output_concat.zarr
"""

concatenate(
settings=yaml_to_model(config_filepath, ConcatenateSettings),
output_dirpath=Path(output_dirpath),
sbatch_filepath=sbatch_filepath,
local=local,
block=False,
monitor=monitor,
)

Expand Down
2 changes: 2 additions & 0 deletions biahub/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,9 @@ class ConcatenateSettings(MyBaseModel):
Y_slice: Union[list, list[Union[list, Literal["all"]]], Literal["all"]] = "all"
Z_slice: Union[list, list[Union[list, Literal["all"]]], Literal["all"]] = "all"
chunks_czyx: Union[Literal[None], list[int]] = None
shards_ratio: list[int] | None = None
ensure_unique_positions: Optional[bool] = False
output_ome_zarr_version: Literal["0.4", "0.5"] = "0.4"

@field_validator("concat_data_paths")
@classmethod
Expand Down
37 changes: 33 additions & 4 deletions biahub/tests/test_concatenate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np
import pytest

from iohub import open_ome_zarr

from biahub.concatenate import concatenate
Expand Down Expand Up @@ -204,7 +207,13 @@ def test_concatenate_with_cropping(create_custom_plate, tmp_path, sbatch_file):
assert output_X == x_end - x_start


def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_file):
@pytest.mark.parametrize(
["version", "shards_ratio_time"],
[["0.4", 1], ["0.5", None], ["0.5", 1], ["0.5", 2], ["0.5", 5]],
)
def test_concatenate_with_custom_chunks(
create_custom_plate, tmp_path, sbatch_file, version, shards_ratio_time
):
"""
Test concatenating with custom chunk sizes
"""
Expand All @@ -227,13 +236,22 @@ def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_fi
)

# Define custom chunk sizes
custom_chunks = [1, 2, 4, 3] # [C, Z, Y, X]
chunks = [1, 1, 2, 4, 3] # [C, Z, Y, X]
if version == "0.5":
if shards_ratio_time is None:
shards_ratio = None
else:
shards_ratio = [shards_ratio_time, 1, 1, 2, 2]
elif version == "0.4":
shards_ratio = None

settings = ConcatenateSettings(
concat_data_paths=[str(plate_1_path) + "/*/*/*", str(plate_2_path) + "/*/*/*"],
channel_names=['all', 'all'],
time_indices='all',
chunks_czyx=custom_chunks,
chunks_czyx=chunks[1:],
shards_ratio=shards_ratio,
output_ome_zarr_version=version,
)

output_path = tmp_path / "output.zarr"
Expand All @@ -242,10 +260,21 @@ def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_fi
output_dirpath=output_path,
sbatch_filepath=sbatch_file,
local=True,
monitor=False,
block=True,
)

# We can't easily check the chunks directly, but we can verify the operation completed successfully
output_plate = open_ome_zarr(output_path)
for pos_name, pos in output_plate.positions():
assert pos.data.chunks == tuple(chunks)
if version == "0.5" and shards_ratio is not None:
assert pos.data.shards == tuple(c * s for c, s in zip(chunks, shards_ratio))
np.testing.assert_array_equal(
pos.data.numpy(),
np.concatenate(
[plate_1[pos_name].data.numpy(), plate_2[pos_name].data.numpy()], axis=1
),
)

# Check that the output plate has all the channels from the input plates
output_channels = output_plate.channel_names
Expand Down
2 changes: 2 additions & 0 deletions biahub/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
)
from biahub.settings import ProcessingInputChannel, TrackingSettings

# Lazy imports for ultrack - imported only when needed in specific functions


def mem_nuc_contour(nuclei_prediction: ArrayLike, membrane_prediction: ArrayLike) -> ArrayLike:
"""
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [

# list package dependencies here
dependencies = [
"iohub>=0.2,<0.3",
"iohub>=0.3.0a2,<0.4",
"matplotlib",
"napari",
"PyQt6",
Expand Down Expand Up @@ -53,9 +53,11 @@ segment = ["cellpose"]

track = ["ultrack>=0.6.3"]

shard = ["tensorstore"]

build = ["build", "twine"]

all = ["biahub[segment,track,build]"]
all = ["biahub[segment,track,shard,build]"]

dev = [
"biahub[all]",
Expand Down
21 changes: 16 additions & 5 deletions settings/example_concatenate_settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# List of paths to concatenate - can use glob patterns
# Each path will be treated as a separate input dataset
concat_data_paths:
- "/path/to/data1.zarr/*/*/*" # First dataset
- "/path/to/data2.zarr/*/*/*" # Second dataset
- "/path/to/data1.zarr/*/*/*" # First dataset
- "/path/to/data2.zarr/*/*/*" # Second dataset
# - "/path/to/data3.zarr/A/1/0" # You can also specify exact positions

# Time indices to include in the output
Expand All @@ -22,8 +22,8 @@ time_indices: "all"
# - For multiple datasets, specify channels for each:
# [["DAPI"], ["GFP", "RFP"]] - Take DAPI from first dataset, GFP and RFP from second
channel_names:
- "all" # Include all channels from first dataset
- "all" # Include all channels from second dataset
- "all" # Include all channels from first dataset
- "all" # Include all channels from second dataset

# Spatial cropping options for X dimension
# Options:
Expand Down Expand Up @@ -55,12 +55,23 @@ Z_slice: "all"
# - [1, 10, 100, 100]: Specify custom chunk sizes
chunks_czyx: null

# Number of chunks in a shard for each dimension [T, C, Z, Y, X]
# Options:
# - null: No sharding
# - [1, 1, 4, 8, 8]: Specify custom shards ratio
shards_ratio: null

# Version of the OME-Zarr format to use for the output
# Options:
# - "0.4" (default)
# - "0.5"
output_ome_zarr_version: "0.4"

# Whether to ensure unique position names in the output
# Options:
# - false or null: Positions with the same name will overwrite each other
# - true: Ensure unique position names by adding suffixes (e.g., A/1d1/0)
ensure_unique_positions: null

# EXAMPLE USE CASES:

# 1. Basic concatenation of all data:
Expand Down
Loading