1
+ # pyright: reportPrivateUsage=false
1
2
from __future__ import annotations
2
3
3
- from typing import Optional , Union
4
+ from builtins import bool as py_bool
5
+ from typing import TYPE_CHECKING , cast
6
+
7
+ import numpy as np
4
8
5
9
from .._internal import get_xp
6
10
from ..common import _aliases
7
11
from ..common ._typing import NestedSequence , SupportsBufferProtocol
8
12
from ._info import __array_namespace_info__
9
13
from ._typing import Array , Device , DType
10
14
11
- import numpy as np
15
+ if TYPE_CHECKING :
16
+ from typing import Any , Literal , TypeAlias
17
+
18
+ from typing_extensions import Buffer , TypeIs
19
+
20
+ _Copy : TypeAlias = py_bool | Literal [2 ] | np ._CopyMode
12
21
13
22
bool = np .bool_
14
23
63
72
sign = get_xp (np )(_aliases .sign )
64
73
65
74
66
- def _supports_buffer_protocol (obj ):
75
+ def _supports_buffer_protocol (obj : object ) -> TypeIs [ Buffer ]: # pyright: ignore[reportUnusedFunction]
67
76
try :
68
- memoryview (obj )
77
+ memoryview (obj ) # pyright: ignore[reportArgumentType]
69
78
except TypeError :
70
79
return False
71
80
return True
@@ -76,15 +85,13 @@ def _supports_buffer_protocol(obj):
76
85
# complicated enough that it's easier to define it separately for each module
77
86
# rather than trying to combine everything into one function in common/
78
87
def asarray (
79
- obj : (
80
- Array | bool | complex | NestedSequence [bool | complex ] | SupportsBufferProtocol
81
- ),
88
+ obj : Array | complex | NestedSequence [complex ] | SupportsBufferProtocol ,
82
89
/ ,
83
90
* ,
84
- dtype : Optional [ DType ] = None ,
85
- device : Optional [ Device ] = None ,
86
- copy : "Optional[Union[bool, np._CopyMode]]" = None ,
87
- ** kwargs ,
91
+ dtype : DType | None = None ,
92
+ device : Device | None = None ,
93
+ copy : _Copy | None = None ,
94
+ ** kwargs : Any ,
88
95
) -> Array :
89
96
"""
90
97
Array API compatibility wrapper for asarray().
@@ -108,24 +115,28 @@ def asarray(
108
115
if copy is False :
109
116
raise NotImplementedError ("asarray(copy=False) requires a newer version of NumPy." )
110
117
111
- return np .array (obj , copy = copy , dtype = dtype , ** kwargs )
118
+ return np .array (obj , copy = copy , dtype = dtype , ** kwargs ) # pyright: ignore
112
119
113
120
114
121
def astype (
115
122
x : Array ,
116
123
dtype : DType ,
117
124
/ ,
118
125
* ,
119
- copy : bool = True ,
120
- device : Optional [ Device ] = None ,
126
+ copy : py_bool = True ,
127
+ device : Device | None = None ,
121
128
) -> Array :
122
129
return x .astype (dtype = dtype , copy = copy )
123
130
124
131
125
132
# count_nonzero returns a python int for axis=None and keepdims=False
126
133
# https://github.com/numpy/numpy/issues/17562
127
- def count_nonzero (x : Array , axis = None , keepdims = False ) -> Array :
128
- result = np .count_nonzero (x , axis = axis , keepdims = keepdims )
134
+ def count_nonzero (
135
+ x : Array ,
136
+ axis : int | tuple [int , ...] | None = None ,
137
+ keepdims : py_bool = False ,
138
+ ) -> Array :
139
+ result = cast ("Any" , np .count_nonzero (x , axis = axis , keepdims = keepdims )) # pyright: ignore
129
140
if axis is None and not keepdims :
130
141
return np .asarray (result )
131
142
return result
@@ -148,10 +159,25 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
148
159
else :
149
160
unstack = get_xp (np )(_aliases .unstack )
150
161
151
- __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'astype' ,
152
- 'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
153
- 'atan2' , 'atanh' , 'bitwise_left_shift' ,
154
- 'bitwise_invert' , 'bitwise_right_shift' ,
155
- 'bool' , 'concat' , 'count_nonzero' , 'pow' ]
162
+ __all__ = [
163
+ "__array_namespace_info__" ,
164
+ "asarray" ,
165
+ "astype" ,
166
+ "acos" ,
167
+ "acosh" ,
168
+ "asin" ,
169
+ "asinh" ,
170
+ "atan" ,
171
+ "atan2" ,
172
+ "atanh" ,
173
+ "bitwise_left_shift" ,
174
+ "bitwise_invert" ,
175
+ "bitwise_right_shift" ,
176
+ "bool" ,
177
+ "concat" ,
178
+ "count_nonzero" ,
179
+ "pow" ,
180
+ ]
181
+ __all__ += _aliases .__all__
156
182
157
183
_all_ignore = ['np' , 'get_xp' ]
0 commit comments