@@ -32,9 +32,9 @@ def _supports_buffer_protocol(obj):
32
32
def _check_device (device ):
33
33
# _array_object imports in this file are inside the functions to avoid
34
34
# circular imports
35
- from ._array_object import CPU_DEVICE
35
+ from ._array_object import Device
36
36
37
- if device not in [ CPU_DEVICE , None ] :
37
+ if device is not None and not isinstance ( device , Device ) :
38
38
raise ValueError (f"Unsupported device { device !r} " )
39
39
40
40
def asarray (
@@ -79,7 +79,7 @@ def asarray(
79
79
return Array ._new (new_array )
80
80
elif _supports_buffer_protocol (obj ):
81
81
# Buffer protocol will always support no-copy
82
- return Array ._new (np .array (obj , copy = copy , dtype = _np_dtype ))
82
+ return Array ._new (np .array (obj , copy = copy , dtype = _np_dtype ), device = device )
83
83
else :
84
84
# No-copy is unsupported for Python built-in types.
85
85
raise ValueError ("Unable to avoid copy while creating an array from given object." )
@@ -89,13 +89,13 @@ def asarray(
89
89
copy = False
90
90
91
91
if isinstance (obj , Array ):
92
- return Array ._new (np .array (obj ._array , copy = copy , dtype = _np_dtype ))
92
+ return Array ._new (np .array (obj ._array , copy = copy , dtype = _np_dtype ), device = device )
93
93
if dtype is None and isinstance (obj , int ) and (obj > 2 ** 64 or obj < - (2 ** 63 )):
94
94
# Give a better error message in this case. NumPy would convert this
95
95
# to an object array. TODO: This won't handle large integers in lists.
96
96
raise OverflowError ("Integer out of bounds for array dtypes" )
97
97
res = np .array (obj , dtype = _np_dtype , copy = copy )
98
- return Array ._new (res )
98
+ return Array ._new (res , device = device )
99
99
100
100
101
101
def arange (
@@ -119,7 +119,7 @@ def arange(
119
119
120
120
if dtype is not None :
121
121
dtype = dtype ._np_dtype
122
- return Array ._new (np .arange (start , stop = stop , step = step , dtype = dtype ))
122
+ return Array ._new (np .arange (start , stop = stop , step = step , dtype = dtype ), device = device )
123
123
124
124
125
125
def empty (
@@ -140,7 +140,7 @@ def empty(
140
140
141
141
if dtype is not None :
142
142
dtype = dtype ._np_dtype
143
- return Array ._new (np .empty (shape , dtype = dtype ))
143
+ return Array ._new (np .empty (shape , dtype = dtype ), device = device )
144
144
145
145
146
146
def empty_like (
@@ -158,7 +158,7 @@ def empty_like(
158
158
159
159
if dtype is not None :
160
160
dtype = dtype ._np_dtype
161
- return Array ._new (np .empty_like (x ._array , dtype = dtype ))
161
+ return Array ._new (np .empty_like (x ._array , dtype = dtype ), device = device )
162
162
163
163
164
164
def eye (
@@ -182,7 +182,7 @@ def eye(
182
182
183
183
if dtype is not None :
184
184
dtype = dtype ._np_dtype
185
- return Array ._new (np .eye (n_rows , M = n_cols , k = k , dtype = dtype ))
185
+ return Array ._new (np .eye (n_rows , M = n_cols , k = k , dtype = dtype ), device = device )
186
186
187
187
188
188
_default = object ()
@@ -237,7 +237,7 @@ def full(
237
237
# This will happen if the fill value is not something that NumPy
238
238
# coerces to one of the acceptable dtypes.
239
239
raise TypeError ("Invalid input to full" )
240
- return Array ._new (res )
240
+ return Array ._new (res , device = device )
241
241
242
242
243
243
def full_like (
@@ -265,7 +265,7 @@ def full_like(
265
265
# This will happen if the fill value is not something that NumPy
266
266
# coerces to one of the acceptable dtypes.
267
267
raise TypeError ("Invalid input to full_like" )
268
- return Array ._new (res )
268
+ return Array ._new (res , device = device )
269
269
270
270
271
271
def linspace (
@@ -290,7 +290,7 @@ def linspace(
290
290
291
291
if dtype is not None :
292
292
dtype = dtype ._np_dtype
293
- return Array ._new (np .linspace (start , stop , num , dtype = dtype , endpoint = endpoint ))
293
+ return Array ._new (np .linspace (start , stop , num , dtype = dtype , endpoint = endpoint ), device = device )
294
294
295
295
296
296
def meshgrid (* arrays : Array , indexing : str = "xy" ) -> List [Array ]:
@@ -308,7 +308,7 @@ def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
308
308
raise ValueError ("meshgrid inputs must all have the same dtype" )
309
309
310
310
return [
311
- Array ._new (array )
311
+ Array ._new (array , device = device )
312
312
for array in np .meshgrid (* [a ._array for a in arrays ], indexing = indexing )
313
313
]
314
314
@@ -331,7 +331,7 @@ def ones(
331
331
332
332
if dtype is not None :
333
333
dtype = dtype ._np_dtype
334
- return Array ._new (np .ones (shape , dtype = dtype ))
334
+ return Array ._new (np .ones (shape , dtype = dtype ), device = device )
335
335
336
336
337
337
def ones_like (
@@ -349,7 +349,7 @@ def ones_like(
349
349
350
350
if dtype is not None :
351
351
dtype = dtype ._np_dtype
352
- return Array ._new (np .ones_like (x ._array , dtype = dtype ))
352
+ return Array ._new (np .ones_like (x ._array , dtype = dtype ), device = device )
353
353
354
354
355
355
def tril (x : Array , / , * , k : int = 0 ) -> Array :
@@ -377,7 +377,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array:
377
377
if x .ndim < 2 :
378
378
# Note: Unlike np.triu, x must be at least 2-D
379
379
raise ValueError ("x must be at least 2-dimensional for triu" )
380
- return Array ._new (np .triu (x ._array , k = k ))
380
+ return Array ._new (np .triu (x ._array , k = k ), device = device )
381
381
382
382
383
383
def zeros (
@@ -398,7 +398,7 @@ def zeros(
398
398
399
399
if dtype is not None :
400
400
dtype = dtype ._np_dtype
401
- return Array ._new (np .zeros (shape , dtype = dtype ))
401
+ return Array ._new (np .zeros (shape , dtype = dtype ), device = device )
402
402
403
403
404
404
def zeros_like (
@@ -416,4 +416,4 @@ def zeros_like(
416
416
417
417
if dtype is not None :
418
418
dtype = dtype ._np_dtype
419
- return Array ._new (np .zeros_like (x ._array , dtype = dtype ))
419
+ return Array ._new (np .zeros_like (x ._array , dtype = dtype ), device = device )
0 commit comments