Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Commit ff86d3c

Browse files
Alternative fix for #188 (#268)
* alternative fix for 188 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 86e9f59 commit ff86d3c

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

datatree/datatree.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
HybridMappingProxy,
3737
_default,
3838
either_dict_or_kwargs,
39+
maybe_wrap_array,
3940
)
4041
from xarray.core.variable import Variable
4142

@@ -235,6 +236,65 @@ def _replace(
235236
inplace=inplace,
236237
)
237238

239+
def map(
240+
self,
241+
func: Callable,
242+
keep_attrs: bool | None = None,
243+
args: Iterable[Any] = (),
244+
**kwargs: Any,
245+
) -> Dataset:
246+
"""Apply a function to each data variable in this dataset
247+
248+
Parameters
249+
----------
250+
func : callable
251+
Function which can be called in the form `func(x, *args, **kwargs)`
252+
to transform each DataArray `x` in this dataset into another
253+
DataArray.
254+
keep_attrs : bool or None, optional
255+
If True, both the dataset's and variables' attributes (`attrs`) will be
256+
copied from the original objects to the new ones. If False, the new dataset
257+
and variables will be returned without copying the attributes.
258+
args : iterable, optional
259+
Positional arguments passed on to `func`.
260+
**kwargs : Any
261+
Keyword arguments passed on to `func`.
262+
263+
Returns
264+
-------
265+
applied : Dataset
266+
Resulting dataset from applying ``func`` to each data variable.
267+
268+
Examples
269+
--------
270+
>>> da = xr.DataArray(np.random.randn(2, 3))
271+
>>> ds = xr.Dataset({"foo": da, "bar": ("x", [-1, 2])})
272+
>>> ds
273+
<xarray.Dataset>
274+
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
275+
Dimensions without coordinates: dim_0, dim_1, x
276+
Data variables:
277+
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 -0.9773
278+
bar (x) int64 -1 2
279+
>>> ds.map(np.fabs)
280+
<xarray.Dataset>
281+
Dimensions: (dim_0: 2, dim_1: 3, x: 2)
282+
Dimensions without coordinates: dim_0, dim_1, x
283+
Data variables:
284+
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
285+
bar (x) float64 1.0 2.0
286+
"""
287+
288+
# Copied from xarray.Dataset so as not to call type(self), which causes problems (see datatree GH188).
289+
# TODO Refactor xarray upstream to avoid needing to overwrite this.
290+
# TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated
291+
variables = {
292+
k: maybe_wrap_array(v, func(v, *args, **kwargs))
293+
for k, v in self.data_vars.items()
294+
}
295+
# return type(self)(variables, attrs=attrs)
296+
return Dataset(variables)
297+
238298

239299
class DataTree(
240300
NamedNode,

0 commit comments

Comments
 (0)