diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6839bdab347..1ebdd7c27cc 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -59,6 +59,15 @@ jobs: - env: "flaky" python-version: "3.10" os: ubuntu-latest + - env: "duckarrays" + python-version: "3.10" + os: "ubuntu-latest" + - env: "duckarrays" + python-version: "3.10" + os: "windows-latest" + - env: "duckarrays" + python-version: "3.10" + os: "macos-latest" steps: - uses: actions/checkout@v3 with: @@ -70,17 +79,20 @@ jobs: if [[ ${{ matrix.os }} == windows* ]] ; then echo "CONDA_ENV_FILE=ci/requirements/environment-windows.yml" >> $GITHUB_ENV - elif [[ "${{ matrix.env }}" != "" ]] ; + else + echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV + fi + + if [[ "${{ matrix.env }}" != "" ]] ; then - if [[ "${{ matrix.env }}" == "flaky" ]] ; - then + if [[ "${{ matrix.env }}" == "flaky" ]] ; then echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV echo "PYTEST_EXTRA_FLAGS=--run-flaky --run-network-tests" >> $GITHUB_ENV + elif [[ "${{ matrix.env }}" == "duckarrays" ]] ; then + echo "PYTEST_EXTRA_FLAGS=--run-duckarray-tests xarray/tests/duckarrays/" >> $GITHUB_ENV else echo "CONDA_ENV_FILE=ci/requirements/${{ matrix.env }}.yml" >> $GITHUB_ENV fi - else - echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV fi echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_ENV @@ -96,7 +108,7 @@ jobs: # We only want to install this on one run, because otherwise we'll have # duplicate annotations. - name: Install error reporter - if: ${{ matrix.os }} == 'ubuntu-latest' and ${{ matrix.python-version }} == '3.10' + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' && matrix.env == '' run: | python -m pip install pytest-github-actions-annotate-failures diff --git a/conftest.py b/conftest.py index 862a1a1d0bc..2391cb047c0 100644 --- a/conftest.py +++ b/conftest.py @@ -11,6 +11,11 @@ def pytest_addoption(parser): action="store_true", help="runs tests requiring a network connection", ) + parser.addoption( + "--run-duckarray-tests", + action="store_true", + help="runs the duckarray hypothesis tests", + ) def pytest_runtest_setup(item): @@ -21,6 +26,10 @@ def pytest_runtest_setup(item): pytest.skip( "set --run-network-tests to run test requiring an internet connection" ) + if "duckarrays" in item.keywords and not item.config.getoption( + "--run-duckarray-tests" + ): + pytest.skip("set --run-duckarray-tests option to run duckarray tests") @pytest.fixture(autouse=True) diff --git a/setup.cfg b/setup.cfg index af7d47c2b79..ff89f0a8f7f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -144,6 +144,7 @@ markers = flaky: flaky tests network: tests requiring a network connection slow: slow tests + duckarrays: duckarray tests [flake8] ignore = diff --git a/xarray/tests/conftest.py b/xarray/tests/conftest.py index 658c349cd74..31b59a85e25 100644 --- a/xarray/tests/conftest.py +++ b/xarray/tests/conftest.py @@ -12,6 +12,48 @@ def backend(request): return request.param +def pytest_configure(config): + config.addinivalue_line( + "markers", + "apply_marks(marks): function to attach marks to tests and test variants", + ) + + +def always_sequence(obj): + if not isinstance(obj, (list, tuple)): + obj = [obj] + + return obj + + +def pytest_collection_modifyitems(session, config, items): + for item in items: + mark = item.get_closest_marker("apply_marks") + if mark is None: + continue + + marks = mark.args[0] + if not isinstance(marks, dict): + continue + + possible_marks = marks.get(item.originalname) + if possible_marks is None: + continue + + if not isinstance(possible_marks, dict): + for mark in always_sequence(possible_marks): + item.add_marker(mark) + continue + + variant = item.name[len(item.originalname) :] + to_attach = possible_marks.get(variant) + if to_attach is None: + continue + + for mark in always_sequence(to_attach): + item.add_marker(mark) + + @pytest.fixture(params=[1]) def ds(request, backend): if request.param == 1: diff --git a/xarray/tests/duckarrays/__init__.py b/xarray/tests/duckarrays/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/xarray/tests/duckarrays/base/__init__.py b/xarray/tests/duckarrays/base/__init__.py new file mode 100644 index 00000000000..5437e73b515 --- /dev/null +++ b/xarray/tests/duckarrays/base/__init__.py @@ -0,0 +1,7 @@ +from .reduce import DataArrayReduceTests, DatasetReduceTests, VariableReduceTests + +__all__ = [ + "VariableReduceTests", + "DataArrayReduceTests", + "DatasetReduceTests", +] diff --git a/xarray/tests/duckarrays/base/reduce.py b/xarray/tests/duckarrays/base/reduce.py new file mode 100644 index 00000000000..4e4a7409a85 --- /dev/null +++ b/xarray/tests/duckarrays/base/reduce.py @@ -0,0 +1,144 @@ +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import given, note, settings + +from ... import assert_identical +from . import strategies + + +class VariableReduceTests: + def check_reduce(self, obj, op, *args, **kwargs): + actual = getattr(obj, op)(*args, **kwargs) + + data = np.asarray(obj.data) + expected = getattr(obj.copy(data=data), op)(*args, **kwargs) + + note(f"actual:\n{actual}") + note(f"expected:\n{expected}") + + assert_identical(actual, expected) + + @staticmethod + def create(shape, dtypes): + return strategies.numpy_array(shape) + + @pytest.mark.parametrize( + "method", + ( + "all", + "any", + "cumprod", + "cumsum", + "max", + "mean", + "median", + "min", + "prod", + "std", + "sum", + "var", + ), + ) + @given(st.data()) + @settings(deadline=None) + def test_reduce(self, method, data): + var = data.draw( + strategies.variable(lambda shape, dtypes: self.create(shape, dtypes)) + ) + + reduce_dims = data.draw(strategies.valid_dims(var.dims)) + + self.check_reduce(var, method, dim=reduce_dims) + + +class DataArrayReduceTests: + def check_reduce(self, obj, op, *args, **kwargs): + actual = getattr(obj, op)(*args, **kwargs) + + data = np.asarray(obj.data) + expected = getattr(obj.copy(data=data), op)(*args, **kwargs) + + note(f"actual:\n{actual}") + note(f"expected:\n{expected}") + + assert_identical(actual, expected) + + @staticmethod + def create(op, shape, dtypes): + return strategies.numpy_array(shape, dtypes) + + @pytest.mark.parametrize( + "method", + ( + "all", + "any", + "cumprod", + "cumsum", + "max", + "mean", + "median", + "min", + "prod", + "std", + "sum", + "var", + ), + ) + @given(st.data()) + @settings(deadline=None) + def test_reduce(self, method, data): + arr = data.draw( + strategies.data_array(lambda shape, dtypes: self.create(shape, dtypes)) + ) + + reduce_dims = data.draw(strategies.valid_dims(arr.dims)) + + self.check_reduce(arr, method, dim=reduce_dims) + + +class DatasetReduceTests: + def check_reduce(self, obj, op, *args, **kwargs): + actual = getattr(obj, op)(*args, **kwargs) + + data = {name: np.asarray(obj.data) for name, obj in obj.variables.items()} + expected = getattr(obj.copy(data=data), op)(*args, **kwargs) + + note(f"actual:\n{actual}") + note(f"expected:\n{expected}") + + assert_identical(actual, expected) + + @staticmethod + def create(shape, dtypes): + return strategies.numpy_array(shape, dtypes) + + @pytest.mark.parametrize( + "method", + ( + "all", + "any", + "cumprod", + "cumsum", + "max", + "mean", + "median", + "min", + "prod", + "std", + "sum", + "var", + ), + ) + @given(st.data()) + @settings(deadline=None) + def test_reduce(self, method, data): + ds = data.draw( + strategies.dataset( + lambda shape, dtypes: self.create(shape, dtypes), max_size=5 + ) + ) + + reduce_dims = data.draw(strategies.valid_dims(ds.dims)) + + self.check_reduce(ds, method, dim=reduce_dims) diff --git a/xarray/tests/duckarrays/base/strategies.py b/xarray/tests/duckarrays/base/strategies.py new file mode 100644 index 00000000000..42eee29b554 --- /dev/null +++ b/xarray/tests/duckarrays/base/strategies.py @@ -0,0 +1,161 @@ +import hypothesis.extra.numpy as npst +import hypothesis.strategies as st + +import xarray as xr +from xarray.core.utils import is_dict_like + +from . import utils + +all_dtypes = ( + npst.integer_dtypes() + | npst.unsigned_integer_dtypes() + | npst.floating_dtypes() + | npst.complex_number_dtypes() +) + + +def numpy_array(shape, dtypes=None): + if dtypes is None: + dtypes = all_dtypes + + def elements(dtype): + max_value = 100 + min_value = 0 if dtype.kind == "u" else -max_value + + return npst.from_dtype( + dtype, allow_infinity=False, min_value=min_value, max_value=max_value + ) + + return dtypes.flatmap( + lambda dtype: npst.arrays(dtype=dtype, shape=shape, elements=elements(dtype)) + ) + + +def dimension_sizes(min_dims, max_dims, min_size, max_size): + sizes = st.lists( + elements=st.tuples(st.text(min_size=1), st.integers(min_size, max_size)), + min_size=min_dims, + max_size=max_dims, + unique_by=lambda x: x[0], + ) + return sizes + + +@st.composite +def variable( + draw, + create_data, + *, + sizes=None, + min_size=1, + max_size=3, + min_dims=1, + max_dims=3, + dtypes=None, +): + if sizes is None: + sizes = draw( + dimension_sizes( + min_size=min_size, + max_size=max_size, + min_dims=min_dims, + max_dims=max_dims, + ) + ) + + if not sizes: + dims = () + shape = () + else: + dims, shape = zip(*sizes) + data = create_data(shape, dtypes) + + return xr.Variable(dims, draw(data)) + + +@st.composite +def data_array( + draw, create_data, *, min_dims=1, max_dims=3, min_size=1, max_size=3, dtypes=None +): + name = draw(st.none() | st.text(min_size=1)) + if dtypes is None: + dtypes = all_dtypes + + sizes = st.lists( + elements=st.tuples(st.text(min_size=1), st.integers(min_size, max_size)), + min_size=min_dims, + max_size=max_dims, + unique_by=lambda x: x[0], + ) + drawn_sizes = draw(sizes) + dims, shape = zip(*drawn_sizes) + + data = draw(create_data(shape, dtypes)) + + return xr.DataArray( + data=data, + name=name, + dims=dims, + ) + + +@st.composite +def dataset( + draw, + create_data, + *, + min_dims=1, + max_dims=3, + min_size=1, + max_size=3, + min_vars=1, + max_vars=3, +): + dtypes = st.just(draw(all_dtypes)) + names = st.text(min_size=1) + sizes = dimension_sizes( + min_size=min_size, max_size=max_size, min_dims=min_dims, max_dims=max_dims + ) + + data_vars = sizes.flatmap( + lambda s: st.dictionaries( + keys=names.filter(lambda n: n not in dict(s)), + values=variable(create_data, sizes=s, dtypes=dtypes), + min_size=min_vars, + max_size=max_vars, + ) + ) + + return xr.Dataset(data_vars=draw(data_vars)) + + +def valid_axis(ndim): + if ndim == 0: + return st.none() | st.just(0) + return st.none() | st.integers(-ndim, ndim - 1) + + +def valid_axes(ndim): + return valid_axis(ndim) | npst.valid_tuple_axes(ndim, min_size=1) + + +def valid_dim(dims): + if not isinstance(dims, list): + dims = [dims] + + ndim = len(dims) + axis = valid_axis(ndim) + return axis.map(lambda axes: utils.valid_dims_from_axes(dims, axes)) + + +def valid_dims(dims): + if is_dict_like(dims): + dims = list(dims.keys()) + elif isinstance(dims, tuple): + dims = list(dims) + elif not isinstance(dims, list): + dims = [dims] + + ndim = len(dims) + axes = valid_axes(ndim) + return axes.map(lambda axes: utils.valid_dims_from_axes(dims, axes)) diff --git a/xarray/tests/duckarrays/base/utils.py b/xarray/tests/duckarrays/base/utils.py new file mode 100644 index 00000000000..2bd353e2116 --- /dev/null +++ b/xarray/tests/duckarrays/base/utils.py @@ -0,0 +1,36 @@ +import warnings +from contextlib import contextmanager + + +@contextmanager +def suppress_warning(category, message=""): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=category, message=message) + + yield + + +def create_dimension_names(ndim): + return [f"dim_{n}" for n in range(ndim)] + + +def valid_dims_from_axes(dims, axes): + if axes is None: + return None + + if axes == 0 and len(dims) == 0: + return None + + if isinstance(axes, int): + return dims[axes] + + return [dims[axis] for axis in axes] + + +def valid_axes_from_dims(all_dims, dims): + if dims is None: + return None + elif isinstance(dims, list): + return [all_dims.index(dim) for dim in dims] + else: + return all_dims.index(dims) diff --git a/xarray/tests/duckarrays/test_sparse.py b/xarray/tests/duckarrays/test_sparse.py new file mode 100644 index 00000000000..f4c425fff86 --- /dev/null +++ b/xarray/tests/duckarrays/test_sparse.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest + +from xarray import DataArray, Dataset, Variable + +# isort: off +# needs to stay here to avoid ImportError for the strategy imports +pytest.importorskip("hypothesis") +# isort: on + +from .. import assert_allclose +from . import base +from .base import strategies + +sparse = pytest.importorskip("sparse") + +pytestmark = [pytest.mark.duckarrays()] + + +@pytest.fixture(autouse=True) +def disable_bottleneck(): + from xarray import set_options + + with set_options(use_bottleneck=False): + yield + + +def create(shape, dtypes): + def convert(arr): + if arr.ndim == 0: + return arr + + return sparse.COO.from_numpy(arr) + + if dtypes is None: + dtypes = strategies.all_dtypes + + # sparse does not support float16, and there's a bug with complex64 (pydata/sparse#553) + sparse_dtypes = dtypes.filter( + lambda dtype: ( + not np.issubdtype(dtype, np.float16) + and not np.issubdtype(dtype, np.complexfloating) + ) + ) + return strategies.numpy_array(shape, sparse_dtypes).map(convert) + + +def as_dense(obj): + if isinstance(obj, (Variable, DataArray, Dataset)): + new_obj = obj.as_numpy() + else: + new_obj = obj + + return new_obj + + +@pytest.mark.apply_marks( + { + "test_reduce": { + "[cumprod]": pytest.mark.skip(reason="cumprod not implemented by sparse"), + "[cumsum]": pytest.mark.skip(reason="cumsum not implemented by sparse"), + "[median]": pytest.mark.skip(reason="median not implemented by sparse"), + "[std]": pytest.mark.skip(reason="nanstd not implemented by sparse"), + "[var]": pytest.mark.skip(reason="nanvar not implemented by sparse"), + } + } +) +class TestSparseVariableReduceMethods(base.VariableReduceTests): + @staticmethod + def create(shape, dtypes): + return create(shape, dtypes) + + def check_reduce(self, obj, op, *args, **kwargs): + actual = as_dense(getattr(obj, op)(*args, **kwargs)) + expected = getattr(as_dense(obj), op)(*args, **kwargs) + + assert_allclose(actual, expected) + + +@pytest.mark.apply_marks( + { + "test_reduce": { + "[cumprod]": pytest.mark.skip(reason="cumprod not implemented by sparse"), + "[cumsum]": pytest.mark.skip(reason="cumsum not implemented by sparse"), + "[median]": pytest.mark.skip(reason="median not implemented by sparse"), + "[std]": pytest.mark.skip(reason="nanstd not implemented by sparse"), + "[var]": pytest.mark.skip(reason="nanvar not implemented by sparse"), + } + } +) +class TestSparseDataArrayReduceMethods(base.DataArrayReduceTests): + @staticmethod + def create(shape, dtypes): + return create(shape, dtypes) + + def check_reduce(self, obj, op, *args, **kwargs): + actual = as_dense(getattr(obj, op)(*args, **kwargs)) + expected = getattr(as_dense(obj), op)(*args, **kwargs) + + assert_allclose(actual, expected) + + +@pytest.mark.apply_marks( + { + "test_reduce": { + "[cumprod]": pytest.mark.skip(reason="cumprod not implemented by sparse"), + "[cumsum]": pytest.mark.skip(reason="cumsum not implemented by sparse"), + "[median]": pytest.mark.skip(reason="median not implemented by sparse"), + "[std]": pytest.mark.skip(reason="nanstd not implemented by sparse"), + "[var]": pytest.mark.skip(reason="nanvar not implemented by sparse"), + } + } +) +class TestSparseDatasetReduceMethods(base.DatasetReduceTests): + @staticmethod + def create(shape, dtypes): + return create(shape, dtypes) + + def check_reduce(self, obj, op, *args, **kwargs): + actual = as_dense(getattr(obj, op)(*args, **kwargs)) + expected = getattr(as_dense(obj), op)(*args, **kwargs) + + assert_allclose(actual, expected) diff --git a/xarray/tests/duckarrays/test_units.py b/xarray/tests/duckarrays/test_units.py new file mode 100644 index 00000000000..3f68d5fdaea --- /dev/null +++ b/xarray/tests/duckarrays/test_units.py @@ -0,0 +1,184 @@ +import numpy as np +import pytest + +# isort: off +# needs to stay here to avoid ImportError for the hypothesis imports +pytest.importorskip("hypothesis") +# isort: on + +import hypothesis.strategies as st +from hypothesis import note + +from .. import assert_allclose +from ..test_units import assert_units_equal, attach_units, strip_units +from . import base +from .base import strategies, utils + +pint = pytest.importorskip("pint") +unit_registry = pint.UnitRegistry(force_ndarray_like=True) +Quantity = unit_registry.Quantity + +pytestmark = [ + pytest.mark.duckarrays(), + pytest.mark.filterwarnings("error::pint.UnitStrippedWarning"), +] + + +@pytest.fixture(autouse=True) +def disable_bottleneck(): + from xarray import set_options + + with set_options(use_bottleneck=False): + yield + + +all_units = st.sampled_from(["m", "mm", "s", "dimensionless"]) + +tolerances = { + np.float64: 1e-8, + np.float32: 1e-4, + np.float16: 1e-2, + np.complex128: 1e-8, + np.complex64: 1e-4, +} + + +def apply_func(op, var, *args, **kwargs): + dim = kwargs.pop("dim", None) + if dim in var.dims: + axis = utils.valid_axes_from_dims(var.dims, dim) + else: + axis = None + kwargs["axis"] = axis + + arr = var.data + func_name = f"nan{op}" if arr.dtype.kind in "fc" else op + func = getattr(np, func_name, getattr(np, op)) + with utils.suppress_warning(RuntimeWarning): + result = func(arr, *args, **kwargs) + + return getattr(result, "units", None) + + +@pytest.mark.apply_marks( + { + "test_reduce": { + "[prod]": pytest.mark.skip(reason="inconsistent implementation in pint"), + } + } +) +class TestPintVariableReduceMethods(base.VariableReduceTests): + @st.composite + @staticmethod + def create(draw, shape, dtypes): + return Quantity(draw(strategies.numpy_array(shape, dtypes)), draw(all_units)) + + def compute_expected(self, obj, op, *args, **kwargs): + without_units = strip_units(obj) + expected = getattr(without_units, op)(*args, **kwargs) + + units = apply_func(op, obj, *args, **kwargs) + return attach_units(expected, {None: units}) + + def check_reduce(self, obj, op, *args, **kwargs): + if ( + op in ("cumprod",) + and getattr(obj.data, "units", None) != unit_registry.dimensionless + ): + with pytest.raises(pint.DimensionalityError): + getattr(obj, op)(*args, **kwargs) + else: + actual = getattr(obj, op)(*args, **kwargs) + + note(f"actual:\n{actual}") + + expected = self.compute_expected(obj, op, *args, **kwargs) + + note(f"expected:\n{expected}") + + assert_units_equal(actual, expected) + assert_allclose(actual, expected) + + +@pytest.mark.apply_marks( + { + "test_reduce": { + "[prod]": pytest.mark.skip(reason="inconsistent implementation in pint"), + } + } +) +class TestPintDataArrayReduceMethods(base.DataArrayReduceTests): + @st.composite + @staticmethod + def create(draw, shape, dtypes): + return Quantity(draw(strategies.numpy_array(shape, dtypes)), draw(all_units)) + + def compute_expected(self, obj, op, *args, **kwargs): + without_units = strip_units(obj) + expected = getattr(without_units, op)(*args, **kwargs) + units = apply_func(op, obj.variable, *args, **kwargs) + + return attach_units(expected, {obj.name: units}) + + def check_reduce(self, obj, op, *args, **kwargs): + if ( + op in ("cumprod",) + and getattr(obj.data, "units", None) != unit_registry.dimensionless + ): + with pytest.raises(pint.DimensionalityError): + getattr(obj, op)(*args, **kwargs) + else: + actual = getattr(obj, op)(*args, **kwargs) + + note(f"actual:\n{actual}") + + expected = self.compute_expected(obj, op, *args, **kwargs) + + note(f"expected:\n{expected}") + + assert_units_equal(actual, expected) + tol = tolerances.get(obj.dtype.name, 1e-8) + assert_allclose(actual, expected, atol=tol) + + +@pytest.mark.apply_marks( + { + "test_reduce": { + "[prod]": pytest.mark.skip(reason="inconsistent implementation in pint"), + } + } +) +class TestPintDatasetReduceMethods(base.DatasetReduceTests): + @st.composite + @staticmethod + def create(draw, shape, dtypes): + return Quantity(draw(strategies.numpy_array(shape, dtypes)), draw(all_units)) + + def compute_expected(self, obj, op, *args, **kwargs): + without_units = strip_units(obj) + result_without_units = getattr(without_units, op)(*args, **kwargs) + units = { + name: apply_func(op, var, *args, **kwargs) + for name, var in obj.variables.items() + } + attached = attach_units(result_without_units, units) + return attached + + def check_reduce(self, obj, op, *args, **kwargs): + if op in ("cumprod",) and any( + getattr(var.data, "units", None) != unit_registry.dimensionless + for var in obj.data_vars.values() + ): + with pytest.raises(pint.DimensionalityError): + getattr(obj, op)(*args, **kwargs) + else: + actual = getattr(obj, op)(*args, **kwargs) + + note(f"actual:\n{actual}") + + expected = self.compute_expected(obj, op, *args, **kwargs) + + note(f"expected:\n{expected}") + + assert_units_equal(actual, expected) + assert_allclose(actual, expected) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 52c50e28931..2a076b97578 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -93,11 +93,14 @@ def array_strip_units(array): def array_attach_units(data, unit): + if unit is None or (isinstance(unit, int) and unit == 1): + return data + if isinstance(data, Quantity): raise ValueError(f"cannot attach unit {unit} to quantity {data}") try: - quantity = data * unit + quantity = unit._REGISTRY.Quantity(data, unit) except np.core._exceptions.UFuncTypeError: if isinstance(unit, unit_registry.Unit): raise @@ -182,36 +185,40 @@ def attach_units(obj, units): return array_attach_units(obj, units) if isinstance(obj, xr.Dataset): - data_vars = { - name: attach_units(value, units) for name, value in obj.data_vars.items() + variables = { + name: attach_units(value, {None: units.get(name)}) + for name, value in obj.variables.items() } - coords = { - name: attach_units(value, units) for name, value in obj.coords.items() + name: var for name, var in variables.items() if name in obj._coord_names + } + data_vars = { + name: var for name, var in variables.items() if name not in obj._coord_names } - new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) elif isinstance(obj, xr.DataArray): # try the array name, "data" and None, then fall back to dimensionless - data_units = units.get(obj.name, None) or units.get(None, None) or 1 - - data = array_attach_units(obj.data, data_units) + units = units.copy() + THIS_ARRAY = xr.core.dataarray._THIS_ARRAY + unset = object() + if obj.name in units: + name = obj.name + elif None in units: + name = None + else: + name = unset - coords = { - name: ( - (value.dims, array_attach_units(value.data, units.get(name) or 1)) - if name in units - else (value.dims, value.data) - ) - for name, value in obj.coords.items() - } - dims = obj.dims - attrs = obj.attrs + if name is not unset: + units[THIS_ARRAY] = units.pop(name) - new_obj = xr.DataArray( - name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims - ) + ds = obj._to_temp_dataset() + attached = attach_units(ds, units) + new_obj = obj._from_temp_dataset(attached, name=obj.name) else: + if isinstance(obj, xr.IndexVariable): + # no units for index variables + return obj + data_units = units.get("data", None) or units.get(None, None) or 1 data = array_attach_units(obj.data, data_units)