Skip to content

Commit abafd27

Browse files
committed
Add creation function wrappers with the device keyword
Also add _typing.py with type annotation definitions
1 parent c037d2b commit abafd27

File tree

2 files changed

+194
-3
lines changed

2 files changed

+194
-3
lines changed

numpy_array_api_compat/_aliases.py

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
99
from typing import Optional, Tuple, Union
10-
from numpy import ndarray, dtype
10+
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1111

1212
from typing import NamedTuple
1313

@@ -107,7 +107,7 @@ def unique_values(x: ndarray, /) -> ndarray:
107107
equal_nan=False,
108108
)
109109

110-
def astype(x: ndarray, dtype: dtype, /, *, copy: bool = True) -> ndarray:
110+
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
111111
if not copy and dtype == x.dtype:
112112
return x
113113
return x.astype(dtype=dtype, copy=copy)
@@ -138,6 +138,136 @@ def var(
138138
def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray:
139139
return np.transpose(x, axes)
140140

141+
# Creation functions add the device keyword (which does nothing for NumPy)
142+
143+
def _check_device(device):
144+
if device not in ["cpu", None]:
145+
raise ValueError(f"Unsupported device {device!r}")
146+
147+
def asarray(
148+
obj: Union[
149+
ndarray,
150+
bool,
151+
int,
152+
float,
153+
NestedSequence[bool | int | float],
154+
SupportsBufferProtocol,
155+
],
156+
/,
157+
*,
158+
dtype: Optional[Dtype] = None,
159+
device: Optional[Device] = None,
160+
copy: Optional[Union[bool, np._CopyMode]] = None,
161+
) -> ndarray:
162+
_check_device(device)
163+
if copy in (False, np._CopyMode.IF_NEEDED):
164+
# copy=False is not yet implemented in np.asarray
165+
raise NotImplementedError("copy=False is not yet implemented")
166+
return np.asarray(obj, dtype=dtype)
167+
168+
def arange(
169+
start: Union[int, float],
170+
/,
171+
stop: Optional[Union[int, float]] = None,
172+
step: Union[int, float] = 1,
173+
*,
174+
dtype: Optional[Dtype] = None,
175+
device: Optional[Device] = None,
176+
) -> ndarray:
177+
_check_device(device)
178+
return np.arange(start, stop=stop, step=step, dtype=dtype)
179+
180+
def empty(
181+
shape: Union[int, Tuple[int, ...]],
182+
*,
183+
dtype: Optional[Dtype] = None,
184+
device: Optional[Device] = None,
185+
) -> ndarray:
186+
_check_device(device)
187+
return np.empty(shape, dtype=dtype)
188+
189+
def empty_like(
190+
x: ndarray, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
191+
) -> ndarray:
192+
_check_device(device)
193+
return np.empty_like(x, dtype=dtype)
194+
195+
def eye(
196+
n_rows: int,
197+
n_cols: Optional[int] = None,
198+
/,
199+
*,
200+
k: int = 0,
201+
dtype: Optional[Dtype] = None,
202+
device: Optional[Device] = None,
203+
) -> ndarray:
204+
_check_device(device)
205+
return np.eye(n_rows, M=n_cols, k=k, dtype=dtype)
206+
207+
def full(
208+
shape: Union[int, Tuple[int, ...]],
209+
fill_value: Union[int, float],
210+
*,
211+
dtype: Optional[Dtype] = None,
212+
device: Optional[Device] = None,
213+
) -> ndarray:
214+
_check_device(device)
215+
return np.full(shape, fill_value, dtype=dtype)
216+
217+
def full_like(
218+
x: ndarray,
219+
/,
220+
fill_value: Union[int, float],
221+
*,
222+
dtype: Optional[Dtype] = None,
223+
device: Optional[Device] = None,
224+
) -> ndarray:
225+
_check_device(device)
226+
return np.full_like(x, fill_value, dtype=dtype)
227+
228+
def linspace(
229+
start: Union[int, float],
230+
stop: Union[int, float],
231+
/,
232+
num: int,
233+
*,
234+
dtype: Optional[Dtype] = None,
235+
device: Optional[Device] = None,
236+
endpoint: bool = True,
237+
) -> ndarray:
238+
_check_device(device)
239+
return np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
240+
241+
def ones(
242+
shape: Union[int, Tuple[int, ...]],
243+
*,
244+
dtype: Optional[Dtype] = None,
245+
device: Optional[Device] = None,
246+
) -> ndarray:
247+
_check_device(device)
248+
return np.ones(shape, dtype=dtype)
249+
250+
def ones_like(
251+
x: ndarray, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
252+
) -> ndarray:
253+
_check_device(device)
254+
return np.ones_like(x, dtype=dtype)
255+
256+
def zeros(
257+
shape: Union[int, Tuple[int, ...]],
258+
*,
259+
dtype: Optional[Dtype] = None,
260+
device: Optional[Device] = None,
261+
) -> ndarray:
262+
_check_device(device)
263+
return np.zeros(shape, dtype=dtype)
264+
265+
def zeros_like(
266+
x: ndarray, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
267+
) -> ndarray:
268+
_check_device(device)
269+
return np.zeros_like(x, dtype=dtype)
270+
141271
# from numpy import * doesn't overwrite these builtin names
142272
from numpy import abs, max, min, round
143273

@@ -146,4 +276,6 @@ def permute_dims(x: ndarray, /, axes: Tuple[int, ...]) -> ndarray:
146276
'bool', 'concat', 'pow', 'UniqueAllResult', 'UniqueCountsResult',
147277
'UniqueInverseResult', 'unique_all', 'unique_counts',
148278
'unique_inverse', 'unique_values', 'astype', 'abs', 'max', 'min',
149-
'round', 'std', 'var', 'permute_dims']
279+
'round', 'std', 'var', 'permute_dims', 'asarray', 'arange',
280+
'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace',
281+
'ones', 'ones_like', 'zeros', 'zeros_like']

numpy_array_api_compat/_typing.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
__all__ = [
4+
"ndarray",
5+
"Device",
6+
"Dtype",
7+
"NestedSequence",
8+
"SupportsBufferProtocol",
9+
]
10+
11+
import sys
12+
from typing import (
13+
Any,
14+
Literal,
15+
Union,
16+
TYPE_CHECKING,
17+
TypeVar,
18+
Protocol,
19+
)
20+
21+
from numpy import (
22+
ndarray,
23+
dtype,
24+
int8,
25+
int16,
26+
int32,
27+
int64,
28+
uint8,
29+
uint16,
30+
uint32,
31+
uint64,
32+
float32,
33+
float64,
34+
)
35+
36+
_T_co = TypeVar("_T_co", covariant=True)
37+
38+
class NestedSequence(Protocol[_T_co]):
39+
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
40+
def __len__(self, /) -> int: ...
41+
42+
Device = Literal["cpu"]
43+
if TYPE_CHECKING or sys.version_info >= (3, 9):
44+
Dtype = dtype[Union[
45+
int8,
46+
int16,
47+
int32,
48+
int64,
49+
uint8,
50+
uint16,
51+
uint32,
52+
uint64,
53+
float32,
54+
float64,
55+
]]
56+
else:
57+
Dtype = dtype
58+
59+
SupportsBufferProtocol = Any

0 commit comments

Comments
 (0)