Skip to content

Commit 600df5e

Browse files
committed
Add "multi device" support
Having more than one device is useful during testing to allow you to find bugs related to how arrays on different devices are handled.
1 parent 718f15b commit 600df5e

File tree

5 files changed

+44
-29
lines changed

5 files changed

+44
-29
lines changed

array_api_strict/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@
309309

310310
__all__ += ["all", "any"]
311311

312+
from ._array_object import Device
313+
__all__ += ["Device"]
314+
312315
# Helper functions that are not part of the standard
313316

314317
from ._flags import (

array_api_strict/_array_object.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,17 @@
4343

4444
import numpy as np
4545

46-
# Placeholder object to represent the "cpu" device (the only device NumPy
47-
# supports).
48-
class _cpu_device:
46+
class Device:
47+
def __init__(self, device="CPU_DEVICE"):
48+
self._device = device
49+
4950
def __repr__(self):
50-
return "CPU_DEVICE"
51+
return f"Device('{self._device}')"
52+
53+
def __eq__(self, other):
54+
return self._device == other._device
5155

52-
CPU_DEVICE = _cpu_device()
56+
CPU_DEVICE = Device()
5357

5458
_default = object()
5559

@@ -73,7 +77,7 @@ class Array:
7377
# Use a custom constructor instead of __init__, as manually initializing
7478
# this class is not supported API.
7579
@classmethod
76-
def _new(cls, x, /):
80+
def _new(cls, x, /, device=CPU_DEVICE):
7781
"""
7882
This is a private method for initializing the array API Array
7983
object.
@@ -95,6 +99,9 @@ def _new(cls, x, /):
9599
)
96100
obj._array = x
97101
obj._dtype = _dtype
102+
if device is None:
103+
device = CPU_DEVICE
104+
obj._device = device
98105
return obj
99106

100107
# Prevent Array() from working
@@ -134,6 +141,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
134141
will be present in other implementations.
135142
136143
"""
144+
if self._device != CPU_DEVICE:
145+
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
137146
# copy keyword is new in 2.0.0; for older versions don't use it
138147
# retry without that keyword.
139148
if np.__version__[0] < '2':
@@ -1154,8 +1163,11 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
11541163
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
11551164
if stream is not None:
11561165
raise ValueError("The stream argument to to_device() is not supported")
1157-
if device == CPU_DEVICE:
1166+
if device == self._device:
11581167
return self
1168+
elif isinstance(device, Device):
1169+
arr = np.asarray(self._array, copy=True)
1170+
return self.__class__._new(arr, device=device)
11591171
raise ValueError(f"Unsupported device {device!r}")
11601172

11611173
@property
@@ -1169,7 +1181,7 @@ def dtype(self) -> Dtype:
11691181

11701182
@property
11711183
def device(self) -> Device:
1172-
return CPU_DEVICE
1184+
return self._device
11731185

11741186
# Note: mT is new in array API spec (see matrix_transpose)
11751187
@property

array_api_strict/_creation_functions.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def _supports_buffer_protocol(obj):
3232
def _check_device(device):
3333
# _array_object imports in this file are inside the functions to avoid
3434
# circular imports
35-
from ._array_object import CPU_DEVICE
35+
from ._array_object import Device
3636

37-
if device not in [CPU_DEVICE, None]:
37+
if device is not None and not isinstance(device, Device):
3838
raise ValueError(f"Unsupported device {device!r}")
3939

4040
def asarray(
@@ -79,7 +79,7 @@ def asarray(
7979
return Array._new(new_array)
8080
elif _supports_buffer_protocol(obj):
8181
# Buffer protocol will always support no-copy
82-
return Array._new(np.array(obj, copy=copy, dtype=_np_dtype))
82+
return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device)
8383
else:
8484
# No-copy is unsupported for Python built-in types.
8585
raise ValueError("Unable to avoid copy while creating an array from given object.")
@@ -89,13 +89,13 @@ def asarray(
8989
copy = False
9090

9191
if isinstance(obj, Array):
92-
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype))
92+
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device)
9393
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
9494
# Give a better error message in this case. NumPy would convert this
9595
# to an object array. TODO: This won't handle large integers in lists.
9696
raise OverflowError("Integer out of bounds for array dtypes")
9797
res = np.array(obj, dtype=_np_dtype, copy=copy)
98-
return Array._new(res)
98+
return Array._new(res, device=device)
9999

100100

101101
def arange(
@@ -119,7 +119,7 @@ def arange(
119119

120120
if dtype is not None:
121121
dtype = dtype._np_dtype
122-
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
122+
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype), device=device)
123123

124124

125125
def empty(
@@ -140,7 +140,7 @@ def empty(
140140

141141
if dtype is not None:
142142
dtype = dtype._np_dtype
143-
return Array._new(np.empty(shape, dtype=dtype))
143+
return Array._new(np.empty(shape, dtype=dtype), device=device)
144144

145145

146146
def empty_like(
@@ -158,7 +158,7 @@ def empty_like(
158158

159159
if dtype is not None:
160160
dtype = dtype._np_dtype
161-
return Array._new(np.empty_like(x._array, dtype=dtype))
161+
return Array._new(np.empty_like(x._array, dtype=dtype), device=device)
162162

163163

164164
def eye(
@@ -182,7 +182,7 @@ def eye(
182182

183183
if dtype is not None:
184184
dtype = dtype._np_dtype
185-
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
185+
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype), device=device)
186186

187187

188188
_default = object()
@@ -237,7 +237,7 @@ def full(
237237
# This will happen if the fill value is not something that NumPy
238238
# coerces to one of the acceptable dtypes.
239239
raise TypeError("Invalid input to full")
240-
return Array._new(res)
240+
return Array._new(res, device=device)
241241

242242

243243
def full_like(
@@ -265,7 +265,7 @@ def full_like(
265265
# This will happen if the fill value is not something that NumPy
266266
# coerces to one of the acceptable dtypes.
267267
raise TypeError("Invalid input to full_like")
268-
return Array._new(res)
268+
return Array._new(res, device=device)
269269

270270

271271
def linspace(
@@ -290,7 +290,7 @@ def linspace(
290290

291291
if dtype is not None:
292292
dtype = dtype._np_dtype
293-
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
293+
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint), device=device)
294294

295295

296296
def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
@@ -308,7 +308,7 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
308308
raise ValueError("meshgrid inputs must all have the same dtype")
309309

310310
return [
311-
Array._new(array)
311+
Array._new(array, device=device)
312312
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
313313
]
314314

@@ -331,7 +331,7 @@ def ones(
331331

332332
if dtype is not None:
333333
dtype = dtype._np_dtype
334-
return Array._new(np.ones(shape, dtype=dtype))
334+
return Array._new(np.ones(shape, dtype=dtype), device=device)
335335

336336

337337
def ones_like(
@@ -349,7 +349,7 @@ def ones_like(
349349

350350
if dtype is not None:
351351
dtype = dtype._np_dtype
352-
return Array._new(np.ones_like(x._array, dtype=dtype))
352+
return Array._new(np.ones_like(x._array, dtype=dtype), device=device)
353353

354354

355355
def tril(x: Array, /, *, k: int = 0) -> Array:
@@ -377,7 +377,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array:
377377
if x.ndim < 2:
378378
# Note: Unlike np.triu, x must be at least 2-D
379379
raise ValueError("x must be at least 2-dimensional for triu")
380-
return Array._new(np.triu(x._array, k=k))
380+
return Array._new(np.triu(x._array, k=k), device=device)
381381

382382

383383
def zeros(
@@ -398,7 +398,7 @@ def zeros(
398398

399399
if dtype is not None:
400400
dtype = dtype._np_dtype
401-
return Array._new(np.zeros(shape, dtype=dtype))
401+
return Array._new(np.zeros(shape, dtype=dtype), device=device)
402402

403403

404404
def zeros_like(
@@ -416,4 +416,4 @@ def zeros_like(
416416

417417
if dtype is not None:
418418
dtype = dtype._np_dtype
419-
return Array._new(np.zeros_like(x._array, dtype=dtype))
419+
return Array._new(np.zeros_like(x._array, dtype=dtype), device=device)

array_api_strict/_typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
Protocol,
2828
)
2929

30-
from ._array_object import Array, _cpu_device
30+
from ._array_object import Array, _device
3131
from ._dtypes import _DType
3232

3333
_T_co = TypeVar("_T_co", covariant=True)
@@ -37,7 +37,7 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
3737
def __len__(self, /) -> int: ...
3838

3939

40-
Device = _cpu_device
40+
Device = _device
4141

4242
Dtype = _DType
4343

array_api_strict/tests/test_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def test_python_scalar_construtors():
319319
def test_device_property():
320320
a = ones((3, 4))
321321
assert a.device == CPU_DEVICE
322-
assert a.device != 'cpu'
322+
assert not isinstance(a.device, str)
323323

324324
assert all(equal(a.to_device(CPU_DEVICE), a))
325325
assert_raises(ValueError, lambda: a.to_device('cpu'))

0 commit comments

Comments
 (0)