@@ -25,7 +25,6 @@ class ExpectedDict(TypedDict):
25
25
shapes : dict [Hashable , int ]
26
26
coords : set [Hashable ]
27
27
data_vars : set [Hashable ]
28
- indexes : dict [Hashable , Index ]
29
28
30
29
31
30
def unzip (iterable ):
@@ -337,6 +336,7 @@ def _wrapper(
337
336
kwargs : dict ,
338
337
arg_is_array : Iterable [bool ],
339
338
expected : ExpectedDict ,
339
+ expected_indexes : dict [Hashable , Index ],
340
340
):
341
341
"""
342
342
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
@@ -372,7 +372,7 @@ def _wrapper(
372
372
373
373
# ChainMap wants MutableMapping, but xindexes is Mapping
374
374
merged_indexes = collections .ChainMap (
375
- expected [ "indexes" ] ,
375
+ expected_indexes ,
376
376
merged_coordinates .xindexes , # type: ignore[arg-type]
377
377
)
378
378
expected_index = merged_indexes .get (name , None )
@@ -412,6 +412,7 @@ def _wrapper(
412
412
try :
413
413
import dask
414
414
import dask .array
415
+ from dask .base import tokenize
415
416
from dask .highlevelgraph import HighLevelGraph
416
417
417
418
except ImportError :
@@ -551,6 +552,20 @@ def _wrapper(
551
552
for isxr , arg in zip (is_xarray , npargs , strict = True )
552
553
]
553
554
555
+ # only include new or modified indexes to minimize duplication of data
556
+ indexes = {
557
+ dim : coordinates .xindexes [dim ][
558
+ _get_chunk_slicer (dim , chunk_index , output_chunk_bounds )
559
+ ]
560
+ for dim in (new_indexes | modified_indexes )
561
+ }
562
+
563
+ tokenized_indexes : dict [Hashable , str ] = {}
564
+ for k , v in indexes .items ():
565
+ tokenized_v = tokenize (v )
566
+ graph [f"{ k } -coordinate-{ tokenized_v } " ] = v
567
+ tokenized_indexes [k ] = f"{ k } -coordinate-{ tokenized_v } "
568
+
554
569
# raise nice error messages in _wrapper
555
570
expected : ExpectedDict = {
556
571
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
@@ -562,17 +577,18 @@ def _wrapper(
562
577
},
563
578
"data_vars" : set (template .data_vars .keys ()),
564
579
"coords" : set (template .coords .keys ()),
565
- # only include new or modified indexes to minimize duplication of data, and graph size.
566
- "indexes" : {
567
- dim : coordinates .xindexes [dim ][
568
- _get_chunk_slicer (dim , chunk_index , output_chunk_bounds )
569
- ]
570
- for dim in (new_indexes | modified_indexes )
571
- },
572
580
}
573
581
574
582
from_wrapper = (gname ,) + chunk_tuple
575
- graph [from_wrapper ] = (_wrapper , func , blocked_args , kwargs , is_array , expected )
583
+ graph [from_wrapper ] = (
584
+ _wrapper ,
585
+ func ,
586
+ blocked_args ,
587
+ kwargs ,
588
+ is_array ,
589
+ expected ,
590
+ (dict , [[k , v ] for k , v in tokenized_indexes .items ()]),
591
+ )
576
592
577
593
# mapping from variable name to dask graph key
578
594
var_key_map : dict [Hashable , str ] = {}
0 commit comments