Skip to content

Merging coordinates computes array values #9481

Open
@shoyer

Description

@shoyer

What is your issue?

Xarray's default handling of coordinate merging (e.g., as used in arithmetic) computes array values, which is not ideal.

(There is probably an older issue to discuss this, but I couldn't find it with a quick search)

This is easiest to see using Dask:

import xarray
import numpy as np
import dask.array

def r(*args):
    raise RuntimeError('data accessed')

x1 = dask.array.from_delayed(dask.delayed(r)(1), shape=(), dtype=np.float64)
x2 = dask.array.from_delayed(dask.delayed(r)(2), shape=(), dtype=np.float64)
ds1 = xarray.Dataset(coords={'x': x1})
ds2 = xarray.Dataset(coords={'x': x2})
ds1 + ds2  # RuntimeError: data accessed

Traceback:

RuntimeError                              Traceback (most recent call last)
Cell In[2], line 12
     10 ds1 = xarray.Dataset(coords={'x': x1})
     11 ds2 = xarray.Dataset(coords={'x': x2})
---> 12 ds1 + ds2

File ~/dev/xarray/xarray/core/_typed_ops.py:35, in DatasetOpsMixin.__add__(self, other)
     34 def __add__(self, other: DsCompatible) -> Self:
---> 35     return self._binary_op(other, operator.add)

File ~/dev/xarray/xarray/core/dataset.py:7783, in Dataset._binary_op(self, other, f, reflexive, join)
   7781     self, other = align(self, other, join=align_type, copy=False)
   7782 g = f if not reflexive else lambda x, y: f(y, x)
-> 7783 ds = self._calculate_binary_op(g, other, join=align_type)
   7784 keep_attrs = _get_keep_attrs(default=False)
   7785 if keep_attrs:

File ~/dev/xarray/xarray/core/dataset.py:7844, in Dataset._calculate_binary_op(self, f, other, join, inplace)
   7841     return type(self)(new_data_vars)
   7843 other_coords: Coordinates | None = getattr(other, "coords", None)
-> 7844 ds = self.coords.merge(other_coords)
   7846 if isinstance(other, Dataset):
   7847     new_vars = apply_over_both(
   7848         self.data_vars, other.data_vars, self.variables, other.variables
   7849     )

File ~/dev/xarray/xarray/core/coordinates.py:522, in Coordinates.merge(self, other)
    519 if not isinstance(other, Coordinates):
    520     other = Dataset(coords=other).coords
--> 522 coords, indexes = merge_coordinates_without_align([self, other])
    523 coord_names = set(coords)
    524 return Dataset._construct_direct(
    525     variables=coords, coord_names=coord_names, indexes=indexes
    526 )

File ~/dev/xarray/xarray/core/merge.py:413, in merge_coordinates_without_align(objects, prioritized, exclude_dims, combine_attrs)
    409     filtered = collected
    411 # TODO: indexes should probably be filtered in collected elements
    412 # before merging them
--> 413 merged_coords, merged_indexes = merge_collected(
    414     filtered, prioritized, combine_attrs=combine_attrs
    415 )
    416 merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords))
    418 return merged_coords, merged_indexes

File ~/dev/xarray/xarray/core/merge.py:290, in merge_collected(grouped, prioritized, compat, combine_attrs, equals)
    288 variables = [variable for variable, _ in elements_list]
    289 try:
--> 290     merged_vars[name] = unique_variable(
    291         name, variables, compat, equals.get(name, None)
    292     )
    293 except MergeError:
    294     if compat != "minimal":
    295         # we need more than "minimal" compatibility (for which
    296         # we drop conflicting coordinates)

File ~/dev/xarray/xarray/core/merge.py:137, in unique_variable(name, variables, compat, equals)
    133         break
    135 if equals is None:
    136     # now compare values with minimum number of computes
--> 137     out = out.compute()
    138     for var in variables[1:]:
    139         equals = getattr(out, compat)(var)

File ~/dev/xarray/xarray/core/variable.py:1003, in Variable.compute(self, **kwargs)
    985 """Manually trigger loading of this variable's data from disk or a
    986 remote source into memory and return a new variable. The original is
    987 left unaltered.
   (...)
   1000 dask.array.compute
   1001 """
   1002 new = self.copy(deep=False)
-> 1003 return new.load(**kwargs)

File ~/dev/xarray/xarray/core/variable.py:981, in Variable.load(self, **kwargs)
    964 def load(self, **kwargs):
    965     """Manually trigger loading of this variable's data from disk or a
    966     remote source into memory and return this variable.
    967
   (...)
    979     dask.array.compute
    980     """
--> 981     self._data = to_duck_array(self._data, **kwargs)
    982     return self

File ~/dev/xarray/xarray/namedarray/pycompat.py:130, in to_duck_array(data, **kwargs)
    128 if is_chunked_array(data):
    129     chunkmanager = get_chunked_array_type(data)
--> 130     loaded_data, *_ = chunkmanager.compute(data, **kwargs)  # type: ignore[var-annotated]
    131     return loaded_data
    133 if isinstance(data, ExplicitlyIndexed):

File ~/dev/xarray/xarray/namedarray/daskmanager.py:86, in DaskManager.compute(self, *data, **kwargs)
     81 def compute(
     82     self, *data: Any, **kwargs: Any
     83 ) -> tuple[np.ndarray[Any, _DType_co], ...]:
     84     from dask.array import compute
---> 86     return compute(*data, **kwargs)

File ~/miniconda3/envs/xarray-py312/lib/python3.12/site-packages/dask/base.py:664, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    661     postcomputes.append(x.__dask_postcompute__())
    663 with shorten_traceback():
--> 664     results = schedule(dsk, keys, **kwargs)
    666 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

Cell In[2], line 6, in r(*args)
      5 def r(*args):
----> 6     raise RuntimeError('data accessed')

RuntimeError: data accessed

We use this check to decide whether or not to preserve coordinates on result objects. If coordinates are the same from all arguments, they are kept. Otherwise they are dropped.

There are checks for matching array identity inside the Variable.equals, so in practice this is often skipped, but it isn't ideal. It's basically the only case where Xarray operations on Xarray objects requires computing lazy array values.

The simplest fix would be to switch the default compat option used for merging inside arithmetic (and other xarray internal operations) to "override", so coordinates are simply copied from the first object on which they appear. Would this make sense?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions