Skip to content

Add refresh_attributes() and implement cache_attrs for Group and Array #3215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/zarr/api/synchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ def group(
store=store,
overwrite=overwrite,
chunk_store=chunk_store,
cache_attrs=cache_attrs,
synchronizer=synchronizer,
path=path,
zarr_version=zarr_version,
Expand All @@ -450,7 +449,8 @@ def group(
attributes=attributes,
storage_options=storage_options,
)
)
),
cache_attrs=cache_attrs,
)


Expand Down Expand Up @@ -536,7 +536,6 @@ def open_group(
async_api.open_group(
store=store,
mode=mode,
cache_attrs=cache_attrs,
synchronizer=synchronizer,
path=path,
chunk_store=chunk_store,
Expand All @@ -547,7 +546,8 @@ def open_group(
attributes=attributes,
use_consolidated=use_consolidated,
)
)
),
cache_attrs=cache_attrs,
)


Expand All @@ -559,6 +559,7 @@ def create_group(
overwrite: bool = False,
attributes: dict[str, Any] | None = None,
storage_options: dict[str, Any] | None = None,
cache_attrs: bool | None = None,
) -> Group:
"""Create a group.

Expand Down Expand Up @@ -595,7 +596,8 @@ def create_group(
zarr_format=zarr_format,
attributes=attributes,
)
)
),
cache_attrs=cache_attrs,
)


Expand Down Expand Up @@ -730,7 +732,6 @@ def create(
chunk_store=chunk_store,
filters=filters,
cache_metadata=cache_metadata,
cache_attrs=cache_attrs,
read_only=read_only,
object_codec=object_codec,
dimension_separator=dimension_separator,
Expand All @@ -747,7 +748,8 @@ def create(
config=config,
**kwargs,
)
)
),
cache_attrs=cache_attrs,
)


Expand All @@ -773,6 +775,7 @@ def create_array(
overwrite: bool = False,
config: ArrayConfigLike | None = None,
write_data: bool = True,
cache_attrs: bool | None = None,
) -> Array:
"""Create an array.

Expand Down Expand Up @@ -872,6 +875,10 @@ def create_array(
then ``write_data`` determines whether the values in that array-like object should be
written to the Zarr array created by this function. If ``write_data`` is ``False``, then the
array will be left empty.
cache_attrs : bool, optional
If True (default), user attributes will be cached for attribute read
operations. If False, user attributes are reloaded from the store prior
to all attribute read operations.
Comment on lines +878 to +881
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring says true is the default, but the parameter itself has a default of None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other places defining a cache_attrs argument did the same, so I decided to follow suit. (:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets break with the past and ensure that the docstring matches the annotation here


Returns
-------
Expand Down Expand Up @@ -914,7 +921,8 @@ def create_array(
config=config,
write_data=write_data,
)
)
),
cache_attrs=cache_attrs,
)


Expand Down
32 changes: 32 additions & 0 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,22 @@ async def update_attributes(self, new_attributes: dict[str, JSON]) -> Self:

return self

async def refresh_attributes(self) -> Self:
"""Reload the attributes of this array from the store.

Returns
-------
AsyncArray
The array updated with the newest attributes from storage."""

metadata = await get_array_metadata(self.store_path, self.metadata.zarr_format)
reparsed_metadata = parse_array_metadata(metadata)

self.metadata.attributes.clear()
self.metadata.attributes.update(reparsed_metadata.attributes)

return self

def __repr__(self) -> str:
return f"<AsyncArray {self.store_path} shape={self.shape} dtype={self.dtype}>"

Expand Down Expand Up @@ -1768,6 +1784,7 @@ class Array:
"""

_async_array: AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata]
cache_attrs: bool | None = field(default=None)

@classmethod
@deprecated("Use zarr.create_array instead.")
Expand Down Expand Up @@ -2105,6 +2122,8 @@ def attrs(self) -> Attributes:
-----
Note that attribute values must be JSON serializable.
"""
if self.cache_attrs is False:
self.refresh_attributes()
return Attributes(self)

@property
Expand Down Expand Up @@ -3703,6 +3722,19 @@ def update_attributes(self, new_attributes: dict[str, JSON]) -> Array:
_new_array = cast("AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]", new_array)
return type(self)(_new_array)

def refresh_attributes(self) -> Array:
"""Reload the attributes of this array from the store.

Returns
-------
Array
The array with the updated attributes."""
# TODO: remove this cast when type inference improves
new_array = sync(self._async_array.refresh_attributes())
# TODO: remove this cast when type inference improves
_new_array = cast("AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]", new_array)
return type(self)(_new_array)

def __repr__(self) -> str:
return f"<Array {self.store_path} shape={self.shape} dtype={self.dtype}>"

Expand Down
43 changes: 41 additions & 2 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,26 @@ async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:

return self

async def refresh_attributes(self) -> AsyncGroup:
"""Reload the attributes of this group from the store.

Returns
-------
self : AsyncGroup
The group updated with the newest attributes from storage.
"""

reparsed_metadata = await _read_group_metadata(
store=self.store_path.store,
path=self.store_path.path,
zarr_format=self.metadata.zarr_format,
)

self.metadata.attributes.clear()
self.metadata.attributes.update(reparsed_metadata.attributes)

return self

def __repr__(self) -> str:
return f"<AsyncGroup {self.store_path}>"

Expand Down Expand Up @@ -1774,6 +1794,7 @@ class Group(SyncMixin):
"""

_async_group: AsyncGroup
cache_attrs: bool | None = field(default=None)

@classmethod
def from_store(
Expand All @@ -1783,6 +1804,7 @@ def from_store(
attributes: dict[str, Any] | None = None,
zarr_format: ZarrFormat = 3,
overwrite: bool = False,
cache_attrs: bool | None = None,
) -> Group:
"""Instantiate a group from an initialized store.

Expand All @@ -1796,6 +1818,10 @@ def from_store(
Zarr storage format version.
overwrite : bool, optional
If True, do not raise an error if the group already exists.
cache_attrs : bool, optional
If True (default), user attributes will be cached for attribute read
operations. If False, user attributes are reloaded from the store prior
to all attribute read operations.

Returns
-------
Expand All @@ -1816,13 +1842,15 @@ def from_store(
),
)

return cls(obj)
return cls(obj, cache_attrs=cache_attrs)

@classmethod
def open(
cls,
store: StoreLike,
zarr_format: ZarrFormat | None = 3,
*,
cache_attrs: bool | None = None,
) -> Group:
"""Open a group from an initialized store.

Expand All @@ -1832,14 +1860,18 @@ def open(
Store containing the Group.
zarr_format : {2, 3, None}, optional
Zarr storage format version.
cache_attrs : bool, optional
If True (default), user attributes will be cached for attribute read
operations. If False, user attributes are reloaded from the store prior
to all attribute read operations.

Returns
-------
Group
Group instantiated from the store.
"""
obj = sync(AsyncGroup.open(store, zarr_format=zarr_format))
return cls(obj)
return cls(obj, cache_attrs=cache_attrs)

def __getitem__(self, path: str) -> Array | Group:
"""Obtain a group member.
Expand Down Expand Up @@ -2024,6 +2056,8 @@ def basename(self) -> str:
@property
def attrs(self) -> Attributes:
"""Attributes of this Group"""
if self.cache_attrs is False:
self.refresh_attributes()
return Attributes(self)

@property
Expand Down Expand Up @@ -2090,6 +2124,11 @@ def update_attributes(self, new_attributes: dict[str, Any]) -> Group:
self._sync(self._async_group.update_attributes(new_attributes))
return self

def refresh_attributes(self) -> Group:
"""Reload the attributes of this Group from the store."""
self._sync(self._async_group.refresh_attributes())
return self

def nmembers(self, max_depth: int | None = 0) -> int:
"""Count the number of members in this group.

Expand Down
62 changes: 62 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,68 @@ def test_update_attrs(zarr_format: ZarrFormat) -> None:
assert arr2.attrs["foo"] == "bar"


@pytest.mark.parametrize("zarr_format", [2, 3])
def test_refresh_attrs(zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Array.refresh_attributes`
"""
store = MemoryStore()
attrs: dict[str, JSON] = {"foo": 100}
arr = zarr.create_array(
store=store, shape=(5,), chunks=(5,), dtype="f8", attributes=attrs, zarr_format=zarr_format
)
assert arr.attrs.asdict() == attrs

new_attrs: dict[str, JSON] = {"bar": 50}
arr2 = zarr.create_array(
store=store,
shape=(5,),
chunks=(5,),
dtype="f8",
attributes=new_attrs,
zarr_format=zarr_format,
overwrite=True,
)
assert arr2.attrs.asdict() == new_attrs

assert arr.attrs.asdict() == attrs
arr.refresh_attributes()
assert arr.attrs.asdict() == new_attrs


@pytest.mark.parametrize("zarr_format", [2, 3])
def test_cache_attrs(zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Array.cache_attrs`
"""
store = MemoryStore()
attrs: dict[str, JSON] = {"foo": 100}
arr = zarr.create_array(
store=store,
shape=(5,),
chunks=(5,),
dtype="f8",
attributes=attrs,
zarr_format=zarr_format,
cache_attrs=False,
)
assert arr.attrs.asdict() == attrs

new_attrs: dict[str, JSON] = {"bar": 50}
arr2 = zarr.create_array(
store=store,
shape=(5,),
chunks=(5,),
dtype="f8",
attributes=new_attrs,
zarr_format=zarr_format,
overwrite=True,
)

assert arr2.attrs.asdict() == new_attrs
assert arr.attrs.asdict() == new_attrs


@pytest.mark.parametrize(("chunks", "shards"), [((2, 2), None), ((2, 2), (4, 4))])
class TestInfo:
def test_info_v2(self, chunks: tuple[int, int], shards: tuple[int, int] | None) -> None:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,39 @@ async def test_group_update_attributes_async(store: Store, zarr_format: ZarrForm
assert new_group.attrs == new_attrs


def test_group_refresh_attributes(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Group.refresh_attributes`
"""
attrs = {"foo": 100}
group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs)
assert group.attrs == attrs
new_attrs = {"foo": 50}
group2 = Group.open(store, zarr_format=zarr_format)
group2.update_attributes(new_attrs)
assert group2.attrs == new_attrs

assert group.attrs == attrs
new_group = group.refresh_attributes()
assert new_group.attrs == new_attrs


def test_group_cache_attrs(store: Store, zarr_format: ZarrFormat) -> None:
"""
Test the behavior of `Group.cache_attrs`
"""
attrs = {"foo": 100}
group = Group.from_store(store, zarr_format=zarr_format, attributes=attrs, cache_attrs=False)
assert group.attrs == attrs

new_attrs = {"foo": 50}
group2 = Group.open(store, zarr_format=zarr_format)
group2.update_attributes(new_attrs)
assert group2.attrs == new_attrs

assert group.attrs == new_attrs


@pytest.mark.parametrize("method", ["create_array", "array"])
@pytest.mark.parametrize("name", ["a", "/a"])
def test_group_create_array(
Expand Down