Skip to content

Commit 430d642

Browse files
authored
DatasetView.map fix keep_attrs (pydata#10219)
* DatasetView.map fix keep_attrs * fix comment * add changelog
1 parent aa9e2bd commit 430d642

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ Bug fixes
3939
- :py:meth:`~xarray.Dataset.to_stacked_array` now uses dimensions in order of appearance.
4040
This fixes the issue where using :py:meth:`~xarray.Dataset.transpose` before :py:meth:`~xarray.Dataset.to_stacked_array`
4141
had no effect. (Mentioned in :issue:`9921`)
42+
- Enable ``keep_attrs`` in ``DatasetView.map`` relevant for :py:func:`map_over_datasets` (:pull:`10219`)
43+
By `Mathias Hauser <https://github.com/mathause>`_.
4244

4345
Documentation
4446
~~~~~~~~~~~~~

xarray/core/datatree.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from xarray.core.indexes import Index, Indexes
4747
from xarray.core.options import OPTIONS as XR_OPTS
48+
from xarray.core.options import _get_keep_attrs
4849
from xarray.core.treenode import NamedNode, NodePath, zip_subtrees
4950
from xarray.core.types import Self
5051
from xarray.core.utils import (
@@ -421,13 +422,18 @@ def map( # type: ignore[override]
421422

422423
# Copied from xarray.Dataset so as not to call type(self), which causes problems (see https://github.com/xarray-contrib/datatree/issues/188).
423424
# TODO Refactor xarray upstream to avoid needing to overwrite this.
424-
# TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated
425+
if keep_attrs is None:
426+
keep_attrs = _get_keep_attrs(default=False)
425427
variables = {
426428
k: maybe_wrap_array(v, func(v, *args, **kwargs))
427429
for k, v in self.data_vars.items()
428430
}
431+
if keep_attrs:
432+
for k, v in variables.items():
433+
v._copy_attrs_from(self.data_vars[k])
434+
attrs = self.attrs if keep_attrs else None
429435
# return type(self)(variables, attrs=attrs)
430-
return Dataset(variables)
436+
return Dataset(variables, attrs=attrs)
431437

432438

433439
class DataTree(

xarray/tests/test_datatree.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,29 @@ def weighted_mean(ds):
10191019

10201020
weighted_mean(dt.dataset)
10211021

1022+
def test_map_keep_attrs(self) -> None:
1023+
# test DatasetView.map(..., keep_attrs=...)
1024+
data = xr.DataArray([1, 2, 3], dims="x", attrs={"da": "attrs"})
1025+
ds = xr.Dataset({"data": data}, attrs={"ds": "attrs"})
1026+
dt = DataTree(ds)
1027+
1028+
def func(ds):
1029+
# x.mean() removes the attrs of the data_vars
1030+
return ds.map(lambda x: x.mean(), keep_attrs=True)
1031+
1032+
result = xr.map_over_datasets(func, dt)
1033+
expected = dt.mean(keep_attrs=True)
1034+
xr.testing.assert_identical(result, expected)
1035+
1036+
# per default DatasetView.map does not keep attrs
1037+
def func(ds):
1038+
# x.mean() removes the attrs of the data_vars
1039+
return ds.map(lambda x: x.mean())
1040+
1041+
result = xr.map_over_datasets(func, dt)
1042+
expected = dt.mean()
1043+
xr.testing.assert_identical(result, expected.mean())
1044+
10221045

10231046
class TestAccess:
10241047
def test_attribute_access(self, create_test_datatree) -> None:

0 commit comments

Comments
 (0)