Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions biahub/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def format_options(self, ctx, formatter):
"import_path": "biahub.optimize_registration.optimize_registration_cli",
"help": "Optimize transform based on match filtering",
},
{
"name": "pyramid",
"import_path": "biahub.pyramid.pyramid_cli",
"help": "Create pyramid levels for a dataset",
},
{
"name": "register",
"import_path": "biahub.register.register_cli",
Expand Down
149 changes: 149 additions & 0 deletions biahub/pyramid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import datetime

from pathlib import Path
from typing import List, Optional

import click
import submitit
import tensorstore as ts

from iohub.ngff import open_ome_zarr

from biahub.cli.parsing import (
input_position_dirpaths,
local,
sbatch_filepath,
sbatch_to_submitit,
)


def pyramid(fov_path: Path, levels: int, method: str) -> None:
"""
Create pyramid levels for a single field of view using tensorstore downsampling.

This function uses cascade downsampling, where each level is downsampled from
the previous level rather than from level 0. This avoids aliasing artifacts
and chunk boundary issues that occur with large downsample factors.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this function move into iohub, say as a method of Position?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ieivanov Yes we can do that.


Parameters
----------
fov_path : Path
Path to the FOV position directory
levels : int
Number of downsampling levels to create
method : str
Downsampling method (e.g., 'mean', 'max', 'min')
"""
with open_ome_zarr(fov_path, mode="r+") as dataset:
dataset.initialize_pyramid(levels=levels)

for level in range(1, levels):
previous_level = dataset[str(level - 1)].tensorstore()

current_scale = dataset.get_effective_scale(str(level))
previous_scale = dataset.get_effective_scale(str(level - 1))
downsample_factors = [
int(round(current_scale[i] / previous_scale[i]))
for i in range(len(current_scale))
]

downsampled = ts.downsample(
previous_level, downsample_factors=downsample_factors, method=method
)

target_store = dataset[str(level)].tensorstore()
target_store[:].write(downsampled[:].read().result()).result()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a heads up, I was having some pretty bad memory issues when I was generating the pyramid like this. Memory would spike massivley near the end of writing. Just putting some sort of loop on writing seemed to cap it. Here's what my code looks like here:

        step = dst_ts.chunk_layout.write_chunk.shape[0]
        for start in range(0, downsampled_ts.shape[0], step):
            stop = min(start + step, downsampled_ts.shape[0])
            dst_ts[start:stop].write(downsampled_ts[start:stop]).result()

Though also would downsampled[:].read().result() read the entire source level into memory?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, tensorstore is pretty bad with memory efficiency. You probably need some looping that is chunk-aligned here as @ivirshup suggested. Also putting them in a transaction will help tensorstore schedule/consolidate the I/O.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ivirshup @ziw-liu Interesting, I'll take a look at chunk-aligned looping and transactions.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ziw-liu I get the impression tensorstore can be great with memory, just not from the highest level api and there are a lot of foot guns 😆

Also putting them in a transaction will help tensorstore schedule/consolidate the I/O.

I had tried this, but did not see an improvement. Does this work well in your hands? Would love a pointer to some code where this worked so I can give it another shot.

Copy link
Contributor

@ziw-liu ziw-liu Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't done benchmarks on transactions myself, but @JoOkuma did at some point.

Copy link
Collaborator

@srivarra srivarra Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ivirshup @ziw-liu
I've combined both writing with TensorStore Transactions and looping over writing each chunk. It looks like we can batch groups of chunks together and have TensorStore Transaction optimize each one. But even if we add more transactions, the peak memory usage seems to always increase.

batch_size_comparison


click.echo(f"Completed pyramid for FOV: {fov_path}")


@click.command("pyramid")
@input_position_dirpaths()
@sbatch_filepath()
@local()
@click.option(
"--levels",
"-lv",
type=int,
default=4,
show_default=True,
help="Number of downsampling levels to create.",
)
@click.option(
"--method",
"-m",
type=click.Choice(
[
"stride",
"median",
"mode",
"mean",
"min",
"max",
]
),
default="mean",
show_default=True,
help="Downsampling method to use.",
)
def pyramid_cli(
input_position_dirpaths: List[Path],
levels: int,
method: str,
sbatch_filepath: Optional[Path],
local: bool,
) -> None:
"""
Creates additional levels of multi-scale pyramids for OME-Zarr datasets.

Uses efficient downsampling to create pyramid levels
in parallel. Each field-of-view (FOV) is processed as a separate SLURM job,
downsampling all timepoints and channels. The pyramids are created in-place
within the input zarr store using the specified downsampling method (default: 'mean').

Example:
biahub pyramid -i ./data.zarr/0/0/0 -lv 4 --method max
biahub pyramid -i ./data.zarr/*/*/* --levels 5 --local
"""
cluster = "local" if local else "slurm"

slurm_args = {
"slurm_job_name": "pyramid",
"slurm_partition": "preempted",
"slurm_cpus_per_task": 16,
"slurm_mem_per_cpu": "8G",
"slurm_time": 30,
"slurm_array_parallelism": 100,
}

# Override with sbatch file parameters if provided
if sbatch_filepath:
slurm_args.update(sbatch_to_submitit(sbatch_filepath))

slurm_out_path = Path("slurm_output")
slurm_out_path.mkdir(exist_ok=True)

executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster)
executor.update_parameters(**slurm_args)

click.echo(
f"Submitting {len(input_position_dirpaths)} pyramid jobs with resources: {slurm_args}"
)

jobs = []
with submitit.helpers.clean_env(), executor.batch():
for fov_path in input_position_dirpaths:
job = executor.submit(pyramid, fov_path=fov_path, levels=levels, method=method)
jobs.append(job)

job_ids = [job.job_id for job in jobs]
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_path = slurm_out_path / f"pyramid-jobs_{timestamp}.log"
with log_path.open("w") as log_file:
log_file.write("\n".join(job_ids))

# wait_for_jobs_to_finish(jobs)


if __name__ == "__main__":
pyramid_cli()