From b1a390d58af1f1a38f2be04fa9b53228610309de Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 12 Sep 2022 15:29:42 -0400 Subject: [PATCH 1/7] remove dask_array_type checks --- xarray/core/arithmetic.py | 3 +-- xarray/core/duck_array_ops.py | 4 ++-- xarray/core/nanops.py | 15 ++++----------- xarray/core/variable.py | 13 ++++--------- 4 files changed, 11 insertions(+), 24 deletions(-) diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index ff7af02abfc..b88189c5f32 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -16,7 +16,6 @@ from .common import ImplementsArrayReduce, ImplementsDatasetReduce from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type class SupportsArithmetic: @@ -38,7 +37,7 @@ class SupportsArithmetic: numbers.Number, bytes, str, - ) + dask_array_type + ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): from .computation import apply_ufunc diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 2bf05abb96a..40d4a720606 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -25,7 +25,7 @@ from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import cupy_array_type, dask_array_type, is_duck_dask_array +from .pycompat import cupy_array_type, is_duck_dask_array from .utils import is_duck_array try: @@ -113,7 +113,7 @@ def isnull(data): return zeros_like(data, dtype=bool) else: # at this point, array should have dtype=object - if isinstance(data, (np.ndarray, dask_array_type)): + if isinstance(data, np.ndarray): return pandas_isnull(data) else: # Not reachable yet, but intended for use with other duck array diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 920fd5a094e..d6f6af32769 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -6,7 +6,6 @@ from . import dtypes, nputils, utils from .duck_array_ops import count, fillna, isnull, where, where_method -from .pycompat import dask_array_type try: import dask.array as dask_array @@ -65,16 +64,14 @@ def nanmin(a, axis=None, out=None): if a.dtype.kind == "O": return _nan_minmax_object("min", dtypes.get_pos_infinity(a.dtype), a, axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanmin(a, axis=axis) + return nputils.nanmin(a, axis=axis) def nanmax(a, axis=None, out=None): if a.dtype.kind == "O": return _nan_minmax_object("max", dtypes.get_neg_infinity(a.dtype), a, axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanmax(a, axis=axis) + return nputils.nanmax(a, axis=axis) def nanargmin(a, axis=None): @@ -82,8 +79,7 @@ def nanargmin(a, axis=None): fill_value = dtypes.get_pos_infinity(a.dtype) return _nan_argminmax_object("argmin", fill_value, a, axis=axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanargmin(a, axis=axis) + return nputils.nanargmin(a, axis=axis) def nanargmax(a, axis=None): @@ -91,8 +87,7 @@ def nanargmax(a, axis=None): fill_value = dtypes.get_neg_infinity(a.dtype) return _nan_argminmax_object("argmax", fill_value, a, axis=axis) - module = dask_array if isinstance(a, dask_array_type) else nputils - return module.nanargmax(a, axis=axis) + return nputils.nanargmax(a, axis=axis) def nansum(a, axis=None, dtype=None, out=None, min_count=None): @@ -128,8 +123,6 @@ def nanmean(a, axis=None, dtype=None, out=None): warnings.filterwarnings( "ignore", r"Mean of empty slice", category=RuntimeWarning ) - if isinstance(a, dask_array_type): - return dask_array.nanmean(a, axis=axis, dtype=dtype) return np.nanmean(a, axis=axis, dtype=dtype) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c650464bbcb..709f6a660e1 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -39,7 +39,6 @@ from .pycompat import ( DuckArrayModule, cupy_array_type, - dask_array_type, integer_types, is_duck_dask_array, sparse_array_type, @@ -59,13 +58,9 @@ ) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( - ( - indexing.ExplicitlyIndexed, - pd.Index, - ) - + dask_array_type - + cupy_array_type -) + indexing.ExplicitlyIndexed, + pd.Index, +) + cupy_array_type # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -1150,7 +1145,7 @@ def to_numpy(self) -> np.ndarray: data = self.data # TODO first attempt to call .to_numpy() once some libraries implement it - if isinstance(data, dask_array_type): + if hasattr(data, "chunks"): data = data.compute() if isinstance(data, cupy_array_type): data = data.get() From 32a00fc5b56a9aad4beb391522af1c1f99d7bab4 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 12 Sep 2022 15:37:01 -0400 Subject: [PATCH 2/7] remove dask_array_compat --- xarray/core/dask_array_compat.py | 52 +------------------------------- xarray/core/nanops.py | 8 ----- 2 files changed, 1 insertion(+), 59 deletions(-) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index b53947e88eb..3144cd68cb2 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -6,57 +6,7 @@ try: import dask.array as da -except ImportError: - da = None # type: ignore - - -def _validate_pad_output_shape(input_shape, pad_width, output_shape): - """Validates the output shape of dask.array.pad, raising a RuntimeError if they do not match. - In the current versions of dask (2.2/2.4), dask.array.pad with mode='reflect' sometimes returns - an invalid shape. - """ - isint = lambda i: isinstance(i, int) - - if isint(pad_width): - pass - elif len(pad_width) == 2 and all(map(isint, pad_width)): - pad_width = sum(pad_width) - elif ( - len(pad_width) == len(input_shape) - and all(map(lambda x: len(x) == 2, pad_width)) - and all(isint(i) for p in pad_width for i in p) - ): - pad_width = np.sum(pad_width, axis=1) - else: - # unreachable: dask.array.pad should already have thrown an error - raise ValueError("Invalid value for `pad_width`") - - if not np.array_equal(np.array(input_shape) + pad_width, output_shape): - raise RuntimeError( - "There seems to be something wrong with the shape of the output of dask.array.pad, " - "try upgrading Dask, use a different pad mode e.g. mode='constant' or first convert " - "your DataArray/Dataset to one backed by a numpy array by calling the `compute()` method." - "See: https://github.com/dask/dask/issues/5303" - ) - -def pad(array, pad_width, mode="constant", **kwargs): - padded = da.pad(array, pad_width, mode=mode, **kwargs) - # workaround for inconsistency between numpy and dask: https://github.com/dask/dask/issues/5303 - if mode == "mean" and issubclass(array.dtype.type, np.integer): - warnings.warn( - 'dask.array.pad(mode="mean") converts integers to floats. xarray converts ' - "these floats back to integers to keep the interface consistent. There is a chance that " - "this introduces rounding errors. If you wish to keep the values as floats, first change " - "the dtype to a float before calling pad.", - UserWarning, - ) - return da.round(padded).astype(array.dtype) - _validate_pad_output_shape(array.shape, pad_width, padded.shape) - return padded - - -if da is not None: sliding_window_view = da.lib.stride_tricks.sliding_window_view -else: +except ImportError: sliding_window_view = None diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index d6f6af32769..4c71658b577 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -7,14 +7,6 @@ from . import dtypes, nputils, utils from .duck_array_ops import count, fillna, isnull, where, where_method -try: - import dask.array as dask_array - - from . import dask_array_compat -except ImportError: - dask_array = None # type: ignore[assignment] - dask_array_compat = None # type: ignore[assignment] - def _maybe_null_out(result, axis, mask, min_count=1): """ From f95eb1f30e50d71d8ff711e2d23bf71b139a68ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Sep 2022 19:39:13 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/dask_array_compat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 3144cd68cb2..e52ff7c7ed8 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,7 +1,5 @@ from __future__ import annotations -import warnings - import numpy as np try: From 9d84a847e07678e115bb1b3fa65981bec1b0af4e Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 12 Sep 2022 16:13:34 -0400 Subject: [PATCH 4/7] restore xarray wraps dask precedence --- xarray/core/arithmetic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index b88189c5f32..cad8ff7b5bb 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -16,6 +16,7 @@ from .common import ImplementsArrayReduce, ImplementsDatasetReduce from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods from .options import OPTIONS, _get_keep_attrs +from .pycompat import dask_array_type class SupportsArithmetic: @@ -37,6 +38,7 @@ class SupportsArithmetic: numbers.Number, bytes, str, + dask_array_type, ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): From 7b8e6d44b2d169fcddb86230627a16bbc6ed8e2d Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 12 Sep 2022 16:20:01 -0400 Subject: [PATCH 5/7] fully remove dask_aray_compat.py file --- xarray/core/dask_array_compat.py | 12 ------------ xarray/core/duck_array_ops.py | 6 ++++-- 2 files changed, 4 insertions(+), 14 deletions(-) delete mode 100644 xarray/core/dask_array_compat.py diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py deleted file mode 100644 index 3144cd68cb2..00000000000 --- a/xarray/core/dask_array_compat.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -import warnings - -import numpy as np - -try: - import dask.array as da - - sliding_window_view = da.lib.stride_tricks.sliding_window_view -except ImportError: - sliding_window_view = None diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 40d4a720606..4f42f497a69 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -23,7 +23,7 @@ from numpy import take, tensordot, transpose, unravel_index # noqa from numpy import where as _where -from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils +from . import dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast from .pycompat import cupy_array_type, is_duck_dask_array from .utils import is_duck_array @@ -631,7 +631,9 @@ def sliding_window_view(array, window_shape, axis): The rolling dimension will be placed at the last dimension. """ if is_duck_dask_array(array): - return dask_array_compat.sliding_window_view(array, window_shape, axis) + import dask.array as da + + return da.lib.stride_tricks.sliding_window_view(array, window_shape, axis) else: return npcompat.sliding_window_view(array, window_shape, axis) From 5e8cd0d3195ad1ee6bdd48a4cae926713348fbca Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 12 Sep 2022 16:52:58 -0400 Subject: [PATCH 6/7] remove one reference to cupy --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 709f6a660e1..cfe1db6a24e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -60,7 +60,7 @@ NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, pd.Index, -) + cupy_array_type +) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) From 8b1c5628792a2c27a08ea978cea95aaf9077f43e Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Mon, 12 Sep 2022 18:01:46 -0400 Subject: [PATCH 7/7] simplify __apply_ufunc__ check --- xarray/core/arithmetic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index cad8ff7b5bb..8d6a1d3ed8c 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -16,7 +16,7 @@ from .common import ImplementsArrayReduce, ImplementsDatasetReduce from .ops import IncludeCumMethods, IncludeNumpySameMethods, IncludeReduceMethods from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_array class SupportsArithmetic: @@ -33,12 +33,10 @@ class SupportsArithmetic: # TODO: allow extending this with some sort of registration system _HANDLED_TYPES = ( - np.ndarray, np.generic, numbers.Number, bytes, str, - dask_array_type, ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): @@ -47,7 +45,9 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. out = kwargs.get("out", ()) for x in inputs + out: - if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): + if not is_duck_array(x) and not isinstance( + x, self._HANDLED_TYPES + (SupportsArithmetic,) + ): return NotImplemented if ufunc.signature is not None: