From 0104039988140ad5e79502acf77b682516e9dad2 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 10 Feb 2025 18:33:18 +0100 Subject: [PATCH 01/11] map_over_datasets: skip empty nodes --- xarray/core/datatree_mapping.py | 53 ++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 6262c7f19cd..8bcd3941e3b 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -107,21 +107,35 @@ def map_over_datasets( # We don't know which arguments are DataTrees so we zip all arguments together as iterables # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {} + func_called: dict[str, bool] = {} tree_args = [arg for arg in args if isinstance(arg, DataTree)] name = result_name(tree_args) for path, node_tree_args in group_subtrees(*tree_args): - node_dataset_args = [arg.dataset for arg in node_tree_args] - for i, arg in enumerate(args): - if not isinstance(arg, DataTree): - node_dataset_args.insert(i, arg) + if node_tree_args[0].has_data: + node_dataset_args = [arg.dataset for arg in node_tree_args] + for i, arg in enumerate(args): + if not isinstance(arg, DataTree): + node_dataset_args.insert(i, arg) + + func_with_error_context = _handle_errors_with_path_context(path)(func) + results = func_with_error_context(*node_dataset_args, **kwargs) + func_called[path] = True + + elif node_tree_args[0].has_attrs: + # propagate attrs + results = node_tree_args[0].dataset + func_called[path] = False + + else: + # use Dataset instead of None so it has copy method + results = Dataset() + func_called[path] = False - func_with_error_context = _handle_errors_with_path_context(path)(func) - results = func_with_error_context(*node_dataset_args, **kwargs) out_data_objects[path] = results - num_return_values = _check_all_return_values(out_data_objects) + num_return_values = _check_all_return_values(out_data_objects, func_called) if num_return_values is None: # one return value @@ -134,6 +148,10 @@ def map_over_datasets( {} for _ in range(num_return_values) ] for path, outputs in out_data_tuples.items(): + # duplicate outputs when func was not called (empty nodes) + if not func_called[path]: + outputs = tuple(outputs.copy() for _ in range(num_return_values)) + for output_dict, output in zip(output_dicts, outputs, strict=False): output_dict[path] = output @@ -185,17 +203,30 @@ def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None: return len(obj) -def _check_all_return_values(returned_objects) -> int | None: +def _check_all_return_values(returned_objects, func_called) -> int | None: """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" result_data_objects = list(returned_objects.items()) - first_path, result = result_data_objects[0] - return_values = _check_single_set_return_values(first_path, result) + func_called_before = False + + # initialize to None if all nodes are empty + return_values = None - for path_to_node, obj in result_data_objects[1:]: + for path_to_node, obj in result_data_objects: cur_return_values = _check_single_set_return_values(path_to_node, obj) + cur_func_called = func_called[path_to_node] + + # the first node where func was actually called + if cur_func_called and not func_called_before: + return_values = cur_return_values + func_called_before = True + first_path = path_to_node + + if not cur_func_called: + continue + if return_values != cur_return_values: if return_values is None: raise TypeError( From a23fb44189866bd600d566b799b0f0b4c0a83267 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 10 Feb 2025 19:33:40 +0100 Subject: [PATCH 02/11] fix typing --- xarray/core/datatree_mapping.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 8bcd3941e3b..69fbcd3c09d 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -150,7 +150,8 @@ def map_over_datasets( for path, outputs in out_data_tuples.items(): # duplicate outputs when func was not called (empty nodes) if not func_called[path]: - outputs = tuple(outputs.copy() for _ in range(num_return_values)) + out = cast(Dataset, outputs) + outputs = tuple(out.copy() for _ in range(num_return_values)) for output_dict, output in zip(output_dicts, outputs, strict=False): output_dict[path] = output From f787a769df43c69b886810ec0a06a56bc979175f Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 6 Mar 2025 16:55:04 +0100 Subject: [PATCH 03/11] changelog --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a3e30e58a2b..cceadfca540 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,6 +35,9 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- Skip empty nodes in :py:func:`map_over_datasets`. This is a breaking change in xarray, but + restores the behavior of the xarray-datatree package (:issue:`9693`, :pull:`10042`). + By `Mathias Hauser `_. - Warn instead of raise if phony_dims are detected when using h5netcdf-backend and ``phony_dims=None`` (:issue:`10049`, :pull:`10058`) By `Kai Mühlbauer `_. From f2a924dc2b2761e4f66d641385bb41a74bd25d8f Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 6 Mar 2025 16:55:35 +0100 Subject: [PATCH 04/11] update docstring & comments --- xarray/core/datatree_mapping.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 69fbcd3c09d..93ebd2e88a6 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -47,15 +47,15 @@ def map_over_datasets( kwargs: Mapping[str, Any] | None = None, ) -> DataTree | tuple[DataTree, ...]: """ - Applies a function to every dataset in one or more DataTree objects with - the same structure (ie.., that are isomorphic), returning new trees which + Applies a function to every non-empty dataset in one or more DataTree objects + with the same structure (i.e., that are isomorphic), returning new trees which store the results. - The function will be applied to any dataset stored in any of the nodes in - the trees. The returned trees will have the same structure as the supplied - trees. + The function will be applied to every node containing data (i.e., which has + ``data_vars`` and/ or ``coordinates``) in the trees. The returned tree(s) will + have the same structure as the supplied trees. - ``func`` needs to return a Dataset, tuple of Dataset objects or None in order + ``func`` needs to return a Dataset, tuple of Dataset objects or None to be able to rebuild the subtrees after mapping, as each result will be assigned to its respective node of a new tree via `DataTree.from_dict`. Any returned value that is one of these types will be stacked into a separate @@ -63,7 +63,7 @@ def map_over_datasets( ``map_over_datasets`` is essentially syntactic sugar for the combination of ``group_subtrees`` and ``DataTree.from_dict``. For example, in the case of - a two argument function that return one result, it is equivalent to:: + a two argument function that returns one result, it is equivalent to:: results = {} for path, (left, right) in group_subtrees(left_tree, right_tree): @@ -107,7 +107,7 @@ def map_over_datasets( # We don't know which arguments are DataTrees so we zip all arguments together as iterables # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {} - func_called: dict[str, bool] = {} + func_called: dict[str, bool] = {} # empty nodes don't call `func` tree_args = [arg for arg in args if isinstance(arg, DataTree)] name = result_name(tree_args) @@ -129,7 +129,7 @@ def map_over_datasets( func_called[path] = False else: - # use Dataset instead of None so it has copy method + # use Dataset instead of None to ensure it has copy method results = Dataset() func_called[path] = False @@ -219,7 +219,8 @@ def _check_all_return_values(returned_objects, func_called) -> int | None: cur_func_called = func_called[path_to_node] - # the first node where func was actually called + # the first node where func was actually called - needed to find the number of + # return values if cur_func_called and not func_called_before: return_values = cur_return_values func_called_before = True From 82de573c7e73bb19f915a02cf4ce85fc99cee175 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 6 Mar 2025 18:09:50 +0100 Subject: [PATCH 05/11] more comments --- xarray/core/datatree_mapping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 93ebd2e88a6..a1ae531338b 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -226,6 +226,7 @@ def _check_all_return_values(returned_objects, func_called) -> int | None: func_called_before = True first_path = path_to_node + # no need to check if the function was not called if not cur_func_called: continue From 0e34067b25a4bbf73de2a4693338af9752e2d486 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 6 Mar 2025 18:12:14 +0100 Subject: [PATCH 06/11] tests --- xarray/tests/test_datatree_mapping.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 6cb4455b739..304023b54e7 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -64,6 +64,13 @@ def multiply_by_kwarg(ds, **kwargs): ) assert_equal(result_tree, expected) + def test_single_tree_skip_empty_nodes(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(lambda ds: ds.rename(a="c")) + # this would fail on empty nodes + result_tree = map_over_datasets(lambda ds: ds.rename(a="c"), dt) + assert_equal(result_tree, expected) + def test_multiple_tree_args(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() @@ -79,6 +86,17 @@ def test_return_multiple_trees(self, create_test_datatree): expected_max = create_test_datatree(modify=lambda ds: ds.max()) assert_equal(dt_max, expected_max) + def test_return_multiple_trees_empty_first_node(self): + # check result tree is constructed correctly even if first nodes are empty + ds = xr.Dataset(data_vars={"a": ("x", [1, 2, 3])}) + dt = xr.DataTree.from_dict({"set1": None, "set2": ds}) + res_min, res_max = xr.map_over_datasets(lambda ds: (ds.min(), ds.max()), dt) + assert_equal(res_min, dt.min()) + assert_equal(res_max, dt.max()) + + # ensure they are different objects + assert res_min["set1"].dataset is not res_max["set2"].dataset + def test_return_wrong_type(self, simple_datatree): dt1 = simple_datatree From bbfedf2c92eeda4f16ebffcf4f55b309f9d0e3b6 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 7 Mar 2025 04:55:17 +0100 Subject: [PATCH 07/11] remove unnecessary test --- xarray/tests/test_datatree_mapping.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 304023b54e7..c7b47d7a974 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -94,9 +94,6 @@ def test_return_multiple_trees_empty_first_node(self): assert_equal(res_min, dt.min()) assert_equal(res_max, dt.max()) - # ensure they are different objects - assert res_min["set1"].dataset is not res_max["set2"].dataset - def test_return_wrong_type(self, simple_datatree): dt1 = simple_datatree @@ -201,7 +198,6 @@ def empty_func(ds): def test_error_contains_path_of_offending_node(self, create_test_datatree): dt = create_test_datatree() dt["set1"]["bad_var"] = 0 - print(dt) def fail_on_specific_node(ds): if "bad_var" in ds: From 4ba1d27d27b75f96ecbd98fac433dfb84dab31a9 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 7 Mar 2025 05:05:32 +0100 Subject: [PATCH 08/11] add binary op test --- xarray/tests/test_datatree.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index c87a1e1329e..fb81d319aa2 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2084,6 +2084,24 @@ def test_binary_op_on_dataset(self) -> None: result = dt * other_ds assert_equal(result, expected) + def test_binary_op_on_dataset_skip_empty_nodes(self) -> None: + # https://github.com/pydata/xarray/issues/10013 + + dt = xr.DataTree() + + a = xr.Dataset(data_vars={"x": ("time", [10])}, coords={"time": [0]}) + b = xr.Dataset(data_vars={"x": ("time", [11, 22])}, coords={"time": [0, 1]}) + + dt = DataTree.from_dict({"a": a, "b": b}) + + expected = DataTree.from_dict({"a": a - b, "b": b - b}) + + # if the empty root node is not skipped its coordinates become inconsistent + # with the ones of node a + result = dt - b + + assert_equal(result, expected) + def test_binary_op_on_datatree(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) From 6e18bd7491585cf5db023addb659361b8a0c9298 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 7 Mar 2025 05:16:46 +0100 Subject: [PATCH 09/11] mention binary ops --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cceadfca540..f9999898e00 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,8 +35,8 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ -- Skip empty nodes in :py:func:`map_over_datasets`. This is a breaking change in xarray, but - restores the behavior of the xarray-datatree package (:issue:`9693`, :pull:`10042`). +- Skip empty nodes in :py:func:`map_over_datasets`. Also affects binary operations. + This is a breaking change in xarray, but restores the behavior of the xarray-datatree package (:issue:`9693`, :pull:`10042`). By `Mathias Hauser `_. - Warn instead of raise if phony_dims are detected when using h5netcdf-backend and ``phony_dims=None`` (:issue:`10049`, :pull:`10058`) By `Kai Mühlbauer `_. From 65181a4660b386b0dbd689abb2ec50866f4d55e6 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Fri, 7 Mar 2025 05:19:20 +0100 Subject: [PATCH 10/11] clean test --- xarray/tests/test_datatree.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index fb81d319aa2..0ca1e621bc9 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2087,8 +2087,6 @@ def test_binary_op_on_dataset(self) -> None: def test_binary_op_on_dataset_skip_empty_nodes(self) -> None: # https://github.com/pydata/xarray/issues/10013 - dt = xr.DataTree() - a = xr.Dataset(data_vars={"x": ("time", [10])}, coords={"time": [0]}) b = xr.Dataset(data_vars={"x": ("time", [11, 22])}, coords={"time": [0, 1]}) From d4c8bdca84e4d2c36241c8d4f04dace03f1fe4e0 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 12 Jun 2025 16:16:10 +0200 Subject: [PATCH 11/11] move changelog entry --- doc/whats-new.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3cecc3611be..b0e22c3eab5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -16,7 +16,9 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ - +- Skip empty nodes in :py:func:`map_over_datasets`. Also affects binary operations. + This is a breaking change in xarray, but restores the behavior of the xarray-datatree package (:issue:`9693`, :pull:`10042`). + By `Mathias Hauser `_. Deprecations ~~~~~~~~~~~~ @@ -310,9 +312,6 @@ Performance Breaking changes ~~~~~~~~~~~~~~~~ -- Skip empty nodes in :py:func:`map_over_datasets`. Also affects binary operations. - This is a breaking change in xarray, but restores the behavior of the xarray-datatree package (:issue:`9693`, :pull:`10042`). - By `Mathias Hauser `_. - Rolled back code that would attempt to catch integer overflow when encoding times with small integer dtypes (:issue:`8542`), since it was inconsistent with xarray's handling of standard integers, and interfered with encoding