4
4
import itertools
5
5
import operator
6
6
from collections .abc import Hashable , Iterable , Mapping , Sequence
7
- from typing import TYPE_CHECKING , Any , Callable
7
+ from typing import TYPE_CHECKING , Any , Callable , Literal , TypedDict
8
8
9
9
import numpy as np
10
10
11
11
from xarray .core .alignment import align
12
+ from xarray .core .coordinates import Coordinates
12
13
from xarray .core .dataarray import DataArray
13
14
from xarray .core .dataset import Dataset
15
+ from xarray .core .indexes import Index
16
+ from xarray .core .merge import merge
14
17
from xarray .core .pycompat import is_dask_collection
15
18
16
19
if TYPE_CHECKING :
17
20
from xarray .core .types import T_Xarray
18
21
19
22
23
+ class ExpectedDict (TypedDict ):
24
+ shapes : dict [Hashable , int ]
25
+ coords : set [Hashable ]
26
+ data_vars : set [Hashable ]
27
+ indexes : dict [Hashable , Index ]
28
+
29
+
20
30
def unzip (iterable ):
21
31
return zip (* iterable )
22
32
@@ -31,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):
31
41
32
42
33
43
def check_result_variables (
34
- result : DataArray | Dataset , expected : Mapping [str , Any ], kind : str
44
+ result : DataArray | Dataset ,
45
+ expected : ExpectedDict ,
46
+ kind : Literal ["coords" , "data_vars" ],
35
47
):
36
48
if kind == "coords" :
37
49
nice_str = "coordinate"
@@ -254,7 +266,7 @@ def _wrapper(
254
266
args : list ,
255
267
kwargs : dict ,
256
268
arg_is_array : Iterable [bool ],
257
- expected : dict ,
269
+ expected : ExpectedDict ,
258
270
):
259
271
"""
260
272
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
@@ -345,33 +357,45 @@ def _wrapper(
345
357
for arg in aligned
346
358
)
347
359
360
+ merged_coordinates = merge ([arg .coords for arg in aligned ]).coords
361
+
348
362
_ , npargs = unzip (
349
363
sorted (list (zip (xarray_indices , xarray_objs )) + others , key = lambda x : x [0 ])
350
364
)
351
365
352
366
# check that chunk sizes are compatible
353
367
input_chunks = dict (npargs [0 ].chunks )
354
- input_indexes = dict (npargs [0 ]._indexes )
355
368
for arg in xarray_objs [1 :]:
356
369
assert_chunks_compatible (npargs [0 ], arg )
357
370
input_chunks .update (arg .chunks )
358
- input_indexes .update (arg ._indexes )
359
371
372
+ coordinates : Coordinates
360
373
if template is None :
361
374
# infer template by providing zero-shaped arrays
362
375
template = infer_template (func , aligned [0 ], * args , ** kwargs )
363
- template_indexes = set (template ._indexes )
364
- preserved_indexes = template_indexes & set (input_indexes )
365
- new_indexes = template_indexes - set (input_indexes )
366
- indexes = {dim : input_indexes [dim ] for dim in preserved_indexes }
367
- indexes .update ({k : template ._indexes [k ] for k in new_indexes })
376
+ template_coords = set (template .coords )
377
+ preserved_coord_vars = template_coords & set (merged_coordinates )
378
+ new_coord_vars = template_coords - set (merged_coordinates )
379
+
380
+ preserved_coords = merged_coordinates .to_dataset ()[preserved_coord_vars ]
381
+ # preserved_coords contains all coordinates bariables that share a dimension
382
+ # with any index variable in preserved_indexes
383
+ # Drop any unneeded vars in a second pass, this is required for e.g.
384
+ # if the mapped function were to drop a non-dimension coordinate variable.
385
+ preserved_coords = preserved_coords .drop_vars (
386
+ tuple (k for k in preserved_coords .variables if k not in template_coords )
387
+ )
388
+
389
+ coordinates = merge (
390
+ (preserved_coords , template .coords .to_dataset ()[new_coord_vars ])
391
+ ).coords
368
392
output_chunks : Mapping [Hashable , tuple [int , ...]] = {
369
393
dim : input_chunks [dim ] for dim in template .dims if dim in input_chunks
370
394
}
371
395
372
396
else :
373
397
# template xarray object has been provided with proper sizes and chunk shapes
374
- indexes = dict ( template ._indexes )
398
+ coordinates = template .coords
375
399
output_chunks = template .chunksizes
376
400
if not output_chunks :
377
401
raise ValueError (
@@ -473,6 +497,9 @@ def subset_dataset_to_block(
473
497
474
498
return (Dataset , (dict , data_vars ), (dict , coords ), dataset .attrs )
475
499
500
+ # variable names that depend on the computation. Currently, indexes
501
+ # cannot be modified in the mapped function, so we exclude thos
502
+ computed_variables = set (template .variables ) - set (coordinates .xindexes )
476
503
# iterate over all possible chunk combinations
477
504
for chunk_tuple in itertools .product (* ichunk .values ()):
478
505
# mapping from dimension name to chunk index
@@ -485,29 +512,32 @@ def subset_dataset_to_block(
485
512
for isxr , arg in zip (is_xarray , npargs )
486
513
]
487
514
488
- # expected["shapes", "coords", "data_vars", "indexes"] are used to
489
515
# raise nice error messages in _wrapper
490
- expected = {}
491
- # input chunk 0 along a dimension maps to output chunk 0 along the same dimension
492
- # even if length of dimension is changed by the applied function
493
- expected ["shapes" ] = {
494
- k : output_chunks [k ][v ] for k , v in chunk_index .items () if k in output_chunks
495
- }
496
- expected ["data_vars" ] = set (template .data_vars .keys ()) # type: ignore[assignment]
497
- expected ["coords" ] = set (template .coords .keys ()) # type: ignore[assignment]
498
- expected ["indexes" ] = {
499
- dim : indexes [dim ][_get_chunk_slicer (dim , chunk_index , output_chunk_bounds )]
500
- for dim in indexes
516
+ expected : ExpectedDict = {
517
+ # input chunk 0 along a dimension maps to output chunk 0 along the same dimension
518
+ # even if length of dimension is changed by the applied function
519
+ "shapes" : {
520
+ k : output_chunks [k ][v ]
521
+ for k , v in chunk_index .items ()
522
+ if k in output_chunks
523
+ },
524
+ "data_vars" : set (template .data_vars .keys ()),
525
+ "coords" : set (template .coords .keys ()),
526
+ "indexes" : {
527
+ dim : coordinates .xindexes [dim ][
528
+ _get_chunk_slicer (dim , chunk_index , output_chunk_bounds )
529
+ ]
530
+ for dim in coordinates .xindexes
531
+ },
501
532
}
502
533
503
534
from_wrapper = (gname ,) + chunk_tuple
504
535
graph [from_wrapper ] = (_wrapper , func , blocked_args , kwargs , is_array , expected )
505
536
506
537
# mapping from variable name to dask graph key
507
538
var_key_map : dict [Hashable , str ] = {}
508
- for name , variable in template .variables .items ():
509
- if name in indexes :
510
- continue
539
+ for name in computed_variables :
540
+ variable = template .variables [name ]
511
541
gname_l = f"{ name } -{ gname } "
512
542
var_key_map [name ] = gname_l
513
543
@@ -543,12 +573,7 @@ def subset_dataset_to_block(
543
573
},
544
574
)
545
575
546
- # TODO: benbovy - flexible indexes: make it work with custom indexes
547
- # this will need to pass both indexes and coords to the Dataset constructor
548
- result = Dataset (
549
- coords = {k : idx .to_pandas_index () for k , idx in indexes .items ()},
550
- attrs = template .attrs ,
551
- )
576
+ result = Dataset (coords = coordinates , attrs = template .attrs )
552
577
553
578
for index in result ._indexes :
554
579
result [index ].attrs = template [index ].attrs
0 commit comments