Skip to content

Commit d88c363

Browse files
committed
TYP: fix typing errors in numpy._aliases
1 parent d8c5a33 commit d88c363

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

array_api_compat/numpy/_aliases.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
# pyright: reportPrivateUsage=false
12
from __future__ import annotations
23

3-
from typing import Optional, Union
4+
from builtins import bool as py_bool
5+
from typing import TYPE_CHECKING, cast
6+
7+
import numpy as np
48

59
from .._internal import get_xp
610
from ..common import _aliases
711
from ..common._typing import NestedSequence, SupportsBufferProtocol
812
from ._info import __array_namespace_info__
913
from ._typing import Array, Device, DType
1014

11-
import numpy as np
15+
if TYPE_CHECKING:
16+
from typing import Any, Literal, TypeAlias
17+
18+
from typing_extensions import Buffer, TypeIs
19+
20+
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
1221

1322
bool = np.bool_
1423

@@ -63,9 +72,9 @@
6372
sign = get_xp(np)(_aliases.sign)
6473

6574

66-
def _supports_buffer_protocol(obj):
75+
def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction]
6776
try:
68-
memoryview(obj)
77+
memoryview(obj) # pyright: ignore[reportArgumentType]
6978
except TypeError:
7079
return False
7180
return True
@@ -76,15 +85,13 @@ def _supports_buffer_protocol(obj):
7685
# complicated enough that it's easier to define it separately for each module
7786
# rather than trying to combine everything into one function in common/
7887
def asarray(
79-
obj: (
80-
Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol
81-
),
88+
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
8289
/,
8390
*,
84-
dtype: Optional[DType] = None,
85-
device: Optional[Device] = None,
86-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
87-
**kwargs,
91+
dtype: DType | None = None,
92+
device: Device | None = None,
93+
copy: _Copy | None = None,
94+
**kwargs: Any,
8895
) -> Array:
8996
"""
9097
Array API compatibility wrapper for asarray().
@@ -108,24 +115,28 @@ def asarray(
108115
if copy is False:
109116
raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.")
110117

111-
return np.array(obj, copy=copy, dtype=dtype, **kwargs)
118+
return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore
112119

113120

114121
def astype(
115122
x: Array,
116123
dtype: DType,
117124
/,
118125
*,
119-
copy: bool = True,
120-
device: Optional[Device] = None,
126+
copy: py_bool = True,
127+
device: Device | None = None,
121128
) -> Array:
122129
return x.astype(dtype=dtype, copy=copy)
123130

124131

125132
# count_nonzero returns a python int for axis=None and keepdims=False
126133
# https://github.com/numpy/numpy/issues/17562
127-
def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
128-
result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
134+
def count_nonzero(
135+
x: Array,
136+
axis: int | tuple[int, ...] | None = None,
137+
keepdims: py_bool = False,
138+
) -> Array:
139+
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore
129140
if axis is None and not keepdims:
130141
return np.asarray(result)
131142
return result
@@ -148,10 +159,25 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
148159
else:
149160
unstack = get_xp(np)(_aliases.unstack)
150161

151-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
152-
'acos', 'acosh', 'asin', 'asinh', 'atan',
153-
'atan2', 'atanh', 'bitwise_left_shift',
154-
'bitwise_invert', 'bitwise_right_shift',
155-
'bool', 'concat', 'count_nonzero', 'pow']
162+
__all__ = [
163+
"__array_namespace_info__",
164+
"asarray",
165+
"astype",
166+
"acos",
167+
"acosh",
168+
"asin",
169+
"asinh",
170+
"atan",
171+
"atan2",
172+
"atanh",
173+
"bitwise_left_shift",
174+
"bitwise_invert",
175+
"bitwise_right_shift",
176+
"bool",
177+
"concat",
178+
"count_nonzero",
179+
"pow",
180+
]
181+
__all__ += _aliases.__all__
156182

157183
_all_ignore = ['np', 'get_xp']

0 commit comments

Comments
 (0)