Skip to content

Commit 5632c8e

Browse files
phofldcherian
andauthored
Reduce graph size through writing indexes directly into graph for map_blocks (#9658)
* Reduce graph size through writing indexes directly into graph for map_blocks * Reduce graph size through writing indexes directly into graph for map_blocks * Update xarray/core/parallel.py --------- Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent 863184d commit 5632c8e

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

xarray/core/parallel.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class ExpectedDict(TypedDict):
2525
shapes: dict[Hashable, int]
2626
coords: set[Hashable]
2727
data_vars: set[Hashable]
28-
indexes: dict[Hashable, Index]
2928

3029

3130
def unzip(iterable):
@@ -337,6 +336,7 @@ def _wrapper(
337336
kwargs: dict,
338337
arg_is_array: Iterable[bool],
339338
expected: ExpectedDict,
339+
expected_indexes: dict[Hashable, Index],
340340
):
341341
"""
342342
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
@@ -372,7 +372,7 @@ def _wrapper(
372372

373373
# ChainMap wants MutableMapping, but xindexes is Mapping
374374
merged_indexes = collections.ChainMap(
375-
expected["indexes"],
375+
expected_indexes,
376376
merged_coordinates.xindexes, # type: ignore[arg-type]
377377
)
378378
expected_index = merged_indexes.get(name, None)
@@ -412,6 +412,7 @@ def _wrapper(
412412
try:
413413
import dask
414414
import dask.array
415+
from dask.base import tokenize
415416
from dask.highlevelgraph import HighLevelGraph
416417

417418
except ImportError:
@@ -551,6 +552,20 @@ def _wrapper(
551552
for isxr, arg in zip(is_xarray, npargs, strict=True)
552553
]
553554

555+
# only include new or modified indexes to minimize duplication of data
556+
indexes = {
557+
dim: coordinates.xindexes[dim][
558+
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
559+
]
560+
for dim in (new_indexes | modified_indexes)
561+
}
562+
563+
tokenized_indexes: dict[Hashable, str] = {}
564+
for k, v in indexes.items():
565+
tokenized_v = tokenize(v)
566+
graph[f"{k}-coordinate-{tokenized_v}"] = v
567+
tokenized_indexes[k] = f"{k}-coordinate-{tokenized_v}"
568+
554569
# raise nice error messages in _wrapper
555570
expected: ExpectedDict = {
556571
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
@@ -562,17 +577,18 @@ def _wrapper(
562577
},
563578
"data_vars": set(template.data_vars.keys()),
564579
"coords": set(template.coords.keys()),
565-
# only include new or modified indexes to minimize duplication of data, and graph size.
566-
"indexes": {
567-
dim: coordinates.xindexes[dim][
568-
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
569-
]
570-
for dim in (new_indexes | modified_indexes)
571-
},
572580
}
573581

574582
from_wrapper = (gname,) + chunk_tuple
575-
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
583+
graph[from_wrapper] = (
584+
_wrapper,
585+
func,
586+
blocked_args,
587+
kwargs,
588+
is_array,
589+
expected,
590+
(dict, [[k, v] for k, v in tokenized_indexes.items()]),
591+
)
576592

577593
# mapping from variable name to dask graph key
578594
var_key_map: dict[Hashable, str] = {}

xarray/tests/test_dask.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from xarray import DataArray, Dataset, Variable
1515
from xarray.core import duck_array_ops
1616
from xarray.core.duck_array_ops import lazy_array_equiv
17+
from xarray.core.indexes import PandasIndex
1718
from xarray.testing import assert_chunks_equal
1819
from xarray.tests import (
1920
assert_allclose,
@@ -1375,6 +1376,13 @@ def test_map_blocks_da_ds_with_template(obj):
13751376
actual = xr.map_blocks(func, obj, template=template)
13761377
assert_identical(actual, template)
13771378

1379+
# Check that indexes are written into the graph directly
1380+
dsk = dict(actual.__dask_graph__())
1381+
assert len({k for k in dsk if "x-coordinate" in k})
1382+
assert all(
1383+
isinstance(v, PandasIndex) for k, v in dsk.items() if "x-coordinate" in k
1384+
)
1385+
13781386
with raise_if_dask_computes():
13791387
actual = obj.map_blocks(func, template=template)
13801388
assert_identical(actual, template)

0 commit comments

Comments
 (0)