Skip to content

Commit 68b040a

Browse files
TomNicholasflamingbearpre-commit-ci[bot]
authored
Shallow copy parent and children in DataTree constructor (#9297)
* add tests * fix by shallow copying * correct first few tests * replace constructors in tests with DataTree.from_dict * rewrite simple_datatree fixture to use DataTree.from_dict * fix incorrect creation of nested tree in formatting test * Update doctests for from_dict constructor * swap i and h in doctest example for clarity. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix a few mypy errors. * Bonkers way to set type checking I will happily take something better. But this was the error I was getting xarray/tests/test_datatree.py:127: error: Argument 1 to "relative_to" of "NamedNode" has incompatible type "DataTree[Any] | DataArray"; expected "NamedNode[Any]" [arg-type] * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removes parent keyword from DataTree constructor But it doesn't fix all the tests There's three tests that I don't fully know what should be tested or if they still make sense. * fix test_setparent_unnamed_child_node_fails * fix test_dont_modify_parent_inplace -> bug? * fix test_create_two_children * make .parent read-only, and remove tests which test the parent setter * update error message to reflect fact that .children is Frozen * fix another test * add test that parent setter tells you to set children instead * fix mypy error due to overriding settable property with read-only property * fix test by not trying to set parent via kwarg --------- Co-authored-by: Matt Savoie <matthew.savoie@colorado.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 25debff commit 68b040a

File tree

9 files changed

+216
-161
lines changed

9 files changed

+216
-161
lines changed

xarray/core/datatree.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,6 @@ class DataTree(
423423
def __init__(
424424
self,
425425
data: Dataset | DataArray | None = None,
426-
parent: DataTree | None = None,
427426
children: Mapping[str, DataTree] | None = None,
428427
name: str | None = None,
429428
):
@@ -439,8 +438,6 @@ def __init__(
439438
data : Dataset, DataArray, or None, optional
440439
Data to store under the .ds attribute of this node. DataArrays will
441440
be promoted to Datasets. Default is None.
442-
parent : DataTree, optional
443-
Parent node to this node. Default is None.
444441
children : Mapping[str, DataTree], optional
445442
Any child nodes of this node. Default is None.
446443
name : str, optional
@@ -459,8 +456,9 @@ def __init__(
459456

460457
super().__init__(name=name)
461458
self._set_node_data(_coerce_to_dataset(data))
462-
self.parent = parent
463-
self.children = children
459+
460+
# shallow copy to avoid modifying arguments in-place (see GH issue #9196)
461+
self.children = {name: child.copy() for name, child in children.items()}
464462

465463
def _set_node_data(self, ds: Dataset):
466464
data_vars, coord_vars = _collect_data_and_coord_variables(ds)
@@ -497,17 +495,6 @@ def _dims(self) -> ChainMap[Hashable, int]:
497495
def _indexes(self) -> ChainMap[Hashable, Index]:
498496
return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents))
499497

500-
@property
501-
def parent(self: DataTree) -> DataTree | None:
502-
"""Parent of this node."""
503-
return self._parent
504-
505-
@parent.setter
506-
def parent(self: DataTree, new_parent: DataTree) -> None:
507-
if new_parent and self.name is None:
508-
raise ValueError("Cannot set an unnamed node as a child of another node")
509-
self._set_parent(new_parent, self.name)
510-
511498
def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView:
512499
variables = dict(self._data_variables)
513500
variables |= self._coord_variables
@@ -896,7 +883,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None:
896883
# create and assign a shallow copy here so as not to alter original name of node in grafted tree
897884
new_node = val.copy(deep=False)
898885
new_node.name = key
899-
new_node.parent = self
886+
new_node._set_parent(new_parent=self, child_name=key)
900887
else:
901888
if not isinstance(val, DataArray | Variable):
902889
# accommodate other types that can be coerced into Variables
@@ -1097,7 +1084,7 @@ def from_dict(
10971084
obj = root_data.copy()
10981085
obj.orphan()
10991086
else:
1100-
obj = cls(name=name, data=root_data, parent=None, children=None)
1087+
obj = cls(name=name, data=root_data, children=None)
11011088

11021089
def depth(item) -> int:
11031090
pathstr, _ = item

xarray/core/datatree_render.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,16 @@ def __init__(self):
5151
5252
>>> from xarray.core.datatree import DataTree
5353
>>> from xarray.core.datatree_render import RenderDataTree
54-
>>> root = DataTree(name="root")
55-
>>> s0 = DataTree(name="sub0", parent=root)
56-
>>> s0b = DataTree(name="sub0B", parent=s0)
57-
>>> s0a = DataTree(name="sub0A", parent=s0)
58-
>>> s1 = DataTree(name="sub1", parent=root)
54+
>>> root = DataTree.from_dict(
55+
... {
56+
... "/": None,
57+
... "/sub0": None,
58+
... "/sub0/sub0B": None,
59+
... "/sub0/sub0A": None,
60+
... "/sub1": None,
61+
... },
62+
... name="root",
63+
... )
5964
>>> print(RenderDataTree(root))
6065
<xarray.DataTree 'root'>
6166
Group: /
@@ -98,11 +103,16 @@ def __init__(
98103
>>> from xarray import Dataset
99104
>>> from xarray.core.datatree import DataTree
100105
>>> from xarray.core.datatree_render import RenderDataTree
101-
>>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1}))
102-
>>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3}))
103-
>>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4}))
104-
>>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6}))
105-
>>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7}))
106+
>>> root = DataTree.from_dict(
107+
... {
108+
... "/": Dataset({"a": 0, "b": 1}),
109+
... "/sub0": Dataset({"c": 2, "d": 3}),
110+
... "/sub0/sub0B": Dataset({"e": 4}),
111+
... "/sub0/sub0A": Dataset({"f": 5, "g": 6}),
112+
... "/sub1": Dataset({"h": 7}),
113+
... },
114+
... name="root",
115+
... )
106116
107117
# Simple one line:
108118
@@ -208,17 +218,16 @@ def by_attr(self, attrname: str = "name") -> str:
208218
>>> from xarray import Dataset
209219
>>> from xarray.core.datatree import DataTree
210220
>>> from xarray.core.datatree_render import RenderDataTree
211-
>>> root = DataTree(name="root")
212-
>>> s0 = DataTree(name="sub0", parent=root)
213-
>>> s0b = DataTree(
214-
... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109})
221+
>>> root = DataTree.from_dict(
222+
... {
223+
... "/sub0/sub0B": Dataset({"foo": 4, "bar": 109}),
224+
... "/sub0/sub0A": None,
225+
... "/sub1/sub1A": None,
226+
... "/sub1/sub1B": Dataset({"bar": 8}),
227+
... "/sub1/sub1C/sub1Ca": None,
228+
... },
229+
... name="root",
215230
... )
216-
>>> s0a = DataTree(name="sub0A", parent=s0)
217-
>>> s1 = DataTree(name="sub1", parent=root)
218-
>>> s1a = DataTree(name="sub1A", parent=s1)
219-
>>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8}))
220-
>>> s1c = DataTree(name="sub1C", parent=s1)
221-
>>> s1ca = DataTree(name="sub1Ca", parent=s1c)
222231
>>> print(RenderDataTree(root).by_attr("name"))
223232
root
224233
├── sub0

xarray/core/iterators.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,9 @@ class LevelOrderIter(Iterator):
2828
--------
2929
>>> from xarray.core.datatree import DataTree
3030
>>> from xarray.core.iterators import LevelOrderIter
31-
>>> f = DataTree(name="f")
32-
>>> b = DataTree(name="b", parent=f)
33-
>>> a = DataTree(name="a", parent=b)
34-
>>> d = DataTree(name="d", parent=b)
35-
>>> c = DataTree(name="c", parent=d)
36-
>>> e = DataTree(name="e", parent=d)
37-
>>> g = DataTree(name="g", parent=f)
38-
>>> i = DataTree(name="i", parent=g)
39-
>>> h = DataTree(name="h", parent=i)
31+
>>> f = DataTree.from_dict(
32+
... {"/b/a": None, "/b/d/c": None, "/b/d/e": None, "/g/h/i": None}, name="f"
33+
... )
4034
>>> print(f)
4135
<xarray.DataTree 'f'>
4236
Group: /
@@ -46,19 +40,19 @@ class LevelOrderIter(Iterator):
4640
│ ├── Group: /b/d/c
4741
│ └── Group: /b/d/e
4842
└── Group: /g
49-
└── Group: /g/i
50-
└── Group: /g/i/h
43+
└── Group: /g/h
44+
└── Group: /g/h/i
5145
>>> [node.name for node in LevelOrderIter(f)]
52-
['f', 'b', 'g', 'a', 'd', 'i', 'c', 'e', 'h']
46+
['f', 'b', 'g', 'a', 'd', 'h', 'c', 'e', 'i']
5347
>>> [node.name for node in LevelOrderIter(f, maxlevel=3)]
54-
['f', 'b', 'g', 'a', 'd', 'i']
48+
['f', 'b', 'g', 'a', 'd', 'h']
5549
>>> [
5650
... node.name
5751
... for node in LevelOrderIter(f, filter_=lambda n: n.name not in ("e", "g"))
5852
... ]
59-
['f', 'b', 'a', 'd', 'i', 'c', 'h']
53+
['f', 'b', 'a', 'd', 'h', 'c', 'i']
6054
>>> [node.name for node in LevelOrderIter(f, stop=lambda n: n.name == "d")]
61-
['f', 'b', 'g', 'a', 'i', 'h']
55+
['f', 'b', 'g', 'a', 'h', 'i']
6256
"""
6357

6458
def __init__(

xarray/core/treenode.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def parent(self) -> Tree | None:
8686
"""Parent of this node."""
8787
return self._parent
8888

89+
@parent.setter
90+
def parent(self: Tree, new_parent: Tree) -> None:
91+
raise AttributeError(
92+
"Cannot set parent attribute directly, you must modify the children of the other node instead using dict-like syntax"
93+
)
94+
8995
def _set_parent(
9096
self, new_parent: Tree | None, child_name: str | None = None
9197
) -> None:

xarray/tests/conftest.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,17 @@ def _create_test_datatree(modify=lambda ds: ds):
196196
set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}))
197197
root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}))
198198

199-
# Avoid using __init__ so we can independently test it
200-
root: DataTree = DataTree(data=root_data)
201-
set1: DataTree = DataTree(name="set1", parent=root, data=set1_data)
202-
DataTree(name="set1", parent=set1)
203-
DataTree(name="set2", parent=set1)
204-
set2: DataTree = DataTree(name="set2", parent=root, data=set2_data)
205-
DataTree(name="set1", parent=set2)
206-
DataTree(name="set3", parent=root)
199+
root = DataTree.from_dict(
200+
{
201+
"/": root_data,
202+
"/set1": set1_data,
203+
"/set1/set1": None,
204+
"/set1/set2": None,
205+
"/set2": set2_data,
206+
"/set2/set1": None,
207+
"/set3": None,
208+
}
209+
)
207210

208211
return root
209212

0 commit comments

Comments
 (0)