Skip to content

Commit 1334009

Browse files
committed
Trim some more.
1 parent a106569 commit 1334009

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

xarray/core/parallel.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ def _wrapper(
270270

271271
result = func(*converted_args, **kwargs)
272272

273+
merged_coordinates = merge(
274+
[arg.coords for arg in args if isinstance(arg, (Dataset, DataArray))]
275+
).coords
276+
273277
# check all dims are present
274278
missing_dimensions = set(expected["shapes"]) - set(result.sizes)
275279
if missing_dimensions:
@@ -285,12 +289,15 @@ def _wrapper(
285289
f"Received dimension {name!r} of length {result.sizes[name]}. "
286290
f"Expected length {expected['shapes'][name]}."
287291
)
288-
if name in expected["indexes"]:
289-
expected_index = expected["indexes"][name]
290-
if not index.equals(expected_index):
291-
raise ValueError(
292-
f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead."
293-
)
292+
293+
merged_indexes = collections.ChainMap(
294+
expected["indexes"], merged_coordinates.xindexes
295+
)
296+
expected_index = merged_indexes.get(name, None)
297+
if expected_index is not None and not index.equals(expected_index):
298+
raise ValueError(
299+
f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead."
300+
)
294301

295302
# check that all expected variables were returned
296303
check_result_variables(result, expected, "coords")
@@ -364,11 +371,11 @@ def _wrapper(
364371
# infer template by providing zero-shaped arrays
365372
template = infer_template(func, aligned[0], *args, **kwargs)
366373
template_coords = set(template.coords)
367-
preserved_indexes = template_coords & set(merged_coordinates)
368-
new_indexes = template_coords - set(merged_coordinates)
374+
preserved_coord_names = template_coords & set(merged_coordinates)
375+
new_indexes = set(template.xindexes) - set(merged_coordinates)
369376

370-
preserved_coords = merged_coordinates.to_dataset()[preserved_indexes]
371-
# preserved_coords contains all coordinates bariables that share a dimension
377+
preserved_coords = merged_coordinates.to_dataset()[preserved_coord_names]
378+
# preserved_coords contains all coordinate variables that share a dimension
372379
# with any index variable in preserved_indexes
373380
# Drop any unneeded vars in a second pass, this is required for e.g.
374381
# if the mapped function were to drop a non-dimension coordinate variable.
@@ -393,6 +400,13 @@ def _wrapper(
393400
" Please construct a template with appropriately chunked dask arrays."
394401
)
395402

403+
new_indexes = set(template.xindexes) - set(merged_coordinates)
404+
modified_indexes = set(
405+
name
406+
for name, xindex in coordinates.xindexes.items()
407+
if not xindex.equals(merged_coordinates.xindexes.get(name, None))
408+
)
409+
396410
for dim in output_chunks:
397411
if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]):
398412
raise ValueError(
@@ -521,9 +535,14 @@ def subset_dataset_to_block(
521535
}
522536
expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
523537
expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
538+
539+
# Minimize duplication due to broadcasting by only including any new or modified indexes
540+
# Others can be inferred by inputs to wrapper (GH8412)
524541
expected["indexes"] = {
525-
dim: index[_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)]
526-
for dim, index in coordinates.xindexes.items()
542+
name: coordinates.xindexes[name][
543+
_get_chunk_slicer(name, chunk_index, output_chunk_bounds)
544+
]
545+
for name in (new_indexes | modified_indexes)
527546
}
528547

529548
from_wrapper = (gname,) + chunk_tuple

0 commit comments

Comments
 (0)