@@ -329,6 +329,10 @@ def _usm_types_walker(o, usm_types_list):
329
329
usm_ar = _usm_ndarray_from_suai (o )
330
330
usm_types_list .append (usm_ar .usm_type )
331
331
return
332
+ if _is_object_with_buffer_protocol (o ):
333
+ return
334
+ if isinstance (o , (int , bool , float , complex )):
335
+ return
332
336
if isinstance (o , (list , tuple )):
333
337
for el in o :
334
338
_usm_types_walker (el , usm_types_list )
@@ -361,11 +365,37 @@ def _device_copy_walker(seq_o, res, events):
361
365
362
366
def _copy_through_host_walker (seq_o , usm_res ):
363
367
if isinstance (seq_o , dpt .usm_ndarray ):
364
- usm_res [...] = dpt .asnumpy (seq_o ).copy ()
365
- return
368
+ if (
369
+ dpctl .utils .get_execution_queue (
370
+ (
371
+ usm_res .sycl_queue ,
372
+ seq_o .sycl_queue ,
373
+ )
374
+ )
375
+ is None
376
+ ):
377
+ usm_res [...] = dpt .asnumpy (seq_o ).copy ()
378
+ return
379
+ else :
380
+ usm_res [...] = seq_o
366
381
if hasattr (seq_o , "__sycl_usm_array_interface__" ):
367
382
usm_ar = _usm_ndarray_from_suai (seq_o )
368
- usm_res [...] = dpt .asnumpy (usm_ar ).copy ()
383
+ if (
384
+ dpctl .utils .get_execution_queue (
385
+ (
386
+ usm_res .sycl_queue ,
387
+ usm_ar .sycl_queue ,
388
+ )
389
+ )
390
+ is None
391
+ ):
392
+ usm_res [...] = dpt .asnumpy (usm_ar ).copy ()
393
+ else :
394
+ usm_res [...] = usm_ar
395
+ return
396
+ if _is_object_with_buffer_protocol (seq_o ):
397
+ np_ar = np .asarray (seq_o )
398
+ usm_res [...] = np_ar
369
399
return
370
400
if isinstance (seq_o , (list , tuple )):
371
401
for i , el in enumerate (seq_o ):
@@ -378,10 +408,10 @@ def _asarray_from_seq(
378
408
seq_obj ,
379
409
seq_shape ,
380
410
seq_dt ,
381
- seq_dev ,
411
+ alloc_q ,
412
+ exec_q ,
382
413
dtype = None ,
383
414
usm_type = None ,
384
- sycl_queue = None ,
385
415
order = "C" ,
386
416
):
387
417
"`obj` is a sequence"
@@ -390,24 +420,13 @@ def _asarray_from_seq(
390
420
_usm_types_walker (seq_obj , usm_types_in_seq )
391
421
usm_type = dpctl .utils .get_coerced_usm_type (usm_types_in_seq )
392
422
dpctl .utils .validate_usm_type (usm_type )
393
- if sycl_queue is None :
394
- exec_q = seq_dev
395
- alloc_q = seq_dev
396
- else :
397
- exec_q = dpctl .utils .get_execution_queue (
398
- (
399
- sycl_queue ,
400
- seq_dev ,
401
- )
402
- )
403
- alloc_q = sycl_queue
404
423
if dtype is None :
405
424
dtype = _map_to_device_dtype (seq_dt , alloc_q )
406
425
else :
407
426
_mapped_dt = _map_to_device_dtype (dtype , alloc_q )
408
427
if _mapped_dt != dtype :
409
428
raise ValueError (
410
- f"Device { sycl_queue .sycl_device } "
429
+ f"Device { alloc_q .sycl_device } "
411
430
f"does not support { dtype } natively."
412
431
)
413
432
dtype = _mapped_dt
@@ -437,6 +456,39 @@ def _asarray_from_seq(
437
456
return res
438
457
439
458
459
+ def _asarray_from_seq_single_device (
460
+ obj ,
461
+ seq_shape ,
462
+ seq_dt ,
463
+ seq_dev ,
464
+ dtype = None ,
465
+ usm_type = None ,
466
+ sycl_queue = None ,
467
+ order = "C" ,
468
+ ):
469
+ if sycl_queue is None :
470
+ exec_q = seq_dev
471
+ alloc_q = seq_dev
472
+ else :
473
+ exec_q = dpctl .utils .get_execution_queue (
474
+ (
475
+ sycl_queue ,
476
+ seq_dev ,
477
+ )
478
+ )
479
+ alloc_q = sycl_queue
480
+ return _asarray_from_seq (
481
+ obj ,
482
+ seq_shape ,
483
+ seq_dt ,
484
+ alloc_q ,
485
+ exec_q ,
486
+ dtype = dtype ,
487
+ usm_type = usm_type ,
488
+ order = order ,
489
+ )
490
+
491
+
440
492
def asarray (
441
493
obj ,
442
494
dtype = None ,
@@ -576,16 +628,51 @@ def asarray(
576
628
order = order ,
577
629
)
578
630
elif len (devs ) == 1 :
579
- return _asarray_from_seq (
631
+ seq_dev = list (devs )[0 ]
632
+ return _asarray_from_seq_single_device (
580
633
obj ,
581
634
seq_shape ,
582
635
seq_dt ,
583
- list ( devs )[ 0 ] ,
636
+ seq_dev ,
584
637
dtype = dtype ,
585
638
usm_type = usm_type ,
586
639
sycl_queue = sycl_queue ,
587
640
order = order ,
588
641
)
642
+ elif len (devs ) > 1 :
643
+ devs = [dev for dev in devs if dev is not None ]
644
+ if len (devs ) == 1 :
645
+ seq_dev = devs [0 ]
646
+ return _asarray_from_seq_single_device (
647
+ obj ,
648
+ seq_shape ,
649
+ seq_dt ,
650
+ seq_dev ,
651
+ dtype = dtype ,
652
+ usm_type = usm_type ,
653
+ sycl_queue = sycl_queue ,
654
+ order = order ,
655
+ )
656
+ else :
657
+ if sycl_queue is None :
658
+ raise dpctl .utils .ExecutionPlacementError (
659
+ "Please specify `device` or `sycl_queue` keyword "
660
+ "argument to determine where to allocated the "
661
+ "resulting array."
662
+ )
663
+ alloc_q = sycl_queue
664
+ exec_q = None # force copying via host
665
+ return _asarray_from_seq (
666
+ obj ,
667
+ seq_shape ,
668
+ seq_dt ,
669
+ alloc_q ,
670
+ exec_q ,
671
+ dtype = dtype ,
672
+ usm_type = usm_type ,
673
+ order = order ,
674
+ )
675
+
589
676
raise NotImplementedError (
590
677
"Converting Python sequences is not implemented"
591
678
)
0 commit comments