24
24
import dpctl .tensor ._tensor_impl as ti
25
25
import dpctl .utils
26
26
from dpctl .tensor ._device import normalize_queue_device
27
+ from dpctl .tensor ._usmarray import _is_object_with_buffer_protocol
27
28
28
29
__doc__ = "Implementation of creation functions in :module:`dpctl.tensor`"
29
30
@@ -66,11 +67,12 @@ def _array_info_dispatch(obj):
66
67
return _empty_tuple , complex , _host_set
67
68
if isinstance (obj , (list , tuple , range )):
68
69
return _array_info_sequence (obj )
69
- if any (
70
- isinstance (obj , s )
71
- for s in [np .integer , np .floating , np .complexfloating , np .bool_ ]
72
- ):
73
- return _empty_tuple , obj .dtype , _host_set
70
+ if _is_object_with_buffer_protocol (obj ):
71
+ np_obj = np .array (obj )
72
+ return np_obj .shape , np_obj .dtype , _host_set
73
+ if hasattr (obj , "__sycl_usm_array_interface__" ):
74
+ usm_ar = _usm_ndarray_from_suai (obj )
75
+ return usm_ar .shape , usm_ar .dtype , frozenset ([usm_ar .sycl_queue ])
74
76
raise ValueError (type (obj ))
75
77
76
78
@@ -219,6 +221,18 @@ def _map_to_device_dtype(dt, q):
219
221
raise RuntimeError (f"Unrecognized data type '{ dt } ' encountered." )
220
222
221
223
224
+ def _usm_ndarray_from_suai (obj ):
225
+ sua_iface = getattr (obj , "__sycl_usm_array_interface__" )
226
+ membuf = dpm .as_usm_memory (obj )
227
+ ary = dpt .usm_ndarray (
228
+ sua_iface ["shape" ],
229
+ dtype = sua_iface ["typestr" ],
230
+ buffer = membuf ,
231
+ strides = sua_iface .get ("strides" , None ),
232
+ )
233
+ return ary
234
+
235
+
222
236
def _asarray_from_numpy_ndarray (
223
237
ary , dtype = None , usm_type = None , sycl_queue = None , order = "K"
224
238
):
@@ -276,17 +290,6 @@ def _asarray_from_numpy_ndarray(
276
290
return res
277
291
278
292
279
- def _is_object_with_buffer_protocol (obj ):
280
- "Returns `True` if object support Python buffer protocol"
281
- try :
282
- # use context manager to ensure
283
- # buffer is instantly released
284
- with memoryview (obj ):
285
- return True
286
- except TypeError :
287
- return False
288
-
289
-
290
293
def _ensure_native_dtype_device_support (dtype , dev ) -> None :
291
294
"""Check that dtype is natively supported by device.
292
295
@@ -318,6 +321,122 @@ def _ensure_native_dtype_device_support(dtype, dev) -> None:
318
321
)
319
322
320
323
324
+ def _usm_types_walker (o , usm_types_list ):
325
+ if isinstance (o , dpt .usm_ndarray ):
326
+ usm_types_list .append (o .usm_type )
327
+ return
328
+ if hasattr (o , "__sycl_usm_array_interface__" ):
329
+ usm_ar = _usm_ndarray_from_suai (o )
330
+ usm_types_list .append (usm_ar .usm_type )
331
+ return
332
+ if isinstance (o , (list , tuple )):
333
+ for el in o :
334
+ _usm_types_walker (el , usm_types_list )
335
+ return
336
+ raise TypeError
337
+
338
+
339
+ def _device_copy_walker (seq_o , res , events ):
340
+ if isinstance (seq_o , dpt .usm_ndarray ):
341
+ exec_q = res .sycl_queue
342
+ ht_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
343
+ src = seq_o , dst = res , sycl_queue = exec_q
344
+ )
345
+ events .append (ht_ev )
346
+ return
347
+ if hasattr (seq_o , "__sycl_usm_array_interface__" ):
348
+ usm_ar = _usm_ndarray_from_suai (seq_o )
349
+ exec_q = res .sycl_queue
350
+ ht_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
351
+ src = usm_ar , dst = res , sycl_queue = exec_q
352
+ )
353
+ events .append (ht_ev )
354
+ return
355
+ if isinstance (seq_o , (list , tuple )):
356
+ for i , el in enumerate (seq_o ):
357
+ _device_copy_walker (el , res [i ], events )
358
+ return
359
+ raise TypeError
360
+
361
+
362
+ def _copy_through_host_walker (seq_o , usm_res ):
363
+ if isinstance (seq_o , dpt .usm_ndarray ):
364
+ usm_res [...] = dpt .asnumpy (seq_o ).copy ()
365
+ return
366
+ if hasattr (seq_o , "__sycl_usm_array_interface__" ):
367
+ usm_ar = _usm_ndarray_from_suai (seq_o )
368
+ usm_res [...] = dpt .asnumpy (usm_ar ).copy ()
369
+ return
370
+ if isinstance (seq_o , (list , tuple )):
371
+ for i , el in enumerate (seq_o ):
372
+ _copy_through_host_walker (el , usm_res [i ])
373
+ return
374
+ usm_res [...] = np .asarray (seq_o )
375
+
376
+
377
+ def _asarray_from_seq (
378
+ seq_obj ,
379
+ seq_shape ,
380
+ seq_dt ,
381
+ seq_dev ,
382
+ dtype = None ,
383
+ usm_type = None ,
384
+ sycl_queue = None ,
385
+ order = "C" ,
386
+ ):
387
+ "`obj` is a sequence"
388
+ if usm_type is None :
389
+ usm_types_in_seq = []
390
+ _usm_types_walker (seq_obj , usm_types_in_seq )
391
+ usm_type = dpctl .utils .get_coerced_usm_type (usm_types_in_seq )
392
+ dpctl .utils .validate_usm_type (usm_type )
393
+ if sycl_queue is None :
394
+ exec_q = seq_dev
395
+ alloc_q = seq_dev
396
+ else :
397
+ exec_q = dpctl .utils .get_execution_queue (
398
+ (
399
+ sycl_queue ,
400
+ seq_dev ,
401
+ )
402
+ )
403
+ alloc_q = sycl_queue
404
+ if dtype is None :
405
+ dtype = _map_to_device_dtype (seq_dt , alloc_q )
406
+ else :
407
+ _mapped_dt = _map_to_device_dtype (dtype , alloc_q )
408
+ if _mapped_dt != dtype :
409
+ raise ValueError (
410
+ f"Device { sycl_queue .sycl_device } "
411
+ f"does not support { dtype } natively."
412
+ )
413
+ dtype = _mapped_dt
414
+ if order in "KA" :
415
+ order = "C"
416
+ if isinstance (exec_q , dpctl .SyclQueue ):
417
+ res = dpt .empty (
418
+ seq_shape ,
419
+ dtype = dtype ,
420
+ usm_type = usm_type ,
421
+ sycl_queue = alloc_q ,
422
+ order = order ,
423
+ )
424
+ ht_events = []
425
+ _device_copy_walker (seq_obj , res , ht_events )
426
+ dpctl .SyclEvent .wait_for (ht_events )
427
+ return res
428
+ else :
429
+ res = dpt .empty (
430
+ seq_shape ,
431
+ dtype = dtype ,
432
+ usm_type = usm_type ,
433
+ sycl_queue = alloc_q ,
434
+ order = order ,
435
+ )
436
+ _copy_through_host_walker (seq_obj , res )
437
+ return res
438
+
439
+
321
440
def asarray (
322
441
obj ,
323
442
dtype = None ,
@@ -327,7 +446,9 @@ def asarray(
327
446
sycl_queue = None ,
328
447
order = "K" ,
329
448
):
330
- """
449
+ """ asarray(obj, dtype=None, copy=None, device=None, \
450
+ usm_type=None, sycl_queue=None, order="K")
451
+
331
452
Converts `obj` to :class:`dpctl.tensor.usm_ndarray`.
332
453
333
454
Args:
@@ -347,7 +468,7 @@ def asarray(
347
468
allocations if possible, but allowed to perform a copy otherwise.
348
469
Default: `None`.
349
470
order ("C","F","A","K", optional): memory layout of the output array.
350
- Default: "C "
471
+ Default: "K "
351
472
device (optional): array API concept of device where the output array
352
473
is created. `device` can be `None`, a oneAPI filter selector string,
353
474
an instance of :class:`dpctl.SyclDevice` corresponding to a
@@ -407,14 +528,7 @@ def asarray(
407
528
order = order ,
408
529
)
409
530
if hasattr (obj , "__sycl_usm_array_interface__" ):
410
- sua_iface = getattr (obj , "__sycl_usm_array_interface__" )
411
- membuf = dpm .as_usm_memory (obj )
412
- ary = dpt .usm_ndarray (
413
- sua_iface ["shape" ],
414
- dtype = sua_iface ["typestr" ],
415
- buffer = membuf ,
416
- strides = sua_iface .get ("strides" , None ),
417
- )
531
+ ary = _usm_ndarray_from_suai (obj )
418
532
return _asarray_from_usm_ndarray (
419
533
ary ,
420
534
dtype = dtype ,
@@ -452,7 +566,7 @@ def asarray(
452
566
raise ValueError (
453
567
"Converting Python sequence to usm_ndarray requires a copy"
454
568
)
455
- _ , _ , devs = _array_info_sequence (obj )
569
+ seq_shape , seq_dt , devs = _array_info_sequence (obj )
456
570
if devs == _host_set :
457
571
return _asarray_from_numpy_ndarray (
458
572
np .asarray (obj , dtype = dtype , order = order ),
@@ -461,7 +575,17 @@ def asarray(
461
575
sycl_queue = sycl_queue ,
462
576
order = order ,
463
577
)
464
- # for sequences
578
+ elif len (devs ) == 1 :
579
+ return _asarray_from_seq (
580
+ obj ,
581
+ seq_shape ,
582
+ seq_dt ,
583
+ list (devs )[0 ],
584
+ dtype = dtype ,
585
+ usm_type = usm_type ,
586
+ sycl_queue = sycl_queue ,
587
+ order = order ,
588
+ )
465
589
raise NotImplementedError (
466
590
"Converting Python sequences is not implemented"
467
591
)
0 commit comments