Skip to content

Commit 3a75f6f

Browse files
authored
assert_equal: ensure check_dim_order=False works for DataTree (#10442)
* assert_equal: ensure check_dim_order=False works for DataTree * typing * changelog
1 parent 4071232 commit 3a75f6f

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Bug fixes
3434
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
3535
- Fix the error message of :py:func:`testing.assert_equal` when two different :py:class:`DataTree` objects
3636
are passed (:pull:`10440`). By `Mathias Hauser <https://github.com/mathause>`_.
37+
- Fix :py:func:`testing.assert_equal` with ``check_dim_order=False`` for :py:class:`DataTree` objects
38+
(:pull:`10442`). By `Mathias Hauser <https://github.com/mathause>`_.
3739

3840

3941
Documentation

xarray/testing/assertions.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from xarray.core.dataarray import DataArray
1313
from xarray.core.dataset import Dataset
1414
from xarray.core.datatree import DataTree
15+
from xarray.core.datatree_mapping import map_over_datasets
1516
from xarray.core.formatting import diff_datatree_repr
1617
from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes
1718
from xarray.core.variable import IndexVariable, Variable
@@ -85,14 +86,25 @@ def assert_isomorphic(a: DataTree, b: DataTree):
8586

8687
def maybe_transpose_dims(a, b, check_dim_order: bool):
8788
"""Helper for assert_equal/allclose/identical"""
89+
8890
__tracebackhide__ = True
89-
if not isinstance(a, Variable | DataArray | Dataset):
91+
92+
def _maybe_transpose_dims(a, b):
93+
if not isinstance(a, Variable | DataArray | Dataset):
94+
return b
95+
if set(a.dims) == set(b.dims):
96+
# Ensure transpose won't fail if a dimension is missing
97+
# If this is the case, the difference will be caught by the caller
98+
return b.transpose(*a.dims)
99+
return b
100+
101+
if check_dim_order:
90102
return b
91-
if not check_dim_order and set(a.dims) == set(b.dims):
92-
# Ensure transpose won't fail if a dimension is missing
93-
# If this is the case, the difference will be caught by the caller
94-
return b.transpose(*a.dims)
95-
return b
103+
104+
if isinstance(a, DataTree):
105+
return map_over_datasets(_maybe_transpose_dims, a, b)
106+
107+
return _maybe_transpose_dims(a, b)
96108

97109

98110
@ensure_warnings

xarray/tests/test_assertions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def test_assert_allclose_equal_transpose(func) -> None:
8888
getattr(xr.testing, func)(ds1, ds2, check_dim_order=False)
8989

9090

91+
def test_assert_equal_transpose_datatree() -> None:
92+
"""Ensure `check_dim_order=False` works for transposed DataTree"""
93+
ds = xr.Dataset(data_vars={"data": (("x", "y"), [[1, 2]])})
94+
95+
a = xr.DataTree.from_dict({"node": ds})
96+
b = xr.DataTree.from_dict({"node": ds.transpose("y", "x")})
97+
98+
with pytest.raises(AssertionError):
99+
xr.testing.assert_equal(a, b)
100+
101+
xr.testing.assert_equal(a, b, check_dim_order=False)
102+
103+
91104
@pytest.mark.filterwarnings("error")
92105
@pytest.mark.parametrize(
93106
"duckarray",

0 commit comments

Comments
 (0)