Skip to content

Commit e7e8c38

Browse files
authored
Start renaming dims to dim (#8487)
* Start renaming `dims` to `dim` Begins the process of #6646. I don't think it's feasible / enjoyable to do this for everything at once, so I would suggest we do it gradually, while keeping the warnings quite quiet, so by the time we convert to louder warnings, users can do a find/replace easily. * No deprecation for internal methods * Simplify typing
1 parent d3a1527 commit e7e8c38

File tree

9 files changed

+92
-55
lines changed

9 files changed

+92
-55
lines changed

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ Breaking changes
3838
Deprecations
3939
~~~~~~~~~~~~
4040

41+
- As part of an effort to standardize the API, we're renaming the ``dims``
42+
keyword arg to ``dim`` for the minority of functions which current use
43+
``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot`
44+
and we'll gradually roll this out across all functions. The warnings are
45+
currently ``PendingDeprecationWarning``, which are silenced by default. We'll
46+
convert these to ``DeprecationWarning`` in a future release.
47+
By `Maximilian Roos <https://github.com/max-sixty>`_.
4148

4249
Bug fixes
4350
~~~~~~~~~

xarray/core/alignment.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def assert_no_index_conflict(self) -> None:
324324
"- they may be used to reindex data along common dimensions"
325325
)
326326

327-
def _need_reindex(self, dims, cmp_indexes) -> bool:
327+
def _need_reindex(self, dim, cmp_indexes) -> bool:
328328
"""Whether or not we need to reindex variables for a set of
329329
matching indexes.
330330
@@ -340,14 +340,14 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
340340
return True
341341

342342
unindexed_dims_sizes = {}
343-
for dim in dims:
344-
if dim in self.unindexed_dim_sizes:
345-
sizes = self.unindexed_dim_sizes[dim]
343+
for d in dim:
344+
if d in self.unindexed_dim_sizes:
345+
sizes = self.unindexed_dim_sizes[d]
346346
if len(sizes) > 1:
347347
# reindex if different sizes are found for unindexed dims
348348
return True
349349
else:
350-
unindexed_dims_sizes[dim] = next(iter(sizes))
350+
unindexed_dims_sizes[d] = next(iter(sizes))
351351

352352
if unindexed_dims_sizes:
353353
indexed_dims_sizes = {}
@@ -356,8 +356,8 @@ def _need_reindex(self, dims, cmp_indexes) -> bool:
356356
for var in index_vars.values():
357357
indexed_dims_sizes.update(var.sizes)
358358

359-
for dim, size in unindexed_dims_sizes.items():
360-
if indexed_dims_sizes.get(dim, -1) != size:
359+
for d, size in unindexed_dims_sizes.items():
360+
if indexed_dims_sizes.get(d, -1) != size:
361361
# reindex if unindexed dimension size doesn't match
362362
return True
363363

xarray/core/computation.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from xarray.core.types import Dims, T_DataArray
2727
from xarray.core.utils import is_dict_like, is_scalar
2828
from xarray.core.variable import Variable
29+
from xarray.util.deprecation_helpers import deprecate_dims
2930

3031
if TYPE_CHECKING:
3132
from xarray.core.coordinates import Coordinates
@@ -1691,9 +1692,10 @@ def cross(
16911692
return c
16921693

16931694

1695+
@deprecate_dims
16941696
def dot(
16951697
*arrays,
1696-
dims: Dims = None,
1698+
dim: Dims = None,
16971699
**kwargs: Any,
16981700
):
16991701
"""Generalized dot product for xarray objects. Like ``np.einsum``, but
@@ -1703,7 +1705,7 @@ def dot(
17031705
----------
17041706
*arrays : DataArray or Variable
17051707
Arrays to compute.
1706-
dims : str, iterable of hashable, "..." or None, optional
1708+
dim : str, iterable of hashable, "..." or None, optional
17071709
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
17081710
If not specified, then all the common dimensions are summed over.
17091711
**kwargs : dict
@@ -1756,18 +1758,18 @@ def dot(
17561758
[3, 4, 5]])
17571759
Dimensions without coordinates: c, d
17581760
1759-
>>> xr.dot(da_a, da_b, dims=["a", "b"])
1761+
>>> xr.dot(da_a, da_b, dim=["a", "b"])
17601762
<xarray.DataArray (c: 2)>
17611763
array([110, 125])
17621764
Dimensions without coordinates: c
17631765
1764-
>>> xr.dot(da_a, da_b, dims=["a"])
1766+
>>> xr.dot(da_a, da_b, dim=["a"])
17651767
<xarray.DataArray (b: 2, c: 2)>
17661768
array([[40, 46],
17671769
[70, 79]])
17681770
Dimensions without coordinates: b, c
17691771
1770-
>>> xr.dot(da_a, da_b, da_c, dims=["b", "c"])
1772+
>>> xr.dot(da_a, da_b, da_c, dim=["b", "c"])
17711773
<xarray.DataArray (a: 3, d: 3)>
17721774
array([[ 9, 14, 19],
17731775
[ 93, 150, 207],
@@ -1779,7 +1781,7 @@ def dot(
17791781
array([110, 125])
17801782
Dimensions without coordinates: c
17811783
1782-
>>> xr.dot(da_a, da_b, dims=...)
1784+
>>> xr.dot(da_a, da_b, dim=...)
17831785
<xarray.DataArray ()>
17841786
array(235)
17851787
"""
@@ -1803,18 +1805,18 @@ def dot(
18031805
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
18041806
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}
18051807

1806-
if dims is ...:
1807-
dims = all_dims
1808-
elif isinstance(dims, str):
1809-
dims = (dims,)
1810-
elif dims is None:
1808+
if dim is ...:
1809+
dim = all_dims
1810+
elif isinstance(dim, str):
1811+
dim = (dim,)
1812+
elif dim is None:
18111813
# find dimensions that occur more than one times
18121814
dim_counts: Counter = Counter()
18131815
for arr in arrays:
18141816
dim_counts.update(arr.dims)
1815-
dims = tuple(d for d, c in dim_counts.items() if c > 1)
1817+
dim = tuple(d for d, c in dim_counts.items() if c > 1)
18161818

1817-
dot_dims: set[Hashable] = set(dims)
1819+
dot_dims: set[Hashable] = set(dim)
18181820

18191821
# dimensions to be parallelized
18201822
broadcast_dims = common_dims - dot_dims

xarray/core/dataarray.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
)
6666
from xarray.plot.accessor import DataArrayPlotAccessor
6767
from xarray.plot.utils import _get_units_from_attrs
68-
from xarray.util.deprecation_helpers import _deprecate_positional_args
68+
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
6969

7070
if TYPE_CHECKING:
7171
from typing import TypeVar, Union
@@ -115,14 +115,14 @@
115115
T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset])
116116

117117

118-
def _check_coords_dims(shape, coords, dims):
119-
sizes = dict(zip(dims, shape))
118+
def _check_coords_dims(shape, coords, dim):
119+
sizes = dict(zip(dim, shape))
120120
for k, v in coords.items():
121-
if any(d not in dims for d in v.dims):
121+
if any(d not in dim for d in v.dims):
122122
raise ValueError(
123123
f"coordinate {k} has dimensions {v.dims}, but these "
124124
"are not a subset of the DataArray "
125-
f"dimensions {dims}"
125+
f"dimensions {dim}"
126126
)
127127

128128
for d, s in v.sizes.items():
@@ -4895,10 +4895,11 @@ def imag(self) -> Self:
48954895
"""
48964896
return self._replace(self.variable.imag)
48974897

4898+
@deprecate_dims
48984899
def dot(
48994900
self,
49004901
other: T_Xarray,
4901-
dims: Dims = None,
4902+
dim: Dims = None,
49024903
) -> T_Xarray:
49034904
"""Perform dot product of two DataArrays along their shared dims.
49044905
@@ -4908,7 +4909,7 @@ def dot(
49084909
----------
49094910
other : DataArray
49104911
The other array with which the dot product is performed.
4911-
dims : ..., str, Iterable of Hashable or None, optional
4912+
dim : ..., str, Iterable of Hashable or None, optional
49124913
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
49134914
If not specified, then all the common dimensions are summed over.
49144915
@@ -4947,7 +4948,7 @@ def dot(
49474948
if not isinstance(other, DataArray):
49484949
raise TypeError("dot only operates on DataArrays.")
49494950

4950-
return computation.dot(self, other, dims=dims)
4951+
return computation.dot(self, other, dim=dim)
49514952

49524953
def sortby(
49534954
self,

xarray/core/variable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,15 +1541,15 @@ def stack(self, dimensions=None, **dimensions_kwargs):
15411541
result = result._stack_once(dims, new_dim)
15421542
return result
15431543

1544-
def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self:
1544+
def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self:
15451545
"""
15461546
Unstacks the variable without needing an index.
15471547
15481548
Unlike `_unstack_once`, this function requires the existing dimension to
15491549
contain the full product of the new dimensions.
15501550
"""
1551-
new_dim_names = tuple(dims.keys())
1552-
new_dim_sizes = tuple(dims.values())
1551+
new_dim_names = tuple(dim.keys())
1552+
new_dim_sizes = tuple(dim.values())
15531553

15541554
if old_dim not in self.dims:
15551555
raise ValueError(f"invalid existing dimension: {old_dim}")

xarray/core/weighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _reduce(
228228

229229
# `dot` does not broadcast arrays, so this avoids creating a large
230230
# DataArray (if `weights` has additional dimensions)
231-
return dot(da, weights, dims=dim)
231+
return dot(da, weights, dim=dim)
232232

233233
def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
234234
"""Calculate the sum of weights, accounting for missing values"""

xarray/tests/test_computation.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,7 +1936,7 @@ def test_dot(use_dask: bool) -> None:
19361936
da_a = da_a.chunk({"a": 3})
19371937
da_b = da_b.chunk({"a": 3})
19381938
da_c = da_c.chunk({"c": 3})
1939-
actual = xr.dot(da_a, da_b, dims=["a", "b"])
1939+
actual = xr.dot(da_a, da_b, dim=["a", "b"])
19401940
assert actual.dims == ("c",)
19411941
assert (actual.data == np.einsum("ij,ijk->k", a, b)).all()
19421942
assert isinstance(actual.variable.data, type(da_a.variable.data))
@@ -1960,33 +1960,33 @@ def test_dot(use_dask: bool) -> None:
19601960
if use_dask:
19611961
da_a = da_a.chunk({"a": 3})
19621962
da_b = da_b.chunk({"a": 3})
1963-
actual = xr.dot(da_a, da_b, dims=["b"])
1963+
actual = xr.dot(da_a, da_b, dim=["b"])
19641964
assert actual.dims == ("a", "c")
19651965
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
19661966
assert isinstance(actual.variable.data, type(da_a.variable.data))
19671967

1968-
actual = xr.dot(da_a, da_b, dims=["b"])
1968+
actual = xr.dot(da_a, da_b, dim=["b"])
19691969
assert actual.dims == ("a", "c")
19701970
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
19711971

1972-
actual = xr.dot(da_a, da_b, dims="b")
1972+
actual = xr.dot(da_a, da_b, dim="b")
19731973
assert actual.dims == ("a", "c")
19741974
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
19751975

1976-
actual = xr.dot(da_a, da_b, dims="a")
1976+
actual = xr.dot(da_a, da_b, dim="a")
19771977
assert actual.dims == ("b", "c")
19781978
assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all()
19791979

1980-
actual = xr.dot(da_a, da_b, dims="c")
1980+
actual = xr.dot(da_a, da_b, dim="c")
19811981
assert actual.dims == ("a", "b")
19821982
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()
19831983

1984-
actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"])
1984+
actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"])
19851985
assert actual.dims == ("c", "e")
19861986
assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all()
19871987

19881988
# should work with tuple
1989-
actual = xr.dot(da_a, da_b, dims=("c",))
1989+
actual = xr.dot(da_a, da_b, dim=("c",))
19901990
assert actual.dims == ("a", "b")
19911991
assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all()
19921992

@@ -1996,47 +1996,47 @@ def test_dot(use_dask: bool) -> None:
19961996
assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all()
19971997

19981998
# 1 array summation
1999-
actual = xr.dot(da_a, dims="a")
1999+
actual = xr.dot(da_a, dim="a")
20002000
assert actual.dims == ("b",)
20012001
assert (actual.data == np.einsum("ij->j ", a)).all()
20022002

20032003
# empty dim
2004-
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a")
2004+
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a")
20052005
assert actual.dims == ("b",)
20062006
assert (actual.data == np.zeros(actual.shape)).all()
20072007

20082008
# Ellipsis (...) sums over all dimensions
2009-
actual = xr.dot(da_a, da_b, dims=...)
2009+
actual = xr.dot(da_a, da_b, dim=...)
20102010
assert actual.dims == ()
20112011
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()
20122012

2013-
actual = xr.dot(da_a, da_b, da_c, dims=...)
2013+
actual = xr.dot(da_a, da_b, da_c, dim=...)
20142014
assert actual.dims == ()
20152015
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()
20162016

2017-
actual = xr.dot(da_a, dims=...)
2017+
actual = xr.dot(da_a, dim=...)
20182018
assert actual.dims == ()
20192019
assert (actual.data == np.einsum("ij-> ", a)).all()
20202020

2021-
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...)
2021+
actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...)
20222022
assert actual.dims == ()
20232023
assert (actual.data == np.zeros(actual.shape)).all()
20242024

20252025
# Invalid cases
20262026
if not use_dask:
20272027
with pytest.raises(TypeError):
2028-
xr.dot(da_a, dims="a", invalid=None)
2028+
xr.dot(da_a, dim="a", invalid=None)
20292029
with pytest.raises(TypeError):
2030-
xr.dot(da_a.to_dataset(name="da"), dims="a")
2030+
xr.dot(da_a.to_dataset(name="da"), dim="a")
20312031
with pytest.raises(TypeError):
2032-
xr.dot(dims="a")
2032+
xr.dot(dim="a")
20332033

20342034
# einsum parameters
2035-
actual = xr.dot(da_a, da_b, dims=["b"], order="C")
2035+
actual = xr.dot(da_a, da_b, dim=["b"], order="C")
20362036
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
20372037
assert actual.values.flags["C_CONTIGUOUS"]
20382038
assert not actual.values.flags["F_CONTIGUOUS"]
2039-
actual = xr.dot(da_a, da_b, dims=["b"], order="F")
2039+
actual = xr.dot(da_a, da_b, dim=["b"], order="F")
20402040
assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all()
20412041
# dask converts Fortran arrays to C order when merging the final array
20422042
if not use_dask:
@@ -2078,7 +2078,7 @@ def test_dot_align_coords(use_dask: bool) -> None:
20782078
expected = (da_a * da_b).sum(["a", "b"])
20792079
xr.testing.assert_allclose(expected, actual)
20802080

2081-
actual = xr.dot(da_a, da_b, dims=...)
2081+
actual = xr.dot(da_a, da_b, dim=...)
20822082
expected = (da_a * da_b).sum()
20832083
xr.testing.assert_allclose(expected, actual)
20842084

xarray/tests/test_dataarray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3964,13 +3964,13 @@ def test_dot(self) -> None:
39643964
assert_equal(expected3, actual3)
39653965

39663966
# Ellipsis: all dims are shared
3967-
actual4 = da.dot(da, dims=...)
3967+
actual4 = da.dot(da, dim=...)
39683968
expected4 = da.dot(da)
39693969
assert_equal(expected4, actual4)
39703970

39713971
# Ellipsis: not all dims are shared
3972-
actual5 = da.dot(dm3, dims=...)
3973-
expected5 = da.dot(dm3, dims=("j", "x", "y", "z"))
3972+
actual5 = da.dot(dm3, dim=...)
3973+
expected5 = da.dot(dm3, dim=("j", "x", "y", "z"))
39743974
assert_equal(expected5, actual5)
39753975

39763976
with pytest.raises(NotImplementedError):

xarray/util/deprecation_helpers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from functools import wraps
3737
from typing import Callable, TypeVar
3838

39+
from xarray.core.utils import emit_user_level_warning
40+
3941
T = TypeVar("T", bound=Callable)
4042

4143
POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
@@ -115,3 +117,28 @@ def inner(*args, **kwargs):
115117
return inner
116118

117119
return _decorator
120+
121+
122+
def deprecate_dims(func: T) -> T:
123+
"""
124+
For functions that previously took `dims` as a kwarg, and have now transitioned to
125+
`dim`. This decorator will issue a warning if `dims` is passed while forwarding it
126+
to `dim`.
127+
"""
128+
129+
@wraps(func)
130+
def wrapper(*args, **kwargs):
131+
if "dims" in kwargs:
132+
emit_user_level_warning(
133+
"The `dims` argument has been renamed to `dim`, and will be removed "
134+
"in the future. This renaming is taking place throughout xarray over the "
135+
"next few releases.",
136+
# Upgrade to `DeprecationWarning` in the future, when the renaming is complete.
137+
PendingDeprecationWarning,
138+
)
139+
kwargs["dim"] = kwargs.pop("dims")
140+
return func(*args, **kwargs)
141+
142+
# We're quite confident we're just returning `T` from this function, so it's fine to ignore typing
143+
# within the function.
144+
return wrapper # type: ignore

0 commit comments

Comments
 (0)