Skip to content

Commit 13d4a70

Browse files
authored
Merge pull request #329 from crusaderky/mypy
TYP: replace basedmypy with mypy
2 parents a127376 + 9375f63 commit 13d4a70

File tree

16 files changed

+359
-107
lines changed

16 files changed

+359
-107
lines changed

pixi.lock

Lines changed: 308 additions & 54 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ array-api-extra = { path = ".", editable = true }
5757
typing-extensions = ">=4.13.2"
5858
pre-commit = ">=4.2.0"
5959
pylint = ">=3.3.7"
60-
basedmypy = ">=2.10.0"
60+
mypy = ">=1.16.0"
6161
basedpyright = ">=1.29.2"
6262
numpydoc = ">=1.8.0,<2"
6363
# import dependencies for mypy:
@@ -227,16 +227,17 @@ python_version = "3.10"
227227
warn_unused_configs = true
228228
strict = true
229229
enable_error_code = ["ignore-without-code", "truthy-bool"]
230-
# https://github.com/data-apis/array-api-typing
231-
disallow_any_expr = false
232-
# false positives with input validation
233-
disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]
230+
disable_error_code = ["no-any-return"]
234231

235232
[[tool.mypy.overrides]]
236233
# slow or unavailable on Windows; do not add to the lint env
237234
module = ["cupy.*", "jax.*", "sparse.*", "torch.*"]
238235
ignore_missing_imports = true
239236

237+
[[tool.mypy.overrides]]
238+
module = ["tests/*"]
239+
disable_error_code = ["no-untyped-def"] # test(...) without -> None
240+
240241
# pyright
241242

242243
[tool.basedpyright]

src/array_api_extra/_lib/_at.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class _AtOp(Enum):
3737
MAX = "max"
3838

3939
# @override from Python 3.12
40-
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride]
40+
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride]
4141
"""
4242
Return string representation (useful for pytest logs).
4343

src/array_api_extra/_lib/_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Backend(Enum): # numpydoc ignore=PR02
3030
JAX = "jax.numpy"
3131
JAX_GPU = "jax.numpy:gpu"
3232

33-
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
33+
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
3434
"""Pretty-print parameterized test names."""
3535
return (
3636
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")

src/array_api_extra/_lib/_funcs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
@overload
37-
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
37+
def apply_where( # numpydoc ignore=GL08
3838
cond: Array,
3939
args: Array | tuple[Array, ...],
4040
f1: Callable[..., Array],
@@ -46,7 +46,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G
4646

4747

4848
@overload
49-
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
49+
def apply_where( # numpydoc ignore=GL08
5050
cond: Array,
5151
args: Array | tuple[Array, ...],
5252
f1: Callable[..., Array],
@@ -57,7 +57,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G
5757
) -> Array: ...
5858

5959

60-
def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
60+
def apply_where( # numpydoc ignore=PR01,PR02
6161
cond: Array,
6262
args: Array | tuple[Array, ...],
6363
f1: Callable[..., Array],
@@ -143,7 +143,7 @@ def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
143143
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
144144

145145

146-
def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
146+
def _apply_where( # numpydoc ignore=PR01,RT01
147147
cond: Array,
148148
f1: Callable[..., Array],
149149
f2: Callable[..., Array] | None,
@@ -813,8 +813,7 @@ def pad(
813813
else:
814814
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))
815815

816-
# https://github.com/python/typeshed/issues/13376
817-
slices: list[slice] = [] # type: ignore[explicit-any]
816+
slices: list[slice] = []
818817
newshape: list[int] = []
819818
for ax, w_tpl in enumerate(pad_width_seq):
820819
if len(w_tpl) != 2:
@@ -826,6 +825,7 @@ def pad(
826825
if w_tpl[0] == 0 and w_tpl[1] == 0:
827826
sl = slice(None, None, None)
828827
else:
828+
stop: int | None
829829
start, stop = w_tpl
830830
stop = None if stop == 0 else -stop
831831

src/array_api_extra/_lib/_lazy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
from numpy.typing import ArrayLike
2424

25-
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[explicit-any]
25+
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic
2626
else:
2727
# Sphinx hack
2828
NumPyObject = Any
@@ -31,7 +31,7 @@
3131

3232

3333
@overload
34-
def lazy_apply( # type: ignore[decorated-any, valid-type]
34+
def lazy_apply( # type: ignore[valid-type]
3535
func: Callable[P, Array | ArrayLike],
3636
*args: Array | complex | None,
3737
shape: tuple[int | None, ...] | None = None,
@@ -43,7 +43,7 @@ def lazy_apply( # type: ignore[decorated-any, valid-type]
4343

4444

4545
@overload
46-
def lazy_apply( # type: ignore[decorated-any, valid-type]
46+
def lazy_apply( # type: ignore[valid-type]
4747
func: Callable[P, Sequence[Array | ArrayLike]],
4848
*args: Array | complex | None,
4949
shape: Sequence[tuple[int | None, ...]],
@@ -313,7 +313,7 @@ def _is_jax_jit_enabled(xp: ModuleType) -> bool: # numpydoc ignore=PR01,RT01
313313
return True
314314

315315

316-
def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
316+
def _lazy_apply_wrapper( # numpydoc ignore=PR01,RT01
317317
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
318318
as_numpy: bool,
319319
multi_output: bool,
@@ -331,7 +331,7 @@ def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,R
331331

332332
# On Dask, @wraps causes the graph key to contain the wrapped function's name
333333
@wraps(func)
334-
def wrapper( # type: ignore[decorated-any,explicit-any]
334+
def wrapper(
335335
*args: Array | complex | None, **kwargs: Any
336336
) -> tuple[Array, ...]: # numpydoc ignore=GL08
337337
args_list = []
@@ -343,7 +343,7 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
343343
if as_numpy:
344344
import numpy as np
345345

346-
arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # noqa: PLW2901
346+
arg = cast(Array, np.asarray(arg)) # noqa: PLW2901
347347
args_list.append(arg)
348348
assert device is not None
349349

src/array_api_extra/_lib/_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _is_materializable(x: Array) -> bool:
110110
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
111111

112112

113-
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
113+
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
114114
"""
115115
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
116116
"""

src/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def is_torch_array(x: object, /) -> TypeGuard[Array]: ...
3636
def is_lazy_array(x: object, /) -> TypeGuard[Array]: ...
3737
def is_writeable_array(x: object, /) -> TypeGuard[Array]: ...
3838
def size(x: Array, /) -> int | None: ...
39-
def to_device( # type: ignore[explicit-any]
39+
def to_device(
4040
x: Array,
4141
device: Device, # pylint: disable=redefined-outer-name
4242
/,

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def asarrays(
210210
float: ("real floating", "complex floating"),
211211
complex: "complex floating",
212212
}
213-
kind = same_dtype[type(cast(complex, b))] # type: ignore[index]
213+
kind = same_dtype[type(cast(complex, b))]
214214
if xp.isdtype(a.dtype, kind):
215215
xb = xp.asarray(b, dtype=a.dtype)
216216
else:
@@ -458,7 +458,7 @@ def persistent_id(
458458
return instances, (f.getvalue(), *rest)
459459

460460

461-
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
461+
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any:
462462
"""
463463
Reverse of ``pickle_flatten``.
464464
@@ -521,7 +521,7 @@ def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
521521
self.obj = obj
522522

523523
@classmethod
524-
def _register(cls): # numpydoc ignore=SS06
524+
def _register(cls) -> None: # numpydoc ignore=SS06
525525
"""
526526
Register upon first use instead of at import time, to avoid
527527
globally importing JAX.
@@ -583,7 +583,7 @@ def f(x: Array, y: float, plus: bool) -> Array:
583583
import jax
584584

585585
@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
586-
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
586+
def inner( # numpydoc ignore=GL08
587587
wargs: _AutoJITWrapper[Any],
588588
) -> _AutoJITWrapper[T]:
589589
args, kwargs = wargs.obj

src/array_api_extra/_lib/_utils/_typing.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ class DType(Protocol): # pylint: disable=missing-class-docstring
9595
class Device(Protocol): # pylint: disable=missing-class-docstring
9696
pass
9797

98-
SetIndex: TypeAlias = ( # type: ignore[explicit-any]
98+
SetIndex: TypeAlias = (
9999
int | slice | EllipsisType | Array | tuple[int | slice | EllipsisType | Array, ...]
100100
)
101-
GetIndex: TypeAlias = ( # type: ignore[explicit-any]
101+
GetIndex: TypeAlias = (
102102
SetIndex | None | tuple[int | slice | EllipsisType | None | Array, ...]
103103
)
104104

0 commit comments

Comments
 (0)