Skip to content

Coalesce and parallelize partial shard reads #3004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes/3004.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Optimizes reading more than one, but not all, chunks from a shard. Chunks are now read in parallel
and reads of nearby chunks within the same shard are combined to reduce the number of calls to the store.
See :ref:`user-guide-config` for more details.
8 changes: 7 additions & 1 deletion docs/user-guide/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ Configuration options include the following:
- Async and threading options, e.g. ``async.concurrency`` and ``threading.max_workers``
- Selections of implementations of codecs, codec pipelines and buffers
- Enabling GPU support with ``zarr.config.enable_gpu()``. See :ref:`user-guide-gpu` for more.
- Tuning reads from sharded zarrs. When reading less than a complete shard, reads of nearby chunks
within the same shard will be combined into a single request if they are less than
``sharding.read.coalesce_max_gap_bytes`` apart and the combined request size is less than
``sharding.read.coalesce_max_bytes``.

For selecting custom implementations of codecs, pipelines, buffers and ndbuffers,
first register the implementations in the registry and then select them in the config.
Expand Down Expand Up @@ -61,4 +65,6 @@ This is the current default configuration::
'default_zarr_format': 3,
'json_indent': 2,
'ndbuffer': 'zarr.buffer.cpu.NDBuffer',
'threading': {'max_workers': None}}
'sharding': {'read': {'coalesce_max_bytes': 104857600,
'coalesce_max_gap_bytes': 1048576}},
'threading': {'max_workers': None}}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ ignore = [
[tool.mypy]
python_version = "3.11"
ignore_missing_imports = true
mypy_path = "src"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to do this for the mypy run in pre-commit to succeed when it was running on tests/test_config.py. Not sure if we want this at all, or if it should go in its own PR.

namespace_packages = false

strict = true
Expand Down
176 changes: 147 additions & 29 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
from zarr.core.common import (
ChunkCoords,
ChunkCoordsLike,
concurrent_map,
parse_enum,
parse_named_configuration,
parse_shapelike,
product,
)
from zarr.core.config import config
from zarr.core.dtype.npy.int import UInt64
from zarr.core.indexing import (
BasicIndexer,
Expand Down Expand Up @@ -198,7 +200,9 @@ async def from_bytes(

@classmethod
def create_empty(
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
cls,
chunks_per_shard: ChunkCoords,
buffer_prototype: BufferPrototype | None = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a ruff format version upgrade? I'm leaving these formatting changes in since they were produced by the pre-commit run

) -> _ShardReader:
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
Expand Down Expand Up @@ -248,7 +252,9 @@ def merge_with_morton_order(

@classmethod
def create_empty(
cls, chunks_per_shard: ChunkCoords, buffer_prototype: BufferPrototype | None = None
cls,
chunks_per_shard: ChunkCoords,
buffer_prototype: BufferPrototype | None = None,
) -> _ShardBuilder:
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
Expand Down Expand Up @@ -329,9 +335,18 @@ async def finalize(
return await shard_builder.finalize(index_location, index_encoder)


class _ChunkCoordsByteSlice(NamedTuple):
"""Holds a chunk's coordinates and its byte range in a serialized shard."""

coords: ChunkCoords
byte_slice: slice


@dataclass(frozen=True)
class ShardingCodec(
ArrayBytesCodec, ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
ArrayBytesCodec,
ArrayBytesCodecPartialDecodeMixin,
ArrayBytesCodecPartialEncodeMixin,
):
chunk_shape: ChunkCoords
codecs: tuple[Codec, ...]
Expand Down Expand Up @@ -454,7 +469,7 @@ async def _decode_single(
shape=shard_shape,
dtype=shard_spec.dtype.to_native_dtype(),
order=shard_spec.order,
fill_value=0,
fill_value=shard_spec.fill_value,
)
shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard)

Expand Down Expand Up @@ -501,39 +516,28 @@ async def _decode_partial_single(
shape=indexer.shape,
dtype=shard_spec.dtype.to_native_dtype(),
order=shard_spec.order,
fill_value=0,
fill_value=shard_spec.fill_value,
)

indexed_chunks = list(indexer)
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}

# reading bytes of all requested chunks
shard_dict: ShardMapping = {}
shard_dict_maybe: ShardMapping | None = {}
if self._is_total_shard(all_chunk_coords, chunks_per_shard):
# read entire shard
shard_dict_maybe = await self._load_full_shard_maybe(
byte_getter=byte_getter,
prototype=chunk_spec.prototype,
chunks_per_shard=chunks_per_shard,
byte_getter, chunk_spec.prototype, chunks_per_shard
)
if shard_dict_maybe is None:
return None
shard_dict = shard_dict_maybe
else:
# read some chunks within the shard
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
if shard_index is None:
return None
shard_dict = {}
for chunk_coords in all_chunk_coords:
chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords)
if chunk_byte_slice:
chunk_bytes = await byte_getter.get(
prototype=chunk_spec.prototype,
byte_range=RangeByteRequest(chunk_byte_slice[0], chunk_byte_slice[1]),
)
if chunk_bytes:
shard_dict[chunk_coords] = chunk_bytes
shard_dict_maybe = await self._load_partial_shard_maybe(
byte_getter, chunk_spec.prototype, chunks_per_shard, all_chunk_coords
)

if shard_dict_maybe is None:
return None
shard_dict = shard_dict_maybe
Comment on lines -519 to +540
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's where the non-formatting changes start in this file


# decoding chunks and writing them into the output buffer
await self.codec_pipeline.read(
Expand Down Expand Up @@ -615,7 +619,9 @@ async def _encode_partial_single(

indexer = list(
get_indexer(
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
selection,
shape=shard_shape,
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
)
)

Expand Down Expand Up @@ -689,7 +695,8 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int:
get_pipeline_class()
.from_codecs(self.index_codecs)
.compute_encoded_size(
16 * product(chunks_per_shard), self._get_index_chunk_spec(chunks_per_shard)
16 * product(chunks_per_shard),
self._get_index_chunk_spec(chunks_per_shard),
)
)

Expand Down Expand Up @@ -734,7 +741,8 @@ async def _load_shard_index_maybe(
)
else:
index_bytes = await byte_getter.get(
prototype=numpy_buffer_prototype(), byte_range=SuffixByteRequest(shard_index_size)
prototype=numpy_buffer_prototype(),
byte_range=SuffixByteRequest(shard_index_size),
)
if index_bytes is not None:
return await self._decode_shard_index(index_bytes, chunks_per_shard)
Expand All @@ -748,7 +756,10 @@ async def _load_shard_index(
) or _ShardIndex.create_empty(chunks_per_shard)

async def _load_full_shard_maybe(
self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords
self,
byte_getter: ByteGetter,
prototype: BufferPrototype,
chunks_per_shard: ChunkCoords,
) -> _ShardReader | None:
shard_bytes = await byte_getter.get(prototype=prototype)

Expand All @@ -758,6 +769,113 @@ async def _load_full_shard_maybe(
else None
)

async def _load_partial_shard_maybe(
self,
byte_getter: ByteGetter,
prototype: BufferPrototype,
chunks_per_shard: ChunkCoords,
all_chunk_coords: set[ChunkCoords],
) -> ShardMapping | None:
"""
Read chunks from `byte_getter` for the case where the read is less than a full shard.
Returns a mapping of chunk coordinates to bytes.
"""
shard_index = await self._load_shard_index_maybe(byte_getter, chunks_per_shard)
if shard_index is None:
return None

chunks = [
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
for chunk_coords in all_chunk_coords
# Drop chunks where index lookup fails
# e.g. when write_empty_chunks = False and the chunk is empty
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
]
if len(chunks) == 0:
return None

groups = self._coalesce_chunks(chunks)

shard_dicts = await concurrent_map(
[(group, byte_getter, prototype) for group in groups],
self._get_group_bytes,
config.get("async.concurrency"),
)

shard_dict: ShardMutableMapping = {}
for d in shard_dicts:
if d is None:
return None
shard_dict.update(d)

return shard_dict

def _coalesce_chunks(
self,
chunks: list[_ChunkCoordsByteSlice],
) -> list[list[_ChunkCoordsByteSlice]]:
"""
Combine chunks from a single shard into groups that should be read together
in a single request.

Respects the following configuration options:
- `sharding.read.coalesce_max_gap_bytes`: The maximum gap between
chunks to coalesce into a single group.
- `sharding.read.coalesce_max_bytes`: The maximum number of bytes in a group.
"""
max_gap_bytes = config.get("sharding.read.coalesce_max_gap_bytes")
coalesce_max_bytes = config.get("sharding.read.coalesce_max_bytes")

sorted_chunks = sorted(chunks, key=lambda c: c.byte_slice.start)

groups = []
current_group = [sorted_chunks[0]]

for chunk in sorted_chunks[1:]:
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes:
current_group.append(chunk)
else:
groups.append(current_group)
current_group = [chunk]

groups.append(current_group)

return groups

async def _get_group_bytes(
self,
group: list[_ChunkCoordsByteSlice],
byte_getter: ByteGetter,
prototype: BufferPrototype,
) -> ShardMapping | None:
"""
Reads a possibly coalesced group of one or more chunks from a shard.
Returns a mapping of chunk coordinates to bytes.
"""
group_start = group[0].byte_slice.start
group_end = group[-1].byte_slice.stop

# A single call to retrieve the bytes for the entire group.
group_bytes = await byte_getter.get(
prototype=prototype,
byte_range=RangeByteRequest(group_start, group_end),
)
if group_bytes is None:
return None

# Extract the bytes corresponding to each chunk in group from group_bytes.
shard_dict = {}
for chunk in group:
chunk_slice = slice(
chunk.byte_slice.start - group_start,
chunk.byte_slice.stop - group_start,
)
shard_dict[chunk.coords] = group_bytes[chunk_slice]

return shard_dict

def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int:
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
return input_byte_length + self._shard_index_size(chunks_per_shard)
Expand Down
6 changes: 6 additions & 0 deletions src/zarr/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def enable_gpu(self) -> ConfigSet:
},
"async": {"concurrency": 10, "timeout": None},
"threading": {"max_workers": None},
"sharding": {
"read": {
"coalesce_max_bytes": 100 * 2**20, # 100MiB
"coalesce_max_gap_bytes": 2**20, # 1MiB
}
},
"json_indent": 2,
"codec_pipeline": {
"path": "zarr.core.codec_pipeline.BatchedCodecPipeline",
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def __iter__(self) -> Iterator[ChunkProjection]:
stop = self.chunk_nitems_cumsum[chunk_rix]
out_selection: slice | npt.NDArray[np.intp]
if self.sel_sort is None:
out_selection = slice(start, stop)
out_selection = np.arange(start, stop)
else:
out_selection = self.sel_sort[start:stop]

Expand Down
Loading
Loading