@@ -270,6 +270,10 @@ def _wrapper(
270
270
271
271
result = func (* converted_args , ** kwargs )
272
272
273
+ merged_coordinates = merge (
274
+ [arg .coords for arg in args if isinstance (arg , (Dataset , DataArray ))]
275
+ ).coords
276
+
273
277
# check all dims are present
274
278
missing_dimensions = set (expected ["shapes" ]) - set (result .sizes )
275
279
if missing_dimensions :
@@ -285,12 +289,15 @@ def _wrapper(
285
289
f"Received dimension { name !r} of length { result .sizes [name ]} . "
286
290
f"Expected length { expected ['shapes' ][name ]} ."
287
291
)
288
- if name in expected ["indexes" ]:
289
- expected_index = expected ["indexes" ][name ]
290
- if not index .equals (expected_index ):
291
- raise ValueError (
292
- f"Expected index { name !r} to be { expected_index !r} . Received { index !r} instead."
293
- )
292
+
293
+ merged_indexes = collections .ChainMap (
294
+ expected ["indexes" ], merged_coordinates .xindexes
295
+ )
296
+ expected_index = merged_indexes .get (name , None )
297
+ if expected_index is not None and not index .equals (expected_index ):
298
+ raise ValueError (
299
+ f"Expected index { name !r} to be { expected_index !r} . Received { index !r} instead."
300
+ )
294
301
295
302
# check that all expected variables were returned
296
303
check_result_variables (result , expected , "coords" )
@@ -364,11 +371,11 @@ def _wrapper(
364
371
# infer template by providing zero-shaped arrays
365
372
template = infer_template (func , aligned [0 ], * args , ** kwargs )
366
373
template_coords = set (template .coords )
367
- preserved_indexes = template_coords & set (merged_coordinates )
368
- new_indexes = template_coords - set (merged_coordinates )
374
+ preserved_coord_names = template_coords & set (merged_coordinates )
375
+ new_indexes = set ( template . xindexes ) - set (merged_coordinates )
369
376
370
- preserved_coords = merged_coordinates .to_dataset ()[preserved_indexes ]
371
- # preserved_coords contains all coordinates bariables that share a dimension
377
+ preserved_coords = merged_coordinates .to_dataset ()[preserved_coord_names ]
378
+ # preserved_coords contains all coordinate variables that share a dimension
372
379
# with any index variable in preserved_indexes
373
380
# Drop any unneeded vars in a second pass, this is required for e.g.
374
381
# if the mapped function were to drop a non-dimension coordinate variable.
@@ -393,6 +400,13 @@ def _wrapper(
393
400
" Please construct a template with appropriately chunked dask arrays."
394
401
)
395
402
403
+ new_indexes = set (template .xindexes ) - set (merged_coordinates )
404
+ modified_indexes = set (
405
+ name
406
+ for name , xindex in coordinates .xindexes .items ()
407
+ if not xindex .equals (merged_coordinates .xindexes .get (name , None ))
408
+ )
409
+
396
410
for dim in output_chunks :
397
411
if dim in input_chunks and len (input_chunks [dim ]) != len (output_chunks [dim ]):
398
412
raise ValueError (
@@ -521,9 +535,14 @@ def subset_dataset_to_block(
521
535
}
522
536
expected ["data_vars" ] = set (template .data_vars .keys ()) # type: ignore[assignment]
523
537
expected ["coords" ] = set (template .coords .keys ()) # type: ignore[assignment]
538
+
539
+ # Minimize duplication due to broadcasting by only including any new or modified indexes
540
+ # Others can be inferred by inputs to wrapper (GH8412)
524
541
expected ["indexes" ] = {
525
- dim : index [_get_chunk_slicer (dim , chunk_index , output_chunk_bounds )]
526
- for dim , index in coordinates .xindexes .items ()
542
+ name : coordinates .xindexes [name ][
543
+ _get_chunk_slicer (name , chunk_index , output_chunk_bounds )
544
+ ]
545
+ for name in (new_indexes | modified_indexes )
527
546
}
528
547
529
548
from_wrapper = (gname ,) + chunk_tuple
0 commit comments