Skip to content

Commit 5fd50c5

Browse files
authored
update mypy to 1.13 (#9687)
* update mypy to 1.13 * workaround netcdf4 typing issue * fix further mypy issues * fix nested paths and add real nestedSequence type * fix nested paths and add real nestedSequence type * fix some mypy errors that only appear local? * fix infinite recursion (damn str is sequence)
1 parent 7467b1e commit 5fd50c5

File tree

9 files changed

+116
-49
lines changed

9 files changed

+116
-49
lines changed

.github/workflows/ci-additional.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ jobs:
122122
python xarray/util/print_versions.py
123123
- name: Install mypy
124124
run: |
125-
python -m pip install "mypy==1.11.2" --force-reinstall
125+
python -m pip install "mypy==1.13" --force-reinstall
126126
127127
- name: Run mypy
128128
run: |
@@ -176,7 +176,7 @@ jobs:
176176
python xarray/util/print_versions.py
177177
- name: Install mypy
178178
run: |
179-
python -m pip install "mypy==1.11.2" --force-reinstall
179+
python -m pip install "mypy==1.13" --force-reinstall
180180
181181
- name: Run mypy
182182
run: |

xarray/backends/api.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,7 @@ def open_groups(
13381338

13391339

13401340
def open_mfdataset(
1341-
paths: str | NestedSequence[str | os.PathLike],
1341+
paths: str | os.PathLike | NestedSequence[str | os.PathLike],
13421342
chunks: T_Chunks | None = None,
13431343
concat_dim: (
13441344
str
@@ -1541,6 +1541,7 @@ def open_mfdataset(
15411541
if not paths:
15421542
raise OSError("no files to open")
15431543

1544+
paths1d: list[str]
15441545
if combine == "nested":
15451546
if isinstance(concat_dim, str | DataArray) or concat_dim is None:
15461547
concat_dim = [concat_dim] # type: ignore[assignment]
@@ -1549,7 +1550,7 @@ def open_mfdataset(
15491550
# encoding the originally-supplied structure as "ids".
15501551
# The "ids" are not used at all if combine='by_coords`.
15511552
combined_ids_paths = _infer_concat_order_from_positions(paths)
1552-
ids, paths = (
1553+
ids, paths1d = (
15531554
list(combined_ids_paths.keys()),
15541555
list(combined_ids_paths.values()),
15551556
)
@@ -1559,6 +1560,8 @@ def open_mfdataset(
15591560
"effect. To manually combine along a specific dimension you should "
15601561
"instead specify combine='nested' along with a value for `concat_dim`.",
15611562
)
1563+
else:
1564+
paths1d = paths # type: ignore[assignment]
15621565

15631566
open_kwargs = dict(engine=engine, chunks=chunks or {}, **kwargs)
15641567

@@ -1574,7 +1577,7 @@ def open_mfdataset(
15741577
open_ = open_dataset
15751578
getattr_ = getattr
15761579

1577-
datasets = [open_(p, **open_kwargs) for p in paths]
1580+
datasets = [open_(p, **open_kwargs) for p in paths1d]
15781581
closers = [getattr_(ds, "_close") for ds in datasets]
15791582
if preprocess is not None:
15801583
datasets = [preprocess(ds) for ds in datasets]
@@ -1626,7 +1629,7 @@ def open_mfdataset(
16261629
if attrs_file is not None:
16271630
if isinstance(attrs_file, os.PathLike):
16281631
attrs_file = cast(str, os.fspath(attrs_file))
1629-
combined.attrs = datasets[paths.index(attrs_file)].attrs
1632+
combined.attrs = datasets[paths1d.index(attrs_file)].attrs
16301633

16311634
return combined
16321635

xarray/backends/common.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import os
55
import time
66
import traceback
7-
from collections.abc import Iterable, Mapping
7+
from collections.abc import Iterable, Mapping, Sequence
88
from glob import glob
9-
from typing import TYPE_CHECKING, Any, ClassVar
9+
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast, overload
1010

1111
import numpy as np
1212

@@ -29,8 +29,18 @@
2929

3030
NONE_VAR_NAME = "__values__"
3131

32+
T = TypeVar("T")
3233

33-
def _normalize_path(path):
34+
35+
@overload
36+
def _normalize_path(path: str | os.PathLike) -> str: ...
37+
38+
39+
@overload
40+
def _normalize_path(path: T) -> T: ...
41+
42+
43+
def _normalize_path(path: str | os.PathLike | T) -> str | T:
3444
"""
3545
Normalize pathlikes to string.
3646
@@ -55,12 +65,24 @@ def _normalize_path(path):
5565
if isinstance(path, str) and not is_remote_uri(path):
5666
path = os.path.abspath(os.path.expanduser(path))
5767

58-
return path
68+
return cast(str, path)
69+
70+
71+
@overload
72+
def _find_absolute_paths(
73+
paths: str | os.PathLike | Sequence[str | os.PathLike], **kwargs
74+
) -> list[str]: ...
75+
76+
77+
@overload
78+
def _find_absolute_paths(
79+
paths: NestedSequence[str | os.PathLike], **kwargs
80+
) -> NestedSequence[str]: ...
5981

6082

6183
def _find_absolute_paths(
6284
paths: str | os.PathLike | NestedSequence[str | os.PathLike], **kwargs
63-
) -> list[str]:
85+
) -> NestedSequence[str]:
6486
"""
6587
Find absolute paths from the pattern.
6688
@@ -99,21 +121,31 @@ def _find_absolute_paths(
99121
expand=False,
100122
)
101123
tmp_paths = fs.glob(fs._strip_protocol(paths)) # finds directories
102-
paths = [fs.get_mapper(path) for path in tmp_paths]
124+
return [fs.get_mapper(path) for path in tmp_paths]
103125
elif is_remote_uri(paths):
104126
raise ValueError(
105127
"cannot do wild-card matching for paths that are remote URLs "
106128
f"unless engine='zarr' is specified. Got paths: {paths}. "
107129
"Instead, supply paths as an explicit list of strings."
108130
)
109131
else:
110-
paths = sorted(glob(_normalize_path(paths)))
132+
return sorted(glob(_normalize_path(paths)))
111133
elif isinstance(paths, os.PathLike):
112-
paths = [os.fspath(paths)]
113-
else:
114-
paths = [os.fspath(p) if isinstance(p, os.PathLike) else p for p in paths]
134+
return [_normalize_path(paths)]
135+
136+
def _normalize_path_list(
137+
lpaths: NestedSequence[str | os.PathLike],
138+
) -> NestedSequence[str]:
139+
return [
140+
(
141+
_normalize_path(p)
142+
if isinstance(p, str | os.PathLike)
143+
else _normalize_path_list(p)
144+
)
145+
for p in lpaths
146+
]
115147

116-
return paths
148+
return _normalize_path_list(paths)
117149

118150

119151
def _encode_variable_name(name):

xarray/backends/netCDF4_.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def prepare_variable(
550550
_ensure_no_forward_slash_in_name(name)
551551
attrs = variable.attrs.copy()
552552
fill_value = attrs.pop("_FillValue", None)
553+
datatype: np.dtype | ncEnumType | h5EnumType
553554
datatype = _get_datatype(
554555
variable, self.format, raise_on_invalid_encoding=check_encoding
555556
)

xarray/core/combine.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import itertools
44
from collections import Counter
5-
from collections.abc import Iterable, Sequence
6-
from typing import TYPE_CHECKING, Literal, Union
5+
from collections.abc import Iterable, Iterator, Sequence
6+
from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast
77

88
import pandas as pd
99

@@ -15,14 +15,26 @@
1515
from xarray.core.utils import iterate_nested
1616

1717
if TYPE_CHECKING:
18-
from xarray.core.types import CombineAttrsOptions, CompatOptions, JoinOptions
18+
from xarray.core.types import (
19+
CombineAttrsOptions,
20+
CompatOptions,
21+
JoinOptions,
22+
NestedSequence,
23+
)
24+
25+
26+
T = TypeVar("T")
1927

2028

21-
def _infer_concat_order_from_positions(datasets):
29+
def _infer_concat_order_from_positions(
30+
datasets: NestedSequence[T],
31+
) -> dict[tuple[int, ...], T]:
2232
return dict(_infer_tile_ids_from_nested_list(datasets, ()))
2333

2434

25-
def _infer_tile_ids_from_nested_list(entry, current_pos):
35+
def _infer_tile_ids_from_nested_list(
36+
entry: NestedSequence[T], current_pos: tuple[int, ...]
37+
) -> Iterator[tuple[tuple[int, ...], T]]:
2638
"""
2739
Given a list of lists (of lists...) of objects, returns a iterator
2840
which returns a tuple containing the index of each object in the nested
@@ -44,11 +56,11 @@ def _infer_tile_ids_from_nested_list(entry, current_pos):
4456
combined_tile_ids : dict[tuple(int, ...), obj]
4557
"""
4658

47-
if isinstance(entry, list):
59+
if not isinstance(entry, str) and isinstance(entry, Sequence):
4860
for i, item in enumerate(entry):
4961
yield from _infer_tile_ids_from_nested_list(item, current_pos + (i,))
5062
else:
51-
yield current_pos, entry
63+
yield current_pos, cast(T, entry)
5264

5365

5466
def _ensure_same_types(series, dim):

xarray/core/datatree_io.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,48 @@
22

33
from collections.abc import Mapping, MutableMapping
44
from os import PathLike
5-
from typing import Any, Literal, get_args
5+
from typing import TYPE_CHECKING, Any, Literal, get_args
66

77
from xarray.core.datatree import DataTree
88
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
99

10+
if TYPE_CHECKING:
11+
from h5netcdf.legacyapi import Dataset as h5Dataset
12+
from netCDF4 import Dataset as ncDataset
13+
1014
T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"]
1115
T_DataTreeNetcdfTypes = Literal["NETCDF4"]
1216

1317

14-
def _get_nc_dataset_class(engine: T_DataTreeNetcdfEngine | None):
18+
def _get_nc_dataset_class(
19+
engine: T_DataTreeNetcdfEngine | None,
20+
) -> type[ncDataset] | type[h5Dataset]:
1521
if engine == "netcdf4":
16-
from netCDF4 import Dataset
17-
elif engine == "h5netcdf":
18-
from h5netcdf.legacyapi import Dataset
19-
elif engine is None:
22+
from netCDF4 import Dataset as ncDataset
23+
24+
return ncDataset
25+
if engine == "h5netcdf":
26+
from h5netcdf.legacyapi import Dataset as h5Dataset
27+
28+
return h5Dataset
29+
if engine is None:
2030
try:
21-
from netCDF4 import Dataset
31+
from netCDF4 import Dataset as ncDataset
32+
33+
return ncDataset
2234
except ImportError:
23-
from h5netcdf.legacyapi import Dataset
24-
else:
25-
raise ValueError(f"unsupported engine: {engine}")
26-
return Dataset
35+
from h5netcdf.legacyapi import Dataset as h5Dataset
36+
37+
return h5Dataset
38+
raise ValueError(f"unsupported engine: {engine}")
2739

2840

2941
def _create_empty_netcdf_group(
3042
filename: str | PathLike,
3143
group: str,
3244
mode: NetcdfWriteModes,
3345
engine: T_DataTreeNetcdfEngine | None,
34-
):
46+
) -> None:
3547
ncDataset = _get_nc_dataset_class(engine)
3648

3749
with ncDataset(filename, mode=mode) as rootgrp:
@@ -49,7 +61,7 @@ def _datatree_to_netcdf(
4961
group: str | None = None,
5062
compute: bool = True,
5163
**kwargs,
52-
):
64+
) -> None:
5365
"""This function creates an appropriate datastore for writing a datatree to
5466
disk as a netCDF file.
5567

xarray/core/types.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SupportsIndex,
1313
TypeVar,
1414
Union,
15+
overload,
1516
)
1617

1718
import numpy as np
@@ -285,15 +286,18 @@ def copy(
285286
AspectOptions = Union[Literal["auto", "equal"], float, None]
286287
ExtendOptions = Literal["neither", "both", "min", "max", None]
287288

288-
# TODO: Wait until mypy supports recursive objects in combination with typevars
289-
_T = TypeVar("_T")
290-
NestedSequence = Union[
291-
_T,
292-
Sequence[_T],
293-
Sequence[Sequence[_T]],
294-
Sequence[Sequence[Sequence[_T]]],
295-
Sequence[Sequence[Sequence[Sequence[_T]]]],
296-
]
289+
290+
_T_co = TypeVar("_T_co", covariant=True)
291+
292+
293+
class NestedSequence(Protocol[_T_co]):
294+
def __len__(self, /) -> int: ...
295+
@overload
296+
def __getitem__(self, index: int, /) -> _T_co | NestedSequence[_T_co]: ...
297+
@overload
298+
def __getitem__(self, index: slice, /) -> NestedSequence[_T_co]: ...
299+
def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ...
300+
def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]: ...
297301

298302

299303
QuantileMethods = Literal[

xarray/plot/facetgrid.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
119119
col_labels: list[Annotation | None]
120120
_x_var: None
121121
_y_var: None
122+
_hue_var: DataArray | None
122123
_cmap_extend: Any | None
123124
_mappables: list[ScalarMappable]
124125
_finalized: bool
@@ -271,6 +272,7 @@ def __init__(
271272
self.col_labels = [None] * ncol
272273
self._x_var = None
273274
self._y_var = None
275+
self._hue_var = None
274276
self._cmap_extend = None
275277
self._mappables = []
276278
self._finalized = False
@@ -720,6 +722,7 @@ def add_legend(
720722
if use_legend_elements:
721723
self.figlegend = _add_legend(**kwargs)
722724
else:
725+
assert self._hue_var is not None
723726
self.figlegend = self.fig.legend(
724727
handles=self._mappables[-1],
725728
labels=list(self._hue_var.to_numpy()),

xarray/tests/test_backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,7 @@ def test_encoding_enum__no_fill_value(self, recwarn):
18791879
cloud_type_dict = {"clear": 0, "cloudy": 1}
18801880
with nc4.Dataset(tmp_file, mode="w") as nc:
18811881
nc.createDimension("time", size=2)
1882-
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
1882+
cloud_type = nc.createEnumType(np.uint8, "cloud_type", cloud_type_dict)
18831883
v = nc.createVariable(
18841884
"clouds",
18851885
cloud_type,
@@ -1926,7 +1926,7 @@ def test_encoding_enum__multiple_variable_with_enum(self):
19261926
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
19271927
with nc4.Dataset(tmp_file, mode="w") as nc:
19281928
nc.createDimension("time", size=2)
1929-
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
1929+
cloud_type = nc.createEnumType(np.uint8, "cloud_type", cloud_type_dict)
19301930
nc.createVariable(
19311931
"clouds",
19321932
cloud_type,
@@ -1975,7 +1975,7 @@ def test_encoding_enum__error_multiple_variable_with_changing_enum(self):
19751975
cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
19761976
with nc4.Dataset(tmp_file, mode="w") as nc:
19771977
nc.createDimension("time", size=2)
1978-
cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
1978+
cloud_type = nc.createEnumType(np.uint8, "cloud_type", cloud_type_dict)
19791979
nc.createVariable(
19801980
"clouds",
19811981
cloud_type,

0 commit comments

Comments
 (0)