diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ba692e4..40abd8c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -65,6 +65,10 @@ jobs: python=${{matrix.python-version}} conda + - name: Install nightly xarray + run: | + python -m pip install --upgrade --pre -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple xarray + - name: Install xarray-array-testing run: | python -m pip install --no-deps -e . diff --git a/xarray_array_testing/reduction.py b/xarray_array_testing/reduction.py index fb4b158..38b37e9 100644 --- a/xarray_array_testing/reduction.py +++ b/xarray_array_testing/reduction.py @@ -1,6 +1,7 @@ from contextlib import nullcontext import hypothesis.strategies as st +import numpy as np import pytest import xarray.testing.strategies as xrst from hypothesis import given @@ -24,4 +25,76 @@ def test_variable_numerical_reduce(self, op, data): # compute using xp.(array) expected = getattr(self.xp, op)(variable.data) + assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}" + self.assert_equal(actual, expected) + + @pytest.mark.parametrize("op", ["all", "any"]) + @given(st.data()) + def test_variable_boolean_reduce(self, op, data): + variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) + + with self.expected_errors(op, variable=variable): + # compute using xr.Variable.() + actual = getattr(variable, op)().data + # compute using xp.(array) + expected = getattr(self.xp, op)(variable.data) + + assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}" + self.assert_equal(actual, expected) + + @pytest.mark.parametrize("op", ["max", "min"]) + @given(st.data()) + def test_variable_order_reduce(self, op, data): + variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) + + with self.expected_errors(op, variable=variable): + # compute using xr.Variable.() + actual = getattr(variable, op)().data + # compute using xp.(array) + expected = getattr(self.xp, op)(variable.data) + + assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}" + self.assert_equal(actual, expected) + + @pytest.mark.parametrize("op", ["argmax", "argmin"]) + @given(st.data()) + def test_variable_order_reduce_index(self, op, data): + variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) + + with self.expected_errors(op, variable=variable): + # compute using xr.Variable.() + actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()} + + # compute using xp.(array) + index = getattr(self.xp, op)(variable.data) + unraveled = np.unravel_index(index, variable.shape) + expected = dict(zip(variable.dims, unraveled)) + + self.assert_equal(actual, expected) + + @pytest.mark.parametrize( + "op", + [ + "cumsum", + pytest.param( + "cumprod", + marks=pytest.mark.skip(reason="not yet included in the array api"), + ), + ], + ) + @given(st.data()) + def test_variable_cumulative_reduce(self, op, data): + array_api_names = {"cumsum": "cumulative_sum", "cumprod": "cumulative_prod"} + variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn)) + + with self.expected_errors(op, variable=variable): + # compute using xr.Variable.() + actual = getattr(variable, op)().data + # compute using xp.(array) + # Variable implements n-d cumulative ops by iterating over dims + expected = variable.data + for axis in range(variable.ndim): + expected = getattr(self.xp, array_api_names[op])(expected, axis=axis) + + assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}" self.assert_equal(actual, expected)