Skip to content

Commit b018442

Browse files
authored
Add Ellipsis typehints (#7017)
* use ellipsis in dot * add ellipsis to more funcs
1 parent c0011e1 commit b018442

File tree

9 files changed

+51
-34
lines changed

9 files changed

+51
-34
lines changed

setup.cfg

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,11 @@ ignore =
152152
E501 # line too long - let black worry about that
153153
E731 # do not assign a lambda expression, use a def
154154
W503 # line break before binary operator
155-
exclude=
155+
exclude =
156156
.eggs
157157
doc
158+
builtins =
159+
ellipsis
158160

159161
[isort]
160162
profile = black

xarray/core/computation.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .coordinates import Coordinates
4141
from .dataarray import DataArray
4242
from .dataset import Dataset
43-
from .types import CombineAttrsOptions, JoinOptions
43+
from .types import CombineAttrsOptions, Ellipsis, JoinOptions
4444

4545
_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
4646
_DEFAULT_NAME = utils.ReprObject("<default-name>")
@@ -1622,7 +1622,11 @@ def cross(
16221622
return c
16231623

16241624

1625-
def dot(*arrays, dims=None, **kwargs):
1625+
def dot(
1626+
*arrays,
1627+
dims: str | Iterable[Hashable] | Ellipsis | None = None,
1628+
**kwargs: Any,
1629+
):
16261630
"""Generalized dot product for xarray objects. Like np.einsum, but
16271631
provides a simpler interface based on array dimensions.
16281632
@@ -1711,10 +1715,7 @@ def dot(*arrays, dims=None, **kwargs):
17111715
if len(arrays) == 0:
17121716
raise TypeError("At least one array should be given.")
17131717

1714-
if isinstance(dims, str):
1715-
dims = (dims,)
1716-
1717-
common_dims = set.intersection(*[set(arr.dims) for arr in arrays])
1718+
common_dims: set[Hashable] = set.intersection(*(set(arr.dims) for arr in arrays))
17181719
all_dims = []
17191720
for arr in arrays:
17201721
all_dims += [d for d in arr.dims if d not in all_dims]
@@ -1724,21 +1725,25 @@ def dot(*arrays, dims=None, **kwargs):
17241725

17251726
if dims is ...:
17261727
dims = all_dims
1728+
elif isinstance(dims, str):
1729+
dims = (dims,)
17271730
elif dims is None:
17281731
# find dimensions that occur more than one times
1729-
dim_counts = Counter()
1732+
dim_counts: Counter = Counter()
17301733
for arr in arrays:
17311734
dim_counts.update(arr.dims)
17321735
dims = tuple(d for d, c in dim_counts.items() if c > 1)
17331736

1734-
dims = tuple(dims) # make dims a tuple
1737+
dot_dims: set[Hashable] = set(dims) # type:ignore[arg-type]
17351738

17361739
# dimensions to be parallelized
1737-
broadcast_dims = tuple(d for d in all_dims if d in common_dims and d not in dims)
1740+
broadcast_dims = common_dims - dot_dims
17381741
input_core_dims = [
17391742
[d for d in arr.dims if d not in broadcast_dims] for arr in arrays
17401743
]
1741-
output_core_dims = [tuple(d for d in all_dims if d not in dims + broadcast_dims)]
1744+
output_core_dims = [
1745+
[d for d in all_dims if d not in dot_dims and d not in broadcast_dims]
1746+
]
17421747

17431748
# construct einsum subscripts, such as '...abc,...ab->...c'
17441749
# Note: input_core_dims are always moved to the last position

xarray/core/dataarray.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from .types import (
7979
CoarsenBoundaryOptions,
8080
DatetimeUnitOptions,
81+
Ellipsis,
8182
ErrorOptions,
8283
ErrorOptionsWithWarn,
8384
InterpOptions,
@@ -3769,7 +3770,7 @@ def imag(self: T_DataArray) -> T_DataArray:
37693770
def dot(
37703771
self: T_DataArray,
37713772
other: T_DataArray,
3772-
dims: Hashable | Sequence[Hashable] | None = None,
3773+
dims: str | Iterable[Hashable] | Ellipsis | None = None,
37733774
) -> T_DataArray:
37743775
"""Perform dot product of two DataArrays along their shared dims.
37753776
@@ -3779,7 +3780,7 @@ def dot(
37793780
----------
37803781
other : DataArray
37813782
The other array with which the dot product is performed.
3782-
dims : ..., Hashable or sequence of Hashable, optional
3783+
dims : ..., str or Iterable of Hashable, optional
37833784
Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions.
37843785
If not specified, then all the common dimensions are summed over.
37853786
@@ -4773,7 +4774,7 @@ def idxmax(
47734774
# https://github.com/python/mypy/issues/12846 is resolved
47744775
def argmin(
47754776
self,
4776-
dim: Hashable | Sequence[Hashable] | None = None,
4777+
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
47774778
axis: int | None = None,
47784779
keep_attrs: bool | None = None,
47794780
skipna: bool | None = None,
@@ -4878,7 +4879,7 @@ def argmin(
48784879
# https://github.com/python/mypy/issues/12846 is resolved
48794880
def argmax(
48804881
self,
4881-
dim: Hashable | Sequence[Hashable] = None,
4882+
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
48824883
axis: int | None = None,
48834884
keep_attrs: bool | None = None,
48844885
skipna: bool | None = None,

xarray/core/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
CombineAttrsOptions,
108108
CompatOptions,
109109
DatetimeUnitOptions,
110+
Ellipsis,
110111
ErrorOptions,
111112
ErrorOptionsWithWarn,
112113
InterpOptions,
@@ -4255,7 +4256,7 @@ def _get_stack_index(
42554256

42564257
def _stack_once(
42574258
self: T_Dataset,
4258-
dims: Sequence[Hashable],
4259+
dims: Sequence[Hashable | Ellipsis],
42594260
new_dim: Hashable,
42604261
index_cls: type[Index],
42614262
create_index: bool | None = True,
@@ -4314,10 +4315,10 @@ def _stack_once(
43144315

43154316
def stack(
43164317
self: T_Dataset,
4317-
dimensions: Mapping[Any, Sequence[Hashable]] | None = None,
4318+
dimensions: Mapping[Any, Sequence[Hashable | Ellipsis]] | None = None,
43184319
create_index: bool | None = True,
43194320
index_cls: type[Index] = PandasMultiIndex,
4320-
**dimensions_kwargs: Sequence[Hashable],
4321+
**dimensions_kwargs: Sequence[Hashable | Ellipsis],
43214322
) -> T_Dataset:
43224323
"""
43234324
Stack any number of existing dimensions into a single new dimension.

xarray/core/types.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import numpy as np
66

77
if TYPE_CHECKING:
8-
from .common import DataWithCoords
8+
9+
from .common import AbstractArray, DataWithCoords
910
from .dataarray import DataArray
1011
from .dataset import Dataset
1112
from .groupby import DataArrayGroupBy, GroupBy
@@ -29,13 +30,19 @@
2930
# from typing_extensions import Self
3031
# except ImportError:
3132
# Self: Any = None
32-
Self: Any = None
33+
Self = TypeVar("Self")
34+
35+
Ellipsis = ellipsis
36+
3337
else:
3438
Self: Any = None
39+
Ellipsis: Any = None
40+
3541

3642
T_Dataset = TypeVar("T_Dataset", bound="Dataset")
3743
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
3844
T_Variable = TypeVar("T_Variable", bound="Variable")
45+
T_Array = TypeVar("T_Array", bound="AbstractArray")
3946
T_Index = TypeVar("T_Index", bound="Index")
4047

4148
T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"])

xarray/core/variable.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171

7272
if TYPE_CHECKING:
7373
from .types import (
74+
Ellipsis,
7475
ErrorOptionsWithWarn,
7576
PadModeOptions,
7677
PadReflectOptions,
@@ -1478,7 +1479,7 @@ def roll(self, shifts=None, **shifts_kwargs):
14781479

14791480
def transpose(
14801481
self,
1481-
*dims: Hashable,
1482+
*dims: Hashable | Ellipsis,
14821483
missing_dims: ErrorOptionsWithWarn = "raise",
14831484
) -> Variable:
14841485
"""Return a new Variable object with transposed dimensions.
@@ -2555,7 +2556,7 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float):
25552556
def _unravel_argminmax(
25562557
self,
25572558
argminmax: str,
2558-
dim: Hashable | Sequence[Hashable] | None,
2559+
dim: Hashable | Sequence[Hashable] | Ellipsis | None,
25592560
axis: int | None,
25602561
keep_attrs: bool | None,
25612562
skipna: bool | None,
@@ -2624,7 +2625,7 @@ def _unravel_argminmax(
26242625

26252626
def argmin(
26262627
self,
2627-
dim: Hashable | Sequence[Hashable] = None,
2628+
dim: Hashable | Sequence[Hashable] | Ellipsis | None = None,
26282629
axis: int = None,
26292630
keep_attrs: bool = None,
26302631
skipna: bool = None,

xarray/core/weighted.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .computation import apply_ufunc, dot
1010
from .npcompat import ArrayLike
1111
from .pycompat import is_duck_dask_array
12-
from .types import T_Xarray
12+
from .types import Ellipsis, T_Xarray
1313

1414
# Weighted quantile methods are a subset of the numpy supported quantile methods.
1515
QUANTILE_METHODS = Literal[
@@ -206,7 +206,7 @@ def _check_dim(self, dim: Hashable | Iterable[Hashable] | None):
206206
def _reduce(
207207
da: DataArray,
208208
weights: DataArray,
209-
dim: Hashable | Iterable[Hashable] | None = None,
209+
dim: str | Iterable[Hashable] | Ellipsis | None = None,
210210
skipna: bool | None = None,
211211
) -> DataArray:
212212
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
@@ -227,7 +227,7 @@ def _reduce(
227227
return dot(da, weights, dims=dim)
228228

229229
def _sum_of_weights(
230-
self, da: DataArray, dim: Hashable | Iterable[Hashable] | None = None
230+
self, da: DataArray, dim: str | Iterable[Hashable] | None = None
231231
) -> DataArray:
232232
"""Calculate the sum of weights, accounting for missing values"""
233233

@@ -251,7 +251,7 @@ def _sum_of_weights(
251251
def _sum_of_squares(
252252
self,
253253
da: DataArray,
254-
dim: Hashable | Iterable[Hashable] | None = None,
254+
dim: str | Iterable[Hashable] | None = None,
255255
skipna: bool | None = None,
256256
) -> DataArray:
257257
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
@@ -263,7 +263,7 @@ def _sum_of_squares(
263263
def _weighted_sum(
264264
self,
265265
da: DataArray,
266-
dim: Hashable | Iterable[Hashable] | None = None,
266+
dim: str | Iterable[Hashable] | None = None,
267267
skipna: bool | None = None,
268268
) -> DataArray:
269269
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
@@ -273,7 +273,7 @@ def _weighted_sum(
273273
def _weighted_mean(
274274
self,
275275
da: DataArray,
276-
dim: Hashable | Iterable[Hashable] | None = None,
276+
dim: str | Iterable[Hashable] | None = None,
277277
skipna: bool | None = None,
278278
) -> DataArray:
279279
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
@@ -287,7 +287,7 @@ def _weighted_mean(
287287
def _weighted_var(
288288
self,
289289
da: DataArray,
290-
dim: Hashable | Iterable[Hashable] | None = None,
290+
dim: str | Iterable[Hashable] | None = None,
291291
skipna: bool | None = None,
292292
) -> DataArray:
293293
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
@@ -301,7 +301,7 @@ def _weighted_var(
301301
def _weighted_std(
302302
self,
303303
da: DataArray,
304-
dim: Hashable | Iterable[Hashable] | None = None,
304+
dim: str | Iterable[Hashable] | None = None,
305305
skipna: bool | None = None,
306306
) -> DataArray:
307307
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""

xarray/tests/test_computation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,7 +1732,7 @@ def apply_truncate_x_x_valid(obj):
17321732

17331733

17341734
@pytest.mark.parametrize("use_dask", [True, False])
1735-
def test_dot(use_dask) -> None:
1735+
def test_dot(use_dask: bool) -> None:
17361736
if use_dask:
17371737
if not has_dask:
17381738
pytest.skip("test for dask.")
@@ -1862,7 +1862,7 @@ def test_dot(use_dask) -> None:
18621862

18631863

18641864
@pytest.mark.parametrize("use_dask", [True, False])
1865-
def test_dot_align_coords(use_dask) -> None:
1865+
def test_dot_align_coords(use_dask: bool) -> None:
18661866
# GH 3694
18671867

18681868
if use_dask:

xarray/tests/test_dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6420,7 +6420,7 @@ def test_deepcopy_obj_array() -> None:
64206420
assert x0.values[0] is not x1.values[0]
64216421

64226422

6423-
def test_clip(da) -> None:
6423+
def test_clip(da: DataArray) -> None:
64246424
with raise_if_dask_computes():
64256425
result = da.clip(min=0.5)
64266426
assert result.min(...) >= 0.5

0 commit comments

Comments
 (0)