Skip to content

Commit 179c670

Browse files
authored
Fix two bugs in DataTree.update() (#9214)
* Fix two bugs in DataTree.update() 1. Fix handling of coordinates on a Dataset argument (previously these were silently dropped). 2. Do not copy inherited coordinates down to lower level nodes. * add mypy annotation
1 parent bac01c0 commit 179c670

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

xarray/core/datatree.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
import pandas as pd
6262

6363
from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
64-
from xarray.core.merge import CoercibleValue
64+
from xarray.core.merge import CoercibleMapping, CoercibleValue
6565
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes
6666

6767
# """
@@ -954,23 +954,29 @@ def update(
954954
955955
Just like `dict.update` this is an in-place operation.
956956
"""
957-
# TODO separate by type
958957
new_children: dict[str, DataTree] = {}
959-
new_variables = {}
960-
for k, v in other.items():
961-
if isinstance(v, DataTree):
962-
# avoid named node being stored under inconsistent key
963-
new_child: DataTree = v.copy()
964-
# Datatree's name is always a string until we fix that (#8836)
965-
new_child.name = str(k)
966-
new_children[str(k)] = new_child
967-
elif isinstance(v, (DataArray, Variable)):
968-
# TODO this should also accommodate other types that can be coerced into Variables
969-
new_variables[k] = v
970-
else:
971-
raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree")
972-
973-
vars_merge_result = dataset_update_method(self.to_dataset(), new_variables)
958+
new_variables: CoercibleMapping
959+
960+
if isinstance(other, Dataset):
961+
new_variables = other
962+
else:
963+
new_variables = {}
964+
for k, v in other.items():
965+
if isinstance(v, DataTree):
966+
# avoid named node being stored under inconsistent key
967+
new_child: DataTree = v.copy()
968+
# Datatree's name is always a string until we fix that (#8836)
969+
new_child.name = str(k)
970+
new_children[str(k)] = new_child
971+
elif isinstance(v, (DataArray, Variable)):
972+
# TODO this should also accommodate other types that can be coerced into Variables
973+
new_variables[k] = v
974+
else:
975+
raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree")
976+
977+
vars_merge_result = dataset_update_method(
978+
self.to_dataset(inherited=False), new_variables
979+
)
974980
data = Dataset._construct_direct(**vars_merge_result._asdict())
975981

976982
# TODO are there any subtleties with preserving order of children like this?

xarray/tests/test_datatree.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,6 @@ def test_update(self):
244244
dt: DataTree = DataTree()
245245
dt.update({"foo": xr.DataArray(0), "a": DataTree()})
246246
expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None})
247-
print(dt)
248-
print(dt.children)
249-
print(dt._children)
250-
print(dt["a"])
251-
print(expected)
252247
assert_equal(dt, expected)
253248

254249
def test_update_new_named_dataarray(self):
@@ -268,14 +263,38 @@ def test_update_doesnt_alter_child_name(self):
268263
def test_update_overwrite(self):
269264
actual = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 1}))})
270265
actual.update({"a": DataTree(xr.Dataset({"x": 2}))})
271-
272266
expected = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 2}))})
267+
assert_equal(actual, expected)
273268

274-
print(actual)
275-
print(expected)
276-
269+
def test_update_coordinates(self):
270+
expected = DataTree.from_dict({"/": xr.Dataset(coords={"a": 1})})
271+
actual = DataTree.from_dict({"/": xr.Dataset()})
272+
actual.update(xr.Dataset(coords={"a": 1}))
277273
assert_equal(actual, expected)
278274

275+
def test_update_inherited_coords(self):
276+
expected = DataTree.from_dict(
277+
{
278+
"/": xr.Dataset(coords={"a": 1}),
279+
"/b": xr.Dataset(coords={"c": 1}),
280+
}
281+
)
282+
actual = DataTree.from_dict(
283+
{
284+
"/": xr.Dataset(coords={"a": 1}),
285+
"/b": xr.Dataset(),
286+
}
287+
)
288+
actual["/b"].update(xr.Dataset(coords={"c": 1}))
289+
assert_identical(actual, expected)
290+
291+
# DataTree.identical() currently does not require that non-inherited
292+
# coordinates are defined identically, so we need to check this
293+
# explicitly
294+
actual_node = actual.children["b"].to_dataset(inherited=False)
295+
expected_node = expected.children["b"].to_dataset(inherited=False)
296+
assert_identical(actual_node, expected_node)
297+
279298

280299
class TestCopy:
281300
def test_copy(self, create_test_datatree):

0 commit comments

Comments
 (0)