Skip to content

Commit b22d24b

Browse files
Prune data tree for Isomorphic operations (#10097)
* prune data tree function with tests * update api.rst with prune * refactor prune to filter_like and use self.filter in filter_like * Added filter_like to whats-new
1 parent 282235f commit b22d24b

File tree

4 files changed

+80
-0
lines changed

4 files changed

+80
-0
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ For manipulating, traversing, navigating, or mapping over the tree structure.
697697
DataTree.pipe
698698
DataTree.match
699699
DataTree.filter
700+
DataTree.filter_like
700701

701702
Pathlib-like Interface
702703
----------------------

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ v2025.02.0 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24+
- Added :py:meth:`DataTree.filter_like` to conveniently restructure a DataTree like another DataTree (:issue:`10096`, :pull:`10097`).
25+
By `Kobe Vandelanotte <https://github.com/kobebryant432>`_.
2426
- Added :py:meth:`Coordinates.from_xindex` as convenience for creating a new :py:class:`Coordinates` object
2527
directly from an existing Xarray index object if the latter supports it (:pull:`10000`)
2628
By `Benoit Bovy <https://github.com/benbovy>`_.

xarray/core/datatree.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,54 @@ def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree:
13941394
}
13951395
return DataTree.from_dict(filtered_nodes, name=self.name)
13961396

1397+
def filter_like(self, other: DataTree) -> DataTree:
1398+
"""
1399+
Filter a datatree like another datatree.
1400+
1401+
Returns a new tree containing only the nodes in the original tree which are also present in the other tree.
1402+
1403+
Parameters
1404+
----------
1405+
other : DataTree
1406+
The tree to filter this tree by.
1407+
1408+
Returns
1409+
-------
1410+
DataTree
1411+
1412+
See Also
1413+
--------
1414+
filter
1415+
isomorphic
1416+
1417+
Examples
1418+
--------
1419+
1420+
>>> dt = DataTree.from_dict(
1421+
... {
1422+
... "/a/A": None,
1423+
... "/a/B": None,
1424+
... "/b/A": None,
1425+
... "/b/B": None,
1426+
... }
1427+
... )
1428+
>>> other = DataTree.from_dict(
1429+
... {
1430+
... "/a/A": None,
1431+
... "/b/A": None,
1432+
... }
1433+
... )
1434+
>>> dt.filter_like(other)
1435+
<xarray.DataTree>
1436+
Group: /
1437+
├── Group: /a
1438+
│ └── Group: /a/A
1439+
└── Group: /b
1440+
└── Group: /b/A
1441+
"""
1442+
other_keys = {key for key, _ in other.subtree_with_keys}
1443+
return self.filter(lambda node: node.relative_to(self) in other_keys)
1444+
13971445
def match(self, pattern: str) -> DataTree:
13981446
"""
13991447
Return nodes with paths matching pattern.

xarray/tests/test_datatree.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,35 @@ def test_assign(self) -> None:
15871587
result = dt.assign({"foo": xr.DataArray(0), "a": DataTree()})
15881588
assert_equal(result, expected)
15891589

1590+
def test_filter_like(self) -> None:
1591+
flower_tree = DataTree.from_dict(
1592+
{"root": None, "trunk": None, "leaves": None, "flowers": None}
1593+
)
1594+
fruit_tree = DataTree.from_dict(
1595+
{"root": None, "trunk": None, "leaves": None, "fruit": None}
1596+
)
1597+
barren_tree = DataTree.from_dict({"root": None, "trunk": None})
1598+
1599+
# test filter_like tree
1600+
filtered_tree = flower_tree.filter_like(barren_tree)
1601+
1602+
assert filtered_tree.equals(barren_tree)
1603+
assert "flowers" not in filtered_tree.children
1604+
1605+
# test symetrical pruning results in isomorphic trees
1606+
assert flower_tree.filter_like(fruit_tree).isomorphic(
1607+
fruit_tree.filter_like(flower_tree)
1608+
)
1609+
1610+
# test "deep" pruning
1611+
dt = DataTree.from_dict(
1612+
{"/a/A": None, "/a/B": None, "/b/A": None, "/b/B": None}
1613+
)
1614+
other = DataTree.from_dict({"/a/A": None, "/b/A": None})
1615+
1616+
filtered = dt.filter_like(other)
1617+
assert filtered.equals(other)
1618+
15901619

15911620
class TestPipe:
15921621
def test_noop(self, create_test_datatree: Callable[[], DataTree]) -> None:

0 commit comments

Comments
 (0)