Skip to content

Commit 7cfcf5e

Browse files
committed
Add is_concurrency_safe function
- Introduced is_concurrency_safe function to determine if the file system supports concurrent access. - Modified get_array_metadata to read metadata files sequentially for non-concurrency-safe file systems. - Enhanced compatibility with synchronous file systems by avoiding deadlocks when accessing metadata concurrently. - Updated logic to conditionally use asyncio.gather based on the concurrency safety of the underlying file system.
1 parent 6c40629 commit 7cfcf5e

File tree

2 files changed

+59
-20
lines changed

2 files changed

+59
-20
lines changed

src/zarr/core/array.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -196,26 +196,42 @@ def create_codec_pipeline(metadata: ArrayMetadata) -> CodecPipeline:
196196
raise TypeError # pragma: no cover
197197

198198

199+
def is_concurrency_safe(store: Store) -> bool:
200+
fs = getattr(store, "fs", None)
201+
return getattr(fs, "asynchronous", True)
202+
203+
199204
async def get_array_metadata(
200205
store_path: StorePath, zarr_format: ZarrFormat | None = 3
201206
) -> dict[str, JSON]:
207+
concurrency_safe = is_concurrency_safe(store_path.store)
208+
202209
if zarr_format == 2:
203-
zarray_bytes, zattrs_bytes = await gather(
204-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
205-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
206-
)
210+
if concurrency_safe:
211+
zarray_bytes, zattrs_bytes = await gather(
212+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
213+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
214+
)
215+
else:
216+
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
217+
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
207218
if zarray_bytes is None:
208219
raise FileNotFoundError(store_path)
209220
elif zarr_format == 3:
210221
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
211222
if zarr_json_bytes is None:
212223
raise FileNotFoundError(store_path)
213224
elif zarr_format is None:
214-
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
215-
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
216-
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
217-
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
218-
)
225+
if concurrency_safe:
226+
zarr_json_bytes, zarray_bytes, zattrs_bytes = await gather(
227+
(store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype),
228+
(store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype),
229+
(store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype),
230+
)
231+
else:
232+
zarr_json_bytes = await (store_path / ZARR_JSON).get(prototype=cpu_buffer_prototype)
233+
zarray_bytes = await (store_path / ZARRAY_JSON).get(prototype=cpu_buffer_prototype)
234+
zattrs_bytes = await (store_path / ZATTRS_JSON).get(prototype=cpu_buffer_prototype)
219235
if zarr_json_bytes is not None and zarray_bytes is not None:
220236
# warn and favor v3
221237
msg = f"Both zarr.json (Zarr format 3) and .zarray (Zarr format 2) metadata objects exist at {store_path}. Zarr v3 will be used."
@@ -1426,7 +1442,11 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F
14261442
]
14271443
)
14281444

1429-
await gather(*awaitables)
1445+
if is_concurrency_safe(self.store_path.store):
1446+
await gather(*awaitables)
1447+
else:
1448+
for awaitable in awaitables:
1449+
await awaitable
14301450

14311451
async def _set_selection(
14321452
self,

src/zarr/core/group.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_build_parents,
3131
_parse_deprecated_compressor,
3232
create_array,
33+
is_concurrency_safe,
3334
)
3435
from zarr.core.attributes import Attributes
3536
from zarr.core.buffer import default_buffer_prototype
@@ -533,17 +534,23 @@ async def open(
533534
if zarr_json_bytes is None:
534535
raise FileNotFoundError(store_path)
535536
elif zarr_format is None:
537+
paths = [
538+
(store_path / ZARR_JSON).get(),
539+
(store_path / ZGROUP_JSON).get(),
540+
(store_path / ZATTRS_JSON).get(),
541+
(store_path / consolidated_key).get(),
542+
]
536543
(
537544
zarr_json_bytes,
538545
zgroup_bytes,
539546
zattrs_bytes,
540547
maybe_consolidated_metadata_bytes,
541-
) = await asyncio.gather(
542-
(store_path / ZARR_JSON).get(),
543-
(store_path / ZGROUP_JSON).get(),
544-
(store_path / ZATTRS_JSON).get(),
545-
(store_path / str(consolidated_key)).get(),
548+
) = (
549+
await asyncio.gather(*paths)
550+
if is_concurrency_safe(store_path.store)
551+
else [await path for path in paths]
546552
)
553+
547554
if zarr_json_bytes is not None and zgroup_bytes is not None:
548555
# warn and favor v3
549556
msg = f"Both zarr.json (Zarr format 3) and .zgroup (Zarr format 2) metadata objects exist at {store_path}. Zarr format 3 will be used."
@@ -3453,11 +3460,23 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM
34533460
"""
34543461
# TODO: consider first fetching array metadata, and only fetching group metadata when we don't
34553462
# find an array
3456-
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3457-
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3458-
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3459-
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3460-
)
3463+
print(f"Reading metadata from {path} in store {store}", file=sys.stderr)
3464+
if is_concurrency_safe(store):
3465+
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3466+
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3467+
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3468+
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3469+
)
3470+
else:
3471+
zarray_bytes = await store.get(
3472+
_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()
3473+
)
3474+
zgroup_bytes = await store.get(
3475+
_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()
3476+
)
3477+
zattrs_bytes = await store.get(
3478+
_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()
3479+
)
34613480

34623481
if zattrs_bytes is None:
34633482
zattrs = {}

0 commit comments

Comments
 (0)