Skip to content

Commit 406c942

Browse files
committed
Add is_concurrency_safe function
- Introduced is_concurrency_safe function to determine if the file system supports concurrent access. - Modified get_array_metadata to read metadata files sequentially for non-concurrency-safe file systems. - Enhanced compatibility with synchronous file systems by avoiding deadlocks when accessing metadata concurrently. - Updated logic to conditionally use asyncio.gather based on the concurrency safety of the underlying file system.
1 parent 9d97b24 commit 406c942

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

src/zarr/core/array.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -196,26 +196,42 @@ def create_codec_pipeline(metadata: ArrayMetadata) -> CodecPipeline:
196196
raise TypeError # pragma: no cover
197197

198198

199+
def is_concurrency_safe(store: Store) -> bool:
200+
fs = getattr(store, "fs", None)
201+
return getattr(fs, "asynchronous", True)
202+
203+
199204
async def get_array_metadata(
200205
store_path: StorePath, zarr_format: ZarrFormat | None = 3
201206
) -> dict[str, JSON]:
207+
concurrency_safe = is_concurrency_safe(store_path.store)
208+
202209
if zarr_format == 2:
203-
zarray_bytes, zattrs_bytes = await gather(
204-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
205-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
206-
)
210+
if concurrency_safe:
211+
zarray_bytes, zattrs_bytes = await gather(
212+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
213+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
214+
)
215+
else:
216+
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
217+
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
207218
if zarray_bytes is None:
208219
raise FileNotFoundError(store_path)
209220
elif zarr_format == 3:
210221
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
211222
if zarr_json_bytes is None:
212223
raise FileNotFoundError(store_path)
213224
elif zarr_format is None:
214-
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
215-
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
216-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
217-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
218-
)
225+
if concurrency_safe:
226+
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
227+
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
228+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
229+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
230+
)
231+
else:
232+
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
233+
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
234+
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
219235
if zarr_json_bytes is not None and zarray_bytes is not None:
220236
# warn and favor v3
221237
msg = f"Both zarr.json (Zarr format 3) and .zarray (Zarr format 2) metadata objects exist at {store_path}. Zarr v3 will be used."
@@ -1430,7 +1446,11 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
14301446
]
14311447
)
14321448

1433-
await gather(*awaitables)
1449+
if is_concurrency_safe(self.store_path.store):
1450+
await gather(*awaitables)
1451+
else:
1452+
for awaitable in awaitables:
1453+
await awaitable
14341454

14351455
async def _set_selection(
14361456
self,

src/zarr/core/group.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_build_parents,
3232
_parse_deprecated_compressor,
3333
create_array,
34+
is_concurrency_safe,
3435
)
3536
from zarr.core.attributes import Attributes
3637
from zarr.core.buffer import default_buffer_prototype
@@ -535,17 +536,23 @@ async def open(
535536
if zarr_json_bytes is None:
536537
raise FileNotFoundError(store_path)
537538
elif zarr_format is None:
539+
paths = [
540+
(store_path / ZARR_JSON).get(),
541+
(store_path / ZGROUP_JSON).get(),
542+
(store_path / ZATTRS_JSON).get(),
543+
(store_path / consolidated_key).get(),
544+
]
538545
(
539546
zarr_json_bytes,
540547
zgroup_bytes,
541548
zattrs_bytes,
542549
maybe_consolidated_metadata_bytes,
543-
) = await asyncio.gather(
544-
(store_path / ZARR_JSON).get(),
545-
(store_path / ZGROUP_JSON).get(),
546-
(store_path / ZATTRS_JSON).get(),
547-
(store_path / str(consolidated_key)).get(),
550+
) = (
551+
await asyncio.gather(*paths)
552+
if is_concurrency_safe(store_path.store)
553+
else [await path for path in paths]
548554
)
555+
549556
if zarr_json_bytes is not None and zgroup_bytes is not None:
550557
# warn and favor v3
551558
msg = f"Both zarr.json (Zarr format 3) and .zgroup (Zarr format 2) metadata objects exist at {store_path}. Zarr format 3 will be used."
@@ -3476,11 +3483,23 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM
34763483
"""
34773484
# TODO: consider first fetching array metadata, and only fetching group metadata when we don't
34783485
# find an array
3479-
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3480-
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3481-
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3482-
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3483-
)
3486+
print(f"Reading metadata from {path} in store {store}", file=sys.stderr)
3487+
if is_concurrency_safe(store):
3488+
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3489+
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3490+
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3491+
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3492+
)
3493+
else:
3494+
zarray_bytes = await store.get(
3495+
_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()
3496+
)
3497+
zgroup_bytes = await store.get(
3498+
_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()
3499+
)
3500+
zattrs_bytes = await store.get(
3501+
_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()
3502+
)
34843503

34853504
if zattrs_bytes is None:
34863505
zattrs = {}

0 commit comments

Comments
 (0)