Skip to content

Commit 54946eb

Browse files
Fix dataarray drop attrs (#10030)
* Fix DataArray().drop_attrs(deep=False) * Add DataArray().drop_attrs() tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply small cosmetics * Add support for attrs to DataArray()._replace * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove testing relict * Fix (try) incompatible types mypy error * Fix (2.try) incompatible types mypy error * Update whats-new * Fix replacing simultaneously passed variable * Add DataArray()._replace() tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent df2ecf4 commit 54946eb

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ Bug fixes
4242
- Use mean of min/max years as offset in calculation of datetime64 mean
4343
(:issue:`10019`, :pull:`10035`).
4444
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
45+
- Fix DataArray().drop_attrs(deep=False) and add support for attrs to
46+
DataArray()._replace(). (:issue:`10027`, :pull:`10030`). By `Jan
47+
Haacker <https://github.com/j-haacker>`_.
4548

4649

4750
Documentation

xarray/core/dataarray.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import copy
34
import datetime
45
import warnings
56
from collections.abc import (
@@ -522,6 +523,7 @@ def _replace(
522523
variable: Variable | None = None,
523524
coords=None,
524525
name: Hashable | None | Default = _default,
526+
attrs=_default,
525527
indexes=None,
526528
) -> Self:
527529
if variable is None:
@@ -532,6 +534,11 @@ def _replace(
532534
indexes = self._indexes
533535
if name is _default:
534536
name = self.name
537+
if attrs is _default:
538+
attrs = copy.copy(self.attrs)
539+
else:
540+
variable = variable.copy()
541+
variable.attrs = attrs
535542
return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True)
536543

537544
def _replace_maybe_drop_dims(
@@ -7575,6 +7582,11 @@ def drop_attrs(self, *, deep: bool = True) -> Self:
75757582
-------
75767583
DataArray
75777584
"""
7578-
return (
7579-
self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset)
7580-
)
7585+
if not deep:
7586+
return self._replace(attrs={})
7587+
else:
7588+
return (
7589+
self._to_temp_dataset()
7590+
.drop_attrs(deep=deep)
7591+
.pipe(self._from_temp_dataset)
7592+
)

xarray/tests/test_dataarray.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1908,6 +1908,21 @@ def test_rename_dimension_coord_warnings(self) -> None:
19081908
warnings.simplefilter("error")
19091909
da.rename(x="x")
19101910

1911+
def test_replace(self) -> None:
1912+
# Tests the `attrs` replacement and whether it interferes with a
1913+
# `variable` replacement
1914+
da = self.mda
1915+
attrs1 = {"a1": "val1", "a2": 161}
1916+
x = np.ones((10, 20))
1917+
v = Variable(["x", "y"], x)
1918+
assert da._replace(variable=v, attrs=attrs1).attrs == attrs1
1919+
attrs2 = {"b1": "val2", "b2": 1312}
1920+
va = Variable(["x", "y"], x, attrs2)
1921+
# assuming passed `attrs` should prevail
1922+
assert da._replace(variable=va, attrs=attrs1).attrs == attrs1
1923+
# assuming `va.attrs` should be adopted
1924+
assert da._replace(variable=va).attrs == attrs2
1925+
19111926
def test_init_value(self) -> None:
19121927
expected = DataArray(
19131928
np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)]
@@ -2991,8 +3006,11 @@ def test_assign_attrs(self) -> None:
29913006

29923007
def test_drop_attrs(self) -> None:
29933008
# Mostly tested in test_dataset.py, but adding a very small test here
2994-
da = DataArray([], attrs=dict(a=1, b=2))
3009+
coord_ = DataArray([], attrs=dict(d=3, e=4))
3010+
da = DataArray([], attrs=dict(a=1, b=2)).assign_coords(dict(coord_=coord_))
29953011
assert da.drop_attrs().attrs == {}
3012+
assert da.drop_attrs().coord_.attrs == {}
3013+
assert da.drop_attrs(deep=False).coord_.attrs == dict(d=3, e=4)
29963014

29973015
@pytest.mark.parametrize(
29983016
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]

0 commit comments

Comments
 (0)