Skip to content

Commit de7434a

Browse files
committed
Add
1 parent 7cfcf5e commit de7434a

File tree

3 files changed

+45
-60
lines changed

3 files changed

+45
-60
lines changed

src/zarr/core/array.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -196,42 +196,28 @@ def create_codec_pipeline(metadata: ArrayMetadata) -> CodecPipeline:
196196
raise TypeError # pragma: no cover
197197

198198

199-
def is_concurrency_safe(store: Store) -> bool:
200-
fs = getattr(store, "fs", None)
201-
return getattr(fs, "asynchronous", True)
202-
203-
204199
async def get_array_metadata(
205200
store_path: StorePath, zarr_format: ZarrFormat | None = 3
206201
) -> dict[str, JSON]:
207-
concurrency_safe = is_concurrency_safe(store_path.store)
208-
209202
if zarr_format == 2:
210-
if concurrency_safe:
211-
zarray_bytes, zattrs_bytes = await gather(
212-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
213-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
214-
)
215-
else:
216-
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
217-
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
203+
zarray_bytes, zattrs_bytes = await store_path.get_many(
204+
ZARRAY_JSON,
205+
ZATTRS_JSON,
206+
prototype=cpu_buffer_prototype,
207+
)
218208
if zarray_bytes is None:
219209
raise FileNotFoundError(store_path)
220210
elif zarr_format == 3:
221211
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
222212
if zarr_json_bytes is None:
223213
raise FileNotFoundError(store_path)
224214
elif zarr_format is None:
225-
if concurrency_safe:
226-
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
227-
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
228-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
229-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
230-
)
231-
else:
232-
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
233-
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
234-
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
215+
zarr_json_bytes, zarray_bytes, zattrs_bytes = await store_path.get_many(
216+
ZARR_JSON,
217+
ZARRAY_JSON,
218+
ZATTRS_JSON,
219+
prototype=cpu_buffer_prototype,
220+
)
235221
if zarr_json_bytes is not None and zarray_bytes is not None:
236222
# warn and favor v3
237223
msg = f"Both zarr.json (Zarr format 3) and .zarray (Zarr format 2) metadata objects exist at {store_path}. Zarr v3 will be used."
@@ -1441,12 +1427,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
14411427
).items()
14421428
]
14431429
)
1444-
1445-
if is_concurrency_safe(self.store_path.store):
1446-
await gather(*awaitables)
1447-
else:
1448-
for awaitable in awaitables:
1449-
await awaitable
1430+
await gather(*awaitables)
14501431

14511432
async def _set_selection(
14521433
self,

src/zarr/core/group.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
_build_parents,
3131
_parse_deprecated_compressor,
3232
create_array,
33-
is_concurrency_safe,
3433
)
3534
from zarr.core.attributes import Attributes
3635
from zarr.core.buffer import default_buffer_prototype
@@ -534,23 +533,17 @@ async def open(
534533
if zarr_json_bytes is None:
535534
raise FileNotFoundError(store_path)
536535
elif zarr_format is None:
537-
paths = [
538-
(store_path / ZARR_JSON).get(),
539-
(store_path / ZGROUP_JSON).get(),
540-
(store_path / ZATTRS_JSON).get(),
541-
(store_path / consolidated_key).get(),
542-
]
543536
(
544537
zarr_json_bytes,
545538
zgroup_bytes,
546539
zattrs_bytes,
547540
maybe_consolidated_metadata_bytes,
548-
) = (
549-
await asyncio.gather(*paths)
550-
if is_concurrency_safe(store_path.store)
551-
else [await path for path in paths]
541+
) = await store_path.get_many(
542+
ZARR_JSON,
543+
ZGROUP_JSON,
544+
ZATTRS_JSON,
545+
consolidated_key,
552546
)
553-
554547
if zarr_json_bytes is not None and zgroup_bytes is not None:
555548
# warn and favor v3
556549
msg = f"Both zarr.json (Zarr format 3) and .zgroup (Zarr format 2) metadata objects exist at {store_path}. Zarr format 3 will be used."
@@ -3461,23 +3454,11 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM
34613454
# TODO: consider first fetching array metadata, and only fetching group metadata when we don't
34623455
# find an array
34633456
print(f"Reading metadata from {path} in store {store}", file=sys.stderr)
3464-
if is_concurrency_safe(store):
3465-
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3466-
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3467-
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3468-
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3469-
)
3470-
else:
3471-
zarray_bytes = await store.get(
3472-
_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()
3473-
)
3474-
zgroup_bytes = await store.get(
3475-
_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()
3476-
)
3477-
zattrs_bytes = await store.get(
3478-
_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()
3479-
)
3480-
3457+
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3458+
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3459+
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3460+
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3461+
)
34813462
if zattrs_bytes is None:
34823463
zattrs = {}
34833464
else:

src/zarr/storage/_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import importlib.util
44
import json
5+
from asyncio import gather
56
from pathlib import Path
67
from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias
78

@@ -163,6 +164,24 @@ async def get(
163164
prototype = default_buffer_prototype()
164165
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)
165166

167+
async def get_many(
168+
self,
169+
*suffixes : str,
170+
prototype: BufferPrototype | None = None,
171+
byte_range: ByteRequest | None = None,
172+
):
173+
tasks = [
174+
(self / suffix).get(prototype=prototype, byte_range=byte_range) for suffix in suffixes
175+
]
176+
if await self._is_concurrency_save():
177+
return await gather(*tasks)
178+
else:
179+
results = []
180+
for task in tasks:
181+
result = await task
182+
results.append(result)
183+
return results
184+
166185
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
167186
"""
168187
Write bytes to the store.
@@ -263,6 +282,10 @@ def __eq__(self, other: object) -> bool:
263282
pass
264283
return False
265284

285+
async def _is_concurrency_save(self):
286+
fs = getattr(self.store, "fs", None)
287+
return getattr(fs, "asynchronous", True)
288+
266289

267290
StoreLike: TypeAlias = Store | StorePath | FSMap | Path | str | dict[str, Buffer]
268291

0 commit comments

Comments
 (0)