From 16675fa798ed388be87c3036c50f106bf2813079 Mon Sep 17 00:00:00 2001 From: Bojidar Marinov Date: Tue, 8 Jul 2025 13:48:14 +0300 Subject: [PATCH] Add refresh_attributes() and implement cache_attrs for Group, Array Should resolve #3178 --- src/zarr/api/synchronous.py | 24 +++++++++----- src/zarr/core/array.py | 32 +++++++++++++++++++ src/zarr/core/group.py | 43 +++++++++++++++++++++++-- tests/test_array.py | 62 +++++++++++++++++++++++++++++++++++++ tests/test_group.py | 33 ++++++++++++++++++++ 5 files changed, 184 insertions(+), 10 deletions(-) diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index b60f69a673..667323ddb7 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -441,7 +441,6 @@ def group( store=store, overwrite=overwrite, chunk_store=chunk_store, - cache_attrs=cache_attrs, synchronizer=synchronizer, path=path, zarr_version=zarr_version, @@ -450,7 +449,8 @@ def group( attributes=attributes, storage_options=storage_options, ) - ) + ), + cache_attrs=cache_attrs, ) @@ -536,7 +536,6 @@ def open_group( async_api.open_group( store=store, mode=mode, - cache_attrs=cache_attrs, synchronizer=synchronizer, path=path, chunk_store=chunk_store, @@ -547,7 +546,8 @@ def open_group( attributes=attributes, use_consolidated=use_consolidated, ) - ) + ), + cache_attrs=cache_attrs, ) @@ -559,6 +559,7 @@ def create_group( overwrite: bool = False, attributes: dict[str, Any] | None = None, storage_options: dict[str, Any] | None = None, + cache_attrs: bool | None = None, ) -> Group: """Create a group. @@ -595,7 +596,8 @@ def create_group( zarr_format=zarr_format, attributes=attributes, ) - ) + ), + cache_attrs=cache_attrs, ) @@ -730,7 +732,6 @@ def create( chunk_store=chunk_store, filters=filters, cache_metadata=cache_metadata, - cache_attrs=cache_attrs, read_only=read_only, object_codec=object_codec, dimension_separator=dimension_separator, @@ -747,7 +748,8 @@ def create( config=config, **kwargs, ) - ) + ), + cache_attrs=cache_attrs, ) @@ -773,6 +775,7 @@ def create_array( overwrite: bool = False, config: ArrayConfigLike | None = None, write_data: bool = True, + cache_attrs: bool | None = None, ) -> Array: """Create an array. @@ -872,6 +875,10 @@ def create_array( then ``write_data`` determines whether the values in that array-like object should be written to the Zarr array created by this function. If ``write_data`` is ``False``, then the array will be left empty. + cache_attrs : bool, optional + If True (default), user attributes will be cached for attribute read + operations. If False, user attributes are reloaded from the store prior + to all attribute read operations. Returns ------- @@ -914,7 +921,8 @@ def create_array( config=config, write_data=write_data, ) - ) + ), + cache_attrs=cache_attrs, ) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index a44a4b55d1..aa1d007938 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -1677,6 +1677,22 @@ async def update_attributes(self, new_attributes: dict[str, JSON]) -> Self: return self + async def refresh_attributes(self) -> Self: + """Reload the attributes of this array from the store. + + Returns + ------- + AsyncArray + The array updated with the newest attributes from storage.""" + + metadata = await get_array_metadata(self.store_path, self.metadata.zarr_format) + reparsed_metadata = parse_array_metadata(metadata) + + self.metadata.attributes.clear() + self.metadata.attributes.update(reparsed_metadata.attributes) + + return self + def __repr__(self) -> str: return f"" @@ -1768,6 +1784,7 @@ class Array: """ _async_array: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] + cache_attrs: bool | None = field(default=None) @classmethod @deprecated("Use zarr.create_array instead.") @@ -2105,6 +2122,8 @@ def attrs(self) -> Attributes: ----- Note that attribute values must be JSON serializable. """ + if self.cache_attrs is False: + self.refresh_attributes() return Attributes(self) @property @@ -3703,6 +3722,19 @@ def update_attributes(self, new_attributes: dict[str, JSON]) -> Array: _new_array = cast("AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]", new_array) return type(self)(_new_array) + def refresh_attributes(self) -> Array: + """Reload the attributes of this array from the store. + + Returns + ------- + Array + The array with the updated attributes.""" + # TODO: remove this cast when type inference improves + new_array = sync(self._async_array.refresh_attributes()) + # TODO: remove this cast when type inference improves + _new_array = cast("AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]", new_array) + return type(self)(_new_array) + def __repr__(self) -> str: return f"" diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index bad710ed43..9f11c46573 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1273,6 +1273,26 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: return self + async def refresh_attributes(self) -> AsyncGroup: + """Reload the attributes of this group from the store. + + Returns + ------- + self : AsyncGroup + The group updated with the newest attributes from storage. + """ + + reparsed_metadata = await _read_group_metadata( + store=self.store_path.store, + path=self.store_path.path, + zarr_format=self.metadata.zarr_format, + ) + + self.metadata.attributes.clear() + self.metadata.attributes.update(reparsed_metadata.attributes) + + return self + def __repr__(self) -> str: return f"" @@ -1774,6 +1794,7 @@ class Group(SyncMixin): """ _async_group: AsyncGroup + cache_attrs: bool | None = field(default=None) @classmethod def from_store( @@ -1783,6 +1804,7 @@ def from_store( attributes: dict[str, Any] | None = None, zarr_format: ZarrFormat = 3, overwrite: bool = False, + cache_attrs: bool | None = None, ) -> Group: """Instantiate a group from an initialized store. @@ -1796,6 +1818,10 @@ def from_store( Zarr storage format version. overwrite : bool, optional If True, do not raise an error if the group already exists. + cache_attrs : bool, optional + If True (default), user attributes will be cached for attribute read + operations. If False, user attributes are reloaded from the store prior + to all attribute read operations. Returns ------- @@ -1816,13 +1842,15 @@ def from_store( ), ) - return cls(obj) + return cls(obj, cache_attrs=cache_attrs) @classmethod def open( cls, store: StoreLike, zarr_format: ZarrFormat | None = 3, + *, + cache_attrs: bool | None = None, ) -> Group: """Open a group from an initialized store. @@ -1832,6 +1860,10 @@ def open( Store containing the Group. zarr_format : {2, 3, None}, optional Zarr storage format version. + cache_attrs : bool, optional + If True (default), user attributes will be cached for attribute read + operations. If False, user attributes are reloaded from the store prior + to all attribute read operations. Returns ------- @@ -1839,7 +1871,7 @@ def open( Group instantiated from the store. """ obj = sync(AsyncGroup.open(store, zarr_format=zarr_format)) - return cls(obj) + return cls(obj, cache_attrs=cache_attrs) def __getitem__(self, path: str) -> Array | Group: """Obtain a group member. @@ -2024,6 +2056,8 @@ def basename(self) -> str: @property def attrs(self) -> Attributes: """Attributes of this Group""" + if self.cache_attrs is False: + self.refresh_attributes() return Attributes(self) @property @@ -2090,6 +2124,11 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group: self._sync(self._async_group.update_attributes(new_attributes)) return self + def refresh_attributes(self) -> Group: + """Reload the attributes of this Group from the store.""" + self._sync(self._async_group.refresh_attributes()) + return self + def nmembers(self, max_depth: int | None = 0) -> int: """Count the number of members in this group. diff --git a/tests/test_array.py b/tests/test_array.py index fe23bc1284..8ee6ff4c7b 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -494,6 +494,68 @@ def test_update_attrs(zarr_format: ZarrFormat) -> None: assert arr2.attrs["foo"] == "bar" +@pytest.mark.parametrize("zarr_format", [2, 3]) +def test_refresh_attrs(zarr_format: ZarrFormat) -> None: + """ + Test the behavior of `Array.refresh_attributes` + """ + store = MemoryStore() + attrs: dict[str, JSON] = {"foo": 100} + arr = zarr.create_array( + store=store, shape=(5,), chunks=(5,), dtype="f8", attributes=attrs, zarr_format=zarr_format + ) + assert arr.attrs.asdict() == attrs + + new_attrs: dict[str, JSON] = {"bar": 50} + arr2 = zarr.create_array( + store=store, + shape=(5,), + chunks=(5,), + dtype="f8", + attributes=new_attrs, + zarr_format=zarr_format, + overwrite=True, + ) + assert arr2.attrs.asdict() == new_attrs + + assert arr.attrs.asdict() == attrs + arr.refresh_attributes() + assert arr.attrs.asdict() == new_attrs + + +@pytest.mark.parametrize("zarr_format", [2, 3]) +def test_cache_attrs(zarr_format: ZarrFormat) -> None: + """ + Test the behavior of `Array.cache_attrs` + """ + store = MemoryStore() + attrs: dict[str, JSON] = {"foo": 100} + arr = zarr.create_array( + store=store, + shape=(5,), + chunks=(5,), + dtype="f8", + attributes=attrs, + zarr_format=zarr_format, + cache_attrs=False, + ) + assert arr.attrs.asdict() == attrs + + new_attrs: dict[str, JSON] = {"bar": 50} + arr2 = zarr.create_array( + store=store, + shape=(5,), + chunks=(5,), + dtype="f8", + attributes=new_attrs, + zarr_format=zarr_format, + overwrite=True, + ) + + assert arr2.attrs.asdict() == new_attrs + assert arr.attrs.asdict() == new_attrs + + @pytest.mark.parametrize(("chunks", "shards"), [((2, 2), None), ((2, 2), (4, 4))]) class TestInfo: def test_info_v2(self, chunks: tuple[int, int], shards: tuple[int, int] | None) -> None: diff --git a/tests/test_group.py b/tests/test_group.py index 60a1fcb9bf..8fbfd3ea32 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -627,6 +627,39 @@ async def test_group_update_attributes_async(store: Store, zarr_format: ZarrForm assert new_group.attrs == new_attrs +def test_group_refresh_attributes(store: Store, zarr_format: ZarrFormat) -> None: + """ + Test the behavior of `Group.refresh_attributes` + """ + attrs = {"foo": 100} + group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs) + assert group.attrs == attrs + new_attrs = {"foo": 50} + group2 = Group.open(store, zarr_format=zarr_format) + group2.update_attributes(new_attrs) + assert group2.attrs == new_attrs + + assert group.attrs == attrs + new_group = group.refresh_attributes() + assert new_group.attrs == new_attrs + + +def test_group_cache_attrs(store: Store, zarr_format: ZarrFormat) -> None: + """ + Test the behavior of `Group.cache_attrs` + """ + attrs = {"foo": 100} + group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs, cache_attrs=False) + assert group.attrs == attrs + + new_attrs = {"foo": 50} + group2 = Group.open(store, zarr_format=zarr_format) + group2.update_attributes(new_attrs) + assert group2.attrs == new_attrs + + assert group.attrs == new_attrs + + @pytest.mark.parametrize("method", ["create_array", "array"]) @pytest.mark.parametrize("name", ["a", "/a"]) def test_group_create_array(