Skip to content

Commit 604bb6d

Browse files
authored
tokenize() should ignore difference between None and {} attrs (#8797)
1 parent a241845 commit 604bb6d

File tree

7 files changed

+38
-28
lines changed

7 files changed

+38
-28
lines changed

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ def reset_coords(
10701070
dataset[self.name] = self.variable
10711071
return dataset
10721072

1073-
def __dask_tokenize__(self):
1073+
def __dask_tokenize__(self) -> object:
10741074
from dask.base import normalize_token
10751075

10761076
return normalize_token((type(self), self._variable, self._coords, self._name))

xarray/core/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def __init__(
694694
data_vars, coords
695695
)
696696

697-
self._attrs = dict(attrs) if attrs is not None else None
697+
self._attrs = dict(attrs) if attrs else None
698698
self._close = None
699699
self._encoding = None
700700
self._variables = variables
@@ -739,7 +739,7 @@ def attrs(self) -> dict[Any, Any]:
739739

740740
@attrs.setter
741741
def attrs(self, value: Mapping[Any, Any]) -> None:
742-
self._attrs = dict(value)
742+
self._attrs = dict(value) if value else None
743743

744744
@property
745745
def encoding(self) -> dict[Any, Any]:
@@ -856,11 +856,11 @@ def load(self, **kwargs) -> Self:
856856

857857
return self
858858

859-
def __dask_tokenize__(self):
859+
def __dask_tokenize__(self) -> object:
860860
from dask.base import normalize_token
861861

862862
return normalize_token(
863-
(type(self), self._variables, self._coord_names, self._attrs)
863+
(type(self), self._variables, self._coord_names, self._attrs or None)
864864
)
865865

866866
def __dask_graph__(self):

xarray/core/variable.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,11 +2592,13 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
25922592
if not isinstance(self._data, PandasIndexingAdapter):
25932593
self._data = PandasIndexingAdapter(self._data)
25942594

2595-
def __dask_tokenize__(self):
2595+
def __dask_tokenize__(self) -> object:
25962596
from dask.base import normalize_token
25972597

25982598
# Don't waste time converting pd.Index to np.ndarray
2599-
return normalize_token((type(self), self._dims, self._data.array, self._attrs))
2599+
return normalize_token(
2600+
(type(self), self._dims, self._data.array, self._attrs or None)
2601+
)
26002602

26012603
def load(self):
26022604
# data is already loaded into memory for IndexVariable

xarray/namedarray/core.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def attrs(self) -> dict[Any, Any]:
511511

512512
@attrs.setter
513513
def attrs(self, value: Mapping[Any, Any]) -> None:
514-
self._attrs = dict(value)
514+
self._attrs = dict(value) if value else None
515515

516516
def _check_shape(self, new_data: duckarray[Any, _DType_co]) -> None:
517517
if new_data.shape != self.shape:
@@ -570,13 +570,12 @@ def real(
570570
return real(self)
571571
return self._new(data=self._data.real)
572572

573-
def __dask_tokenize__(self) -> Hashable:
573+
def __dask_tokenize__(self) -> object:
574574
# Use v.data, instead of v._data, in order to cope with the wrappers
575575
# around NetCDF and the like
576576
from dask.base import normalize_token
577577

578-
s, d, a, attrs = type(self), self._dims, self.data, self.attrs
579-
return normalize_token((s, d, a, attrs)) # type: ignore[no-any-return]
578+
return normalize_token((type(self), self._dims, self.data, self._attrs or None))
580579

581580
def __dask_graph__(self) -> Graph | None:
582581
if is_duck_dask_array(self._data):

xarray/namedarray/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __eq__(self, other: ReprObject | Any) -> bool:
218218
def __hash__(self) -> int:
219219
return hash((type(self), self._value))
220220

221-
def __dask_tokenize__(self) -> Hashable:
221+
def __dask_tokenize__(self) -> object:
222222
from dask.base import normalize_token
223223

224-
return normalize_token((type(self), self._value)) # type: ignore[no-any-return]
224+
return normalize_token((type(self), self._value))

xarray/tests/test_dask.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,6 @@ def test_persist(self):
299299
self.assertLazyAndAllClose(u + 1, v)
300300
self.assertLazyAndAllClose(u + 1, v2)
301301

302-
def test_tokenize_empty_attrs(self) -> None:
303-
# Issue #6970
304-
assert self.eager_var._attrs is None
305-
expected = dask.base.tokenize(self.eager_var)
306-
assert self.eager_var.attrs == self.eager_var._attrs == {}
307-
assert (
308-
expected
309-
== dask.base.tokenize(self.eager_var)
310-
== dask.base.tokenize(self.lazy_var.compute())
311-
)
312-
313302
@requires_pint
314303
def test_tokenize_duck_dask_array(self):
315304
import pint
@@ -1573,6 +1562,30 @@ def test_token_identical(obj, transform):
15731562
)
15741563

15751564

1565+
@pytest.mark.parametrize(
1566+
"obj",
1567+
[
1568+
make_ds(), # Dataset
1569+
make_ds().variables["c2"], # Variable
1570+
make_ds().variables["x"], # IndexVariable
1571+
],
1572+
)
1573+
def test_tokenize_empty_attrs(obj):
1574+
"""Issues #6970 and #8788"""
1575+
obj.attrs = {}
1576+
assert obj._attrs is None
1577+
a = dask.base.tokenize(obj)
1578+
1579+
assert obj.attrs == {}
1580+
assert obj._attrs == {} # attrs getter changed None to dict
1581+
b = dask.base.tokenize(obj)
1582+
assert a == b
1583+
1584+
obj2 = obj.copy()
1585+
c = dask.base.tokenize(obj2)
1586+
assert a == c
1587+
1588+
15761589
def test_recursive_token():
15771590
"""Test that tokenization is invoked recursively, and doesn't just rely on the
15781591
output of str()

xarray/tests/test_sparse.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -878,10 +878,6 @@ def test_dask_token():
878878
import dask
879879

880880
s = sparse.COO.from_numpy(np.array([0, 0, 1, 2]))
881-
882-
# https://github.com/pydata/sparse/issues/300
883-
s.__dask_tokenize__ = lambda: dask.base.normalize_token(s.__dict__)
884-
885881
a = DataArray(s)
886882
t1 = dask.base.tokenize(a)
887883
t2 = dask.base.tokenize(a)

0 commit comments

Comments
 (0)