Skip to content

Commit 4270159

Browse files
committed
Add
1 parent 406c942 commit 4270159

File tree

3 files changed

+45
-60
lines changed

3 files changed

+45
-60
lines changed

src/zarr/core/array.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -196,42 +196,28 @@ def create_codec_pipeline(metadata: ArrayMetadata) -> CodecPipeline:
196196
raise TypeError # pragma: no cover
197197

198198

199-
def is_concurrency_safe(store: Store) -> bool:
200-
fs = getattr(store, "fs", None)
201-
return getattr(fs, "asynchronous", True)
202-
203-
204199
async def get_array_metadata(
205200
store_path: StorePath, zarr_format: ZarrFormat | None = 3
206201
) -> dict[str, JSON]:
207-
concurrency_safe = is_concurrency_safe(store_path.store)
208-
209202
if zarr_format == 2:
210-
if concurrency_safe:
211-
zarray_bytes, zattrs_bytes = await gather(
212-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
213-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
214-
)
215-
else:
216-
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
217-
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
203+
zarray_bytes, zattrs_bytes = await store_path.get_many(
204+
ZARRAY_JSON,
205+
ZATTRS_JSON,
206+
prototype=cpu_buffer_prototype,
207+
)
218208
if zarray_bytes is None:
219209
raise FileNotFoundError(store_path)
220210
elif zarr_format == 3:
221211
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
222212
if zarr_json_bytes is None:
223213
raise FileNotFoundError(store_path)
224214
elif zarr_format is None:
225-
if concurrency_safe:
226-
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
227-
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
228-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
229-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
230-
)
231-
else:
232-
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
233-
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
234-
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
215+
zarr_json_bytes, zarray_bytes, zattrs_bytes = await store_path.get_many(
216+
ZARR_JSON,
217+
ZARRAY_JSON,
218+
ZATTRS_JSON,
219+
prototype=cpu_buffer_prototype,
220+
)
235221
if zarr_json_bytes is not None and zarray_bytes is not None:
236222
# warn and favor v3
237223
msg = f"Both zarr.json (Zarr format 3) and .zarray (Zarr format 2) metadata objects exist at {store_path}. Zarr v3 will be used."
@@ -1445,12 +1431,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
14451431
).items()
14461432
]
14471433
)
1448-
1449-
if is_concurrency_safe(self.store_path.store):
1450-
await gather(*awaitables)
1451-
else:
1452-
for awaitable in awaitables:
1453-
await awaitable
1434+
await gather(*awaitables)
14541435

14551436
async def _set_selection(
14561437
self,

src/zarr/core/group.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
_build_parents,
3232
_parse_deprecated_compressor,
3333
create_array,
34-
is_concurrency_safe,
3534
)
3635
from zarr.core.attributes import Attributes
3736
from zarr.core.buffer import default_buffer_prototype
@@ -536,23 +535,17 @@ async def open(
536535
if zarr_json_bytes is None:
537536
raise FileNotFoundError(store_path)
538537
elif zarr_format is None:
539-
paths = [
540-
(store_path / ZARR_JSON).get(),
541-
(store_path / ZGROUP_JSON).get(),
542-
(store_path / ZATTRS_JSON).get(),
543-
(store_path / consolidated_key).get(),
544-
]
545538
(
546539
zarr_json_bytes,
547540
zgroup_bytes,
548541
zattrs_bytes,
549542
maybe_consolidated_metadata_bytes,
550-
) = (
551-
await asyncio.gather(*paths)
552-
if is_concurrency_safe(store_path.store)
553-
else [await path for path in paths]
543+
) = await store_path.get_many(
544+
ZARR_JSON,
545+
ZGROUP_JSON,
546+
ZATTRS_JSON,
547+
consolidated_key,
554548
)
555-
556549
if zarr_json_bytes is not None and zgroup_bytes is not None:
557550
# warn and favor v3
558551
msg = f"Both zarr.json (Zarr format 3) and .zgroup (Zarr format 2) metadata objects exist at {store_path}. Zarr format 3 will be used."
@@ -3484,23 +3477,11 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM
34843477
# TODO: consider first fetching array metadata, and only fetching group metadata when we don't
34853478
# find an array
34863479
print(f"Reading metadata from {path} in store {store}", file=sys.stderr)
3487-
if is_concurrency_safe(store):
3488-
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3489-
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3490-
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3491-
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3492-
)
3493-
else:
3494-
zarray_bytes = await store.get(
3495-
_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()
3496-
)
3497-
zgroup_bytes = await store.get(
3498-
_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()
3499-
)
3500-
zattrs_bytes = await store.get(
3501-
_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()
3502-
)
3503-
3480+
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3481+
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3482+
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3483+
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3484+
)
35043485
if zattrs_bytes is None:
35053486
zattrs = {}
35063487
else:

src/zarr/storage/_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import importlib.util
44
import json
5+
from asyncio import gather
56
from pathlib import Path
67
from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias
78

@@ -163,6 +164,24 @@ async def get(
163164
prototype = default_buffer_prototype()
164165
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)
165166

167+
async def get_many(
168+
self,
169+
*suffixes : str,
170+
prototype: BufferPrototype | None = None,
171+
byte_range: ByteRequest | None = None,
172+
):
173+
tasks = [
174+
(self / suffix).get(prototype=prototype, byte_range=byte_range) for suffix in suffixes
175+
]
176+
if await self._is_concurrency_save():
177+
return await gather(*tasks)
178+
else:
179+
results = []
180+
for task in tasks:
181+
result = await task
182+
results.append(result)
183+
return results
184+
166185
async def set(self, value: Buffer, byte_range: ByteRequest | None = None) -> None:
167186
"""
168187
Write bytes to the store.
@@ -263,6 +282,10 @@ def __eq__(self, other: object) -> bool:
263282
pass
264283
return False
265284

285+
async def _is_concurrency_save(self):
286+
fs = getattr(self.store, "fs", None)
287+
return getattr(fs, "asynchronous", True)
288+
266289

267290
StoreLike: TypeAlias = Store | StorePath | FSMap | Path | str | dict[str, Buffer]
268291

0 commit comments

Comments
 (0)