Skip to content

Commit 076848e

Browse files
committed
Add a cupy submodule
Full support still needs to be tested, and also we need to double check if any of the aliases are unnecessary for cupy.
1 parent 7333696 commit 076848e

File tree

7 files changed

+174
-99
lines changed

7 files changed

+174
-99
lines changed

numpy_array_api_compat/common/_aliases.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
from types import ModuleType
1414

15-
from ..numpy._helpers import _is_numpy_array, get_namespace
15+
from ._helpers import _check_device, _is_numpy_array, get_namespace
1616
from .._internal import get_xp
1717

1818
# Basic renames
@@ -189,10 +189,6 @@ def permute_dims(x: ndarray, /, xp, axes: Tuple[int, ...]) -> ndarray:
189189

190190
# Creation functions add the device keyword (which does nothing for NumPy)
191191

192-
def _check_device(device):
193-
if device not in ["cpu", None]:
194-
raise ValueError(f"Unsupported device {device!r}")
195-
196192
# asarray also adds the copy keyword
197193
def _asarray(
198194
obj: Union[
@@ -232,7 +228,7 @@ def _asarray(
232228
else:
233229
raise ValueError("Unrecognized namespace argument to asarray()")
234230

235-
_check_device(device)
231+
_check_device(xp, device)
236232
if _is_numpy_array(obj):
237233
import numpy as np
238234
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
@@ -263,7 +259,7 @@ def arange(
263259
dtype: Optional[Dtype] = None,
264260
device: Optional[Device] = None,
265261
) -> ndarray:
266-
_check_device(device)
262+
_check_device(xp, device)
267263
return xp.arange(start, stop=stop, step=step, dtype=dtype)
268264

269265
@get_xp
@@ -274,14 +270,14 @@ def empty(
274270
dtype: Optional[Dtype] = None,
275271
device: Optional[Device] = None,
276272
) -> ndarray:
277-
_check_device(device)
273+
_check_device(xp, device)
278274
return xp.empty(shape, dtype=dtype)
279275

280276
@get_xp
281277
def empty_like(
282278
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
283279
) -> ndarray:
284-
_check_device(device)
280+
_check_device(xp, device)
285281
return xp.empty_like(x, dtype=dtype)
286282

287283
@get_xp
@@ -295,7 +291,7 @@ def eye(
295291
dtype: Optional[Dtype] = None,
296292
device: Optional[Device] = None,
297293
) -> ndarray:
298-
_check_device(device)
294+
_check_device(xp, device)
299295
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype)
300296

301297
@get_xp
@@ -307,7 +303,7 @@ def full(
307303
dtype: Optional[Dtype] = None,
308304
device: Optional[Device] = None,
309305
) -> ndarray:
310-
_check_device(device)
306+
_check_device(xp, device)
311307
return xp.full(shape, fill_value, dtype=dtype)
312308

313309
@get_xp
@@ -320,7 +316,7 @@ def full_like(
320316
dtype: Optional[Dtype] = None,
321317
device: Optional[Device] = None,
322318
) -> ndarray:
323-
_check_device(device)
319+
_check_device(xp, device)
324320
return xp.full_like(x, fill_value, dtype=dtype)
325321

326322
@get_xp
@@ -335,7 +331,7 @@ def linspace(
335331
device: Optional[Device] = None,
336332
endpoint: bool = True,
337333
) -> ndarray:
338-
_check_device(device)
334+
_check_device(xp, device)
339335
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint)
340336

341337
@get_xp
@@ -346,14 +342,14 @@ def ones(
346342
dtype: Optional[Dtype] = None,
347343
device: Optional[Device] = None,
348344
) -> ndarray:
349-
_check_device(device)
345+
_check_device(xp, device)
350346
return xp.ones(shape, dtype=dtype)
351347

352348
@get_xp
353349
def ones_like(
354350
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
355351
) -> ndarray:
356-
_check_device(device)
352+
_check_device(xp, device)
357353
return xp.ones_like(x, dtype=dtype)
358354

359355
@get_xp
@@ -364,14 +360,14 @@ def zeros(
364360
dtype: Optional[Dtype] = None,
365361
device: Optional[Device] = None,
366362
) -> ndarray:
367-
_check_device(device)
363+
_check_device(xp, device)
368364
return xp.zeros(shape, dtype=dtype)
369365

370366
@get_xp
371367
def zeros_like(
372368
x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
373369
) -> ndarray:
374-
_check_device(device)
370+
_check_device(xp, device)
375371
return xp.zeros_like(x, dtype=dtype)
376372

377373
# xp.reshape calls the keyword argument 'newshape' instead of 'shape'

numpy_array_api_compat/common/_helpers.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,40 @@
11
"""
22
Various helper functions which are not part of the spec.
3+
4+
Functions which start with an underscore are for internal use only but helpers
5+
that are in __all__ are intended as additional helper functions for use by end
6+
users of the compat library.
37
"""
8+
from __future__ import annotations
9+
10+
import sys
11+
12+
def _is_numpy_array(x):
13+
# Avoid importing NumPy if it isn't already
14+
if 'numpy' not in sys.modules:
15+
return False
16+
17+
import numpy as np
18+
19+
# TODO: Should we reject ndarray subclasses?
20+
return isinstance(x, (np.ndarray, np.generic))
21+
22+
def _is_cupy_array(x):
23+
# Avoid importing NumPy if it isn't already
24+
if 'cupy' not in sys.modules:
25+
return False
26+
27+
import cupy as cp
28+
29+
# TODO: Should we reject ndarray subclasses?
30+
return isinstance(x, (cp.ndarray, cp.generic))
31+
32+
def is_array_api_obj(x):
33+
"""
34+
Check if x is an array API compatible array object.
35+
"""
36+
return _is_numpy_array(x) or _is_cupy_array(x) or hasattr(x, '__array_namespace__')
37+
438
def get_namespace(*xs, _use_compat=True):
539
"""
640
Get the array API compatible namespace for the arrays `xs`.
@@ -35,3 +69,102 @@ def get_namespace(*xs, _use_compat=True):
3569
xp, = namespaces
3670

3771
return xp
72+
73+
74+
def _check_device(xp, device):
75+
if xp == sys.modules.get('numpy'):
76+
if device not in ["cpu", None]:
77+
raise ValueError(f"Unsupported device for NumPy: {device!r}")
78+
79+
# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
80+
# or cupy.ndarray. They are not included in array objects of this library
81+
# because this library just reuses the respective ndarray classes without
82+
# wrapping or subclassing them. These helper functions can be used instead of
83+
# the wrapper functions for libraries that need to support both NumPy/CuPy and
84+
# other libraries that use devices.
85+
def device(x: "Array", /) -> "Device":
86+
"""
87+
Hardware device the array data resides on.
88+
89+
Parameters
90+
----------
91+
x: array
92+
array instance from NumPy or an array API compatible library.
93+
94+
Returns
95+
-------
96+
out: device
97+
a ``device`` object (see the "Device Support" section of the array API specification).
98+
"""
99+
if _is_numpy_array(x):
100+
return "cpu"
101+
return x.device
102+
103+
# Based on cupy.array_api.Array.to_device
104+
def _cupy_to_device(x, device, /, stream=None):
105+
import cupy as cp
106+
from cupy.cuda import Device as _Device
107+
from cupy.cuda import stream as stream_module
108+
from cupy_backends.cuda.api import runtime
109+
110+
if device == x.device:
111+
return x
112+
elif not isinstance(device, _Device):
113+
raise ValueError(f"Unsupported device {device!r}")
114+
else:
115+
# see cupy/cupy#5985 for the reason how we handle device/stream here
116+
prev_device = runtime.getDevice()
117+
prev_stream: stream_module.Stream = None
118+
if stream is not None:
119+
prev_stream = stream_module.get_current_stream()
120+
# stream can be an int as specified in __dlpack__, or a CuPy stream
121+
if isinstance(stream, int):
122+
stream = cp.cuda.ExternalStream(stream)
123+
elif isinstance(stream, cp.cuda.Stream):
124+
pass
125+
else:
126+
raise ValueError('the input stream is not recognized')
127+
stream.use()
128+
try:
129+
runtime.setDevice(device.id)
130+
arr = x.copy()
131+
finally:
132+
runtime.setDevice(prev_device)
133+
if stream is not None:
134+
prev_stream.use()
135+
return arr
136+
137+
def to_device(x: "Array", device: "Device", /, *, stream: Optional[Union[int, Any]] = None) -> "Array":
138+
"""
139+
Copy the array from the device on which it currently resides to the specified ``device``.
140+
141+
Parameters
142+
----------
143+
x: array
144+
array instance from NumPy or an array API compatible library.
145+
device: device
146+
a ``device`` object (see the "Device Support" section of the array API specification).
147+
stream: Optional[Union[int, Any]]
148+
stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
149+
150+
Returns
151+
-------
152+
out: array
153+
an array with the same data and data type as ``x`` and located on the specified ``device``.
154+
155+
.. note::
156+
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
157+
"""
158+
if _is_numpy_array(x):
159+
if stream is not None:
160+
raise ValueError("The stream argument to to_device() is not supported")
161+
if device == 'cpu':
162+
return x
163+
raise ValueError(f"Unsupported device {device!r}")
164+
elif _is_cupy_array(x):
165+
# cupy does not yet have to_device
166+
return _cupy_to_device(x, device, stream=stream)
167+
168+
return x.to_device(device, stream=stream)
169+
170+
__all__ = ['is_array_api_obj', 'get_namespace', 'device', 'to_device']

numpy_array_api_compat/cupy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
#
1212
# from . import linalg
1313
#
14-
# It doesn't overwrite np.linalg from above. The import is generated
14+
# It doesn't overwrite cupy.linalg from above. The import is generated
1515
# dynamically so that the library can be vendored.
1616
__import__(__package__ + '.linalg')
1717

1818
from .linalg import matrix_transpose, vecdot
1919

20-
from ._helpers import *
20+
from ..common._helpers import *
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from functools import partial
2+
3+
from ..common._aliases import *
4+
from ..common._aliases import _asarray
5+
from ..common._aliases import __all__
6+
7+
asarray = asarray_cupy = partial(_asarray, namespace='cupy')
8+
asarray.__doc__ = _asarray.__doc__
9+
del partial
10+
11+
__all__ = __all__ + ['asarray', 'asarray_cupy']

numpy_array_api_compat/cupy/linalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from cupy.linalg import *
2+
# cupy.linalg doesn't have __all__. If it is added, replace this with
3+
#
4+
# from cupy.linalg import __all__ as linalg_all
5+
_n = {}
6+
exec('from cupy.linalg import *', _n)
7+
del _n['__builtins__']
8+
linalg_all = list(_n)
9+
del _n
10+
11+
from ..common.linalg import *
12+
from ..common.linalg import __all__ as common_linalg_all
13+
14+
__all__ = linalg_all + common_linalg_all

numpy_array_api_compat/numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
from .linalg import matrix_transpose, vecdot
1919

20-
from ._helpers import *
20+
from ..common._helpers import *

numpy_array_api_compat/numpy/_helpers.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

0 commit comments

Comments
 (0)