Skip to content

Commit 857c783

Browse files
committed
convert apply_to_dataset to a top-level function
1 parent fd2c897 commit 857c783

File tree

6 files changed

+89
-34
lines changed

6 files changed

+89
-34
lines changed

xarray/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@
1818
from .core.alignment import align, broadcast
1919
from .core.combine import combine_by_coords, combine_nested
2020
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
21-
from .core.computation import apply_ufunc, corr, cov, dot, polyval, where
21+
from .core.computation import (
22+
apply_to_dataset,
23+
apply_ufunc,
24+
corr,
25+
cov,
26+
dot,
27+
polyval,
28+
where,
29+
)
2230
from .core.concat import concat
2331
from .core.dataarray import DataArray
2432
from .core.dataset import Dataset
@@ -46,6 +54,7 @@
4654
# Top-level functions
4755
"align",
4856
"apply_ufunc",
57+
"apply_to_dataset",
4958
"as_variable",
5059
"broadcast",
5160
"cftime_range",

xarray/core/common.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -635,21 +635,6 @@ def pipe(
635635
else:
636636
return func(self, *args, **kwargs)
637637

638-
def apply_to_dataset(self, f):
639-
from .dataarray import DataArray
640-
641-
if isinstance(self, DataArray):
642-
ds = self._to_temp_dataset()
643-
else:
644-
ds = self
645-
646-
result = f(ds)
647-
648-
if isinstance(self, DataArray):
649-
return self._from_temp_dataset(result, name=self.name)
650-
else:
651-
return result
652-
653638
def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None):
654639
"""Returns a GroupBy object for performing grouped operations.
655640

xarray/core/computation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,42 @@ def earth_mover_distance(first_samples,
11421142
return apply_array_ufunc(func, *args, dask=dask)
11431143

11441144

1145+
def apply_to_dataset(func, obj, *args, **kwargs):
1146+
"""apply a function expecting a Dataset to a xarray object
1147+
1148+
Parameters
1149+
----------
1150+
func : callable
1151+
A function expecting a Dataset as its first parameter.
1152+
obj : DataArray or Dataset
1153+
The dataset to apply ``func`` to. If a ``DataArray``, convert it to a single
1154+
variable ``Dataset`` first.
1155+
*args, **kwargs
1156+
Additional arguments to ``func``
1157+
1158+
Returns
1159+
-------
1160+
DataArray or Dataset
1161+
The result of ``func(obj, *args, **kwargs)`` with the same type as ``obj``.
1162+
1163+
Notes
1164+
-----
1165+
If a ``DataArray``, result will have the same name as ``obj`` but the single data
1166+
variable in the temporary ``Dataset`` will always have a generic name.
1167+
"""
1168+
from .dataarray import DataArray
1169+
1170+
ds = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj
1171+
1172+
result = func(ds, *args, **kwargs)
1173+
1174+
return (
1175+
obj._from_temp_dataset(result, name=obj.name)
1176+
if isinstance(obj, DataArray)
1177+
else result
1178+
)
1179+
1180+
11451181
def cov(da_a, da_b, dim=None, ddof=1):
11461182
"""
11471183
Compute covariance between two DataArray objects along a shared dimension.

xarray/tests/test_computation.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,49 @@ def test_apply_groupby_add():
468468
add(data_array.groupby("y"), data_array.groupby("x"))
469469

470470

471+
@pytest.mark.parametrize(
472+
["obj", "expected"],
473+
(
474+
pytest.param(
475+
xr.DataArray(
476+
[0, 1],
477+
coords={
478+
"x": ("x", [-1, 1], {"a": 1, "b": 2}),
479+
"u": ("x", [2, 3], {"c": 3}),
480+
},
481+
dims="x",
482+
attrs={"d": 4, "e": 5},
483+
),
484+
xr.DataArray([0, 1], coords={"x": [-1, 1], "u": ("x", [2, 3])}, dims="x"),
485+
id="DataArray",
486+
),
487+
pytest.param(
488+
xr.Dataset(
489+
{"a": ("x", [1, 2], {"a": 1, "b": 2}), "b": ("x", [0, 1], {"c": 3})},
490+
coords={
491+
"x": ("x", [-1, 1], {"d": 4, "e": 5}),
492+
"u": ("x", [2, 3], {"f": 6}),
493+
},
494+
),
495+
xr.Dataset(
496+
{"a": ("x", [1, 2]), "b": ("x", [0, 1])},
497+
coords={"x": [-1, 1], "u": ("x", [2, 3])},
498+
),
499+
id="Dataset",
500+
),
501+
),
502+
)
503+
def test_apply_to_dataset(obj, expected):
504+
def clear_all_attrs(ds):
505+
new_ds = ds.copy()
506+
for var in new_ds.variables.values():
507+
var.attrs.clear()
508+
new_ds.attrs.clear()
509+
return new_ds
510+
511+
assert_identical(expected, xr.apply_to_dataset(clear_all_attrs, obj))
512+
513+
471514
def test_unified_dim_sizes():
472515
assert unified_dim_sizes([xr.Variable((), 0)]) == {}
473516
assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1])]) == {"x": 1}

xarray/tests/test_dataarray.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,17 +2617,6 @@ def test_fillna(self):
26172617
actual = a.groupby("b").fillna(DataArray([0, 2], dims="b"))
26182618
assert_identical(expected, actual)
26192619

2620-
def test_apply_to_dataset(self):
2621-
def func(ds):
2622-
return Dataset(ds.data_vars, coords=ds.coords)
2623-
2624-
da = DataArray(
2625-
[[0, 1], [2, 3], [4, 5]],
2626-
dims=("x", "y"),
2627-
name="abc",
2628-
)
2629-
assert_identical(da, da.apply_to_dataset(func))
2630-
26312620
def test_groupby_iter(self):
26322621
for ((act_x, act_dv), (exp_x, exp_ds)) in zip(
26332622
self.dv.groupby("y"), self.ds.groupby("y")

xarray/tests/test_dataset.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5028,13 +5028,6 @@ def test_count(self):
50285028
actual = ds.count()
50295029
assert_identical(expected, actual)
50305030

5031-
def test_apply_to_dataset(self):
5032-
def func(ds):
5033-
return Dataset(ds.data_vars, coords=ds.coords)
5034-
5035-
ds = create_test_data()
5036-
assert_identical(ds, ds.apply_to_dataset(func))
5037-
50385031
def test_map(self):
50395032
data = create_test_data()
50405033
data.attrs["foo"] = "bar"

0 commit comments

Comments
 (0)