Skip to content

Commit b444438

Browse files
authored
Adapt map_blocks to use new Coordinates API (pydata#8560)
* Adapt map_blocks to use new Coordinates API * cleanup * typing fixes * optimize * small cleanups * Typing fixes
1 parent b3890a3 commit b444438

File tree

5 files changed

+79
-35
lines changed

5 files changed

+79
-35
lines changed

xarray/core/coordinates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates):
213213
:py:class:`~xarray.Coordinates` object is passed, its indexes
214214
will be added to the new created object.
215215
indexes: dict-like, optional
216-
Mapping of where keys are coordinate names and values are
216+
Mapping where keys are coordinate names and values are
217217
:py:class:`~xarray.indexes.Index` objects. If None (default),
218218
pandas indexes will be created for each dimension coordinate.
219219
Passing an empty dictionary will skip this default behavior.

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
try:
8181
from dask.dataframe import DataFrame as DaskDataFrame
8282
except ImportError:
83-
DaskDataFrame = None # type: ignore
83+
DaskDataFrame = None
8484
try:
8585
from dask.delayed import Delayed
8686
except ImportError:

xarray/core/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
try:
172172
from dask.dataframe import DataFrame as DaskDataFrame
173173
except ImportError:
174-
DaskDataFrame = None # type: ignore
174+
DaskDataFrame = None
175175

176176

177177
# list of attributes of pd.DatetimeIndex that are ndarrays of time info

xarray/core/parallel.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,29 @@
44
import itertools
55
import operator
66
from collections.abc import Hashable, Iterable, Mapping, Sequence
7-
from typing import TYPE_CHECKING, Any, Callable
7+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
88

99
import numpy as np
1010

1111
from xarray.core.alignment import align
12+
from xarray.core.coordinates import Coordinates
1213
from xarray.core.dataarray import DataArray
1314
from xarray.core.dataset import Dataset
15+
from xarray.core.indexes import Index
16+
from xarray.core.merge import merge
1417
from xarray.core.pycompat import is_dask_collection
1518

1619
if TYPE_CHECKING:
1720
from xarray.core.types import T_Xarray
1821

1922

23+
class ExpectedDict(TypedDict):
24+
shapes: dict[Hashable, int]
25+
coords: set[Hashable]
26+
data_vars: set[Hashable]
27+
indexes: dict[Hashable, Index]
28+
29+
2030
def unzip(iterable):
2131
return zip(*iterable)
2232

@@ -31,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):
3141

3242

3343
def check_result_variables(
34-
result: DataArray | Dataset, expected: Mapping[str, Any], kind: str
44+
result: DataArray | Dataset,
45+
expected: ExpectedDict,
46+
kind: Literal["coords", "data_vars"],
3547
):
3648
if kind == "coords":
3749
nice_str = "coordinate"
@@ -254,7 +266,7 @@ def _wrapper(
254266
args: list,
255267
kwargs: dict,
256268
arg_is_array: Iterable[bool],
257-
expected: dict,
269+
expected: ExpectedDict,
258270
):
259271
"""
260272
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
@@ -345,33 +357,45 @@ def _wrapper(
345357
for arg in aligned
346358
)
347359

360+
merged_coordinates = merge([arg.coords for arg in aligned]).coords
361+
348362
_, npargs = unzip(
349363
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
350364
)
351365

352366
# check that chunk sizes are compatible
353367
input_chunks = dict(npargs[0].chunks)
354-
input_indexes = dict(npargs[0]._indexes)
355368
for arg in xarray_objs[1:]:
356369
assert_chunks_compatible(npargs[0], arg)
357370
input_chunks.update(arg.chunks)
358-
input_indexes.update(arg._indexes)
359371

372+
coordinates: Coordinates
360373
if template is None:
361374
# infer template by providing zero-shaped arrays
362375
template = infer_template(func, aligned[0], *args, **kwargs)
363-
template_indexes = set(template._indexes)
364-
preserved_indexes = template_indexes & set(input_indexes)
365-
new_indexes = template_indexes - set(input_indexes)
366-
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
367-
indexes.update({k: template._indexes[k] for k in new_indexes})
376+
template_coords = set(template.coords)
377+
preserved_coord_vars = template_coords & set(merged_coordinates)
378+
new_coord_vars = template_coords - set(merged_coordinates)
379+
380+
preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
381+
# preserved_coords contains all coordinates bariables that share a dimension
382+
# with any index variable in preserved_indexes
383+
# Drop any unneeded vars in a second pass, this is required for e.g.
384+
# if the mapped function were to drop a non-dimension coordinate variable.
385+
preserved_coords = preserved_coords.drop_vars(
386+
tuple(k for k in preserved_coords.variables if k not in template_coords)
387+
)
388+
389+
coordinates = merge(
390+
(preserved_coords, template.coords.to_dataset()[new_coord_vars])
391+
).coords
368392
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
369393
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
370394
}
371395

372396
else:
373397
# template xarray object has been provided with proper sizes and chunk shapes
374-
indexes = dict(template._indexes)
398+
coordinates = template.coords
375399
output_chunks = template.chunksizes
376400
if not output_chunks:
377401
raise ValueError(
@@ -473,6 +497,9 @@ def subset_dataset_to_block(
473497

474498
return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)
475499

500+
# variable names that depend on the computation. Currently, indexes
501+
# cannot be modified in the mapped function, so we exclude thos
502+
computed_variables = set(template.variables) - set(coordinates.xindexes)
476503
# iterate over all possible chunk combinations
477504
for chunk_tuple in itertools.product(*ichunk.values()):
478505
# mapping from dimension name to chunk index
@@ -485,29 +512,32 @@ def subset_dataset_to_block(
485512
for isxr, arg in zip(is_xarray, npargs)
486513
]
487514

488-
# expected["shapes", "coords", "data_vars", "indexes"] are used to
489515
# raise nice error messages in _wrapper
490-
expected = {}
491-
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
492-
# even if length of dimension is changed by the applied function
493-
expected["shapes"] = {
494-
k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks
495-
}
496-
expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
497-
expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
498-
expected["indexes"] = {
499-
dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)]
500-
for dim in indexes
516+
expected: ExpectedDict = {
517+
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
518+
# even if length of dimension is changed by the applied function
519+
"shapes": {
520+
k: output_chunks[k][v]
521+
for k, v in chunk_index.items()
522+
if k in output_chunks
523+
},
524+
"data_vars": set(template.data_vars.keys()),
525+
"coords": set(template.coords.keys()),
526+
"indexes": {
527+
dim: coordinates.xindexes[dim][
528+
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
529+
]
530+
for dim in coordinates.xindexes
531+
},
501532
}
502533

503534
from_wrapper = (gname,) + chunk_tuple
504535
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
505536

506537
# mapping from variable name to dask graph key
507538
var_key_map: dict[Hashable, str] = {}
508-
for name, variable in template.variables.items():
509-
if name in indexes:
510-
continue
539+
for name in computed_variables:
540+
variable = template.variables[name]
511541
gname_l = f"{name}-{gname}"
512542
var_key_map[name] = gname_l
513543

@@ -543,12 +573,7 @@ def subset_dataset_to_block(
543573
},
544574
)
545575

546-
# TODO: benbovy - flexible indexes: make it work with custom indexes
547-
# this will need to pass both indexes and coords to the Dataset constructor
548-
result = Dataset(
549-
coords={k: idx.to_pandas_index() for k, idx in indexes.items()},
550-
attrs=template.attrs,
551-
)
576+
result = Dataset(coords=coordinates, attrs=template.attrs)
552577

553578
for index in result._indexes:
554579
result[index].attrs = template[index].attrs

xarray/tests/test_dask.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,25 @@ def test_map_blocks_da_ds_with_template(obj):
13671367
assert_identical(actual, template)
13681368

13691369

1370+
def test_map_blocks_roundtrip_string_index():
1371+
ds = xr.Dataset(
1372+
{"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]}
1373+
).chunk(label=1)
1374+
assert ds.label.dtype == np.dtype("<U3")
1375+
1376+
mapped = ds.map_blocks(lambda x: x, template=ds)
1377+
assert mapped.label.dtype == ds.label.dtype
1378+
1379+
mapped = ds.map_blocks(lambda x: x, template=None)
1380+
assert mapped.label.dtype == ds.label.dtype
1381+
1382+
mapped = ds.data.map_blocks(lambda x: x, template=ds.data)
1383+
assert mapped.label.dtype == ds.label.dtype
1384+
1385+
mapped = ds.data.map_blocks(lambda x: x, template=None)
1386+
assert mapped.label.dtype == ds.label.dtype
1387+
1388+
13701389
def test_map_blocks_template_convert_object():
13711390
da = make_da()
13721391
func = lambda x: x.to_dataset().isel(x=[1])

0 commit comments

Comments
 (0)