@@ -65,7 +65,13 @@ def _array_info_dispatch(obj):
65
65
return _empty_tuple , int , _host_set
66
66
if isinstance (obj , complex ):
67
67
return _empty_tuple , complex , _host_set
68
- if isinstance (obj , (list , tuple , range )):
68
+ if isinstance (
69
+ obj ,
70
+ (
71
+ list ,
72
+ tuple ,
73
+ ),
74
+ ):
69
75
return _array_info_sequence (obj )
70
76
if _is_object_with_buffer_protocol (obj ):
71
77
np_obj = np .array (obj )
@@ -329,7 +335,11 @@ def _usm_types_walker(o, usm_types_list):
329
335
usm_ar = _usm_ndarray_from_suai (o )
330
336
usm_types_list .append (usm_ar .usm_type )
331
337
return
332
- if isinstance (o , (list , tuple )):
338
+ if _is_object_with_buffer_protocol (o ):
339
+ return
340
+ if isinstance (o , (int , bool , float , complex )):
341
+ return
342
+ if isinstance (o , (list , tuple , range )):
333
343
for el in o :
334
344
_usm_types_walker (el , usm_types_list )
335
345
return
@@ -361,11 +371,37 @@ def _device_copy_walker(seq_o, res, events):
361
371
362
372
def _copy_through_host_walker (seq_o , usm_res ):
363
373
if isinstance (seq_o , dpt .usm_ndarray ):
364
- usm_res [...] = dpt .asnumpy (seq_o ).copy ()
365
- return
374
+ if (
375
+ dpctl .utils .get_execution_queue (
376
+ (
377
+ usm_res .sycl_queue ,
378
+ seq_o .sycl_queue ,
379
+ )
380
+ )
381
+ is None
382
+ ):
383
+ usm_res [...] = dpt .asnumpy (seq_o ).copy ()
384
+ return
385
+ else :
386
+ usm_res [...] = seq_o
366
387
if hasattr (seq_o , "__sycl_usm_array_interface__" ):
367
388
usm_ar = _usm_ndarray_from_suai (seq_o )
368
- usm_res [...] = dpt .asnumpy (usm_ar ).copy ()
389
+ if (
390
+ dpctl .utils .get_execution_queue (
391
+ (
392
+ usm_res .sycl_queue ,
393
+ usm_ar .sycl_queue ,
394
+ )
395
+ )
396
+ is None
397
+ ):
398
+ usm_res [...] = dpt .asnumpy (usm_ar ).copy ()
399
+ else :
400
+ usm_res [...] = usm_ar
401
+ return
402
+ if _is_object_with_buffer_protocol (seq_o ):
403
+ np_ar = np .asarray (seq_o )
404
+ usm_res [...] = np_ar
369
405
return
370
406
if isinstance (seq_o , (list , tuple )):
371
407
for i , el in enumerate (seq_o ):
@@ -378,10 +414,10 @@ def _asarray_from_seq(
378
414
seq_obj ,
379
415
seq_shape ,
380
416
seq_dt ,
381
- seq_dev ,
417
+ alloc_q ,
418
+ exec_q ,
382
419
dtype = None ,
383
420
usm_type = None ,
384
- sycl_queue = None ,
385
421
order = "C" ,
386
422
):
387
423
"`obj` is a sequence"
@@ -390,24 +426,13 @@ def _asarray_from_seq(
390
426
_usm_types_walker (seq_obj , usm_types_in_seq )
391
427
usm_type = dpctl .utils .get_coerced_usm_type (usm_types_in_seq )
392
428
dpctl .utils .validate_usm_type (usm_type )
393
- if sycl_queue is None :
394
- exec_q = seq_dev
395
- alloc_q = seq_dev
396
- else :
397
- exec_q = dpctl .utils .get_execution_queue (
398
- (
399
- sycl_queue ,
400
- seq_dev ,
401
- )
402
- )
403
- alloc_q = sycl_queue
404
429
if dtype is None :
405
430
dtype = _map_to_device_dtype (seq_dt , alloc_q )
406
431
else :
407
432
_mapped_dt = _map_to_device_dtype (dtype , alloc_q )
408
433
if _mapped_dt != dtype :
409
434
raise ValueError (
410
- f"Device { sycl_queue .sycl_device } "
435
+ f"Device { alloc_q .sycl_device } "
411
436
f"does not support { dtype } natively."
412
437
)
413
438
dtype = _mapped_dt
@@ -437,6 +462,39 @@ def _asarray_from_seq(
437
462
return res
438
463
439
464
465
+ def _asarray_from_seq_single_device (
466
+ obj ,
467
+ seq_shape ,
468
+ seq_dt ,
469
+ seq_dev ,
470
+ dtype = None ,
471
+ usm_type = None ,
472
+ sycl_queue = None ,
473
+ order = "C" ,
474
+ ):
475
+ if sycl_queue is None :
476
+ exec_q = seq_dev
477
+ alloc_q = seq_dev
478
+ else :
479
+ exec_q = dpctl .utils .get_execution_queue (
480
+ (
481
+ sycl_queue ,
482
+ seq_dev ,
483
+ )
484
+ )
485
+ alloc_q = sycl_queue
486
+ return _asarray_from_seq (
487
+ obj ,
488
+ seq_shape ,
489
+ seq_dt ,
490
+ alloc_q ,
491
+ exec_q ,
492
+ dtype = dtype ,
493
+ usm_type = usm_type ,
494
+ order = order ,
495
+ )
496
+
497
+
440
498
def asarray (
441
499
obj ,
442
500
dtype = None ,
@@ -576,16 +634,42 @@ def asarray(
576
634
order = order ,
577
635
)
578
636
elif len (devs ) == 1 :
579
- return _asarray_from_seq (
637
+ seq_dev = list (devs )[0 ]
638
+ return _asarray_from_seq_single_device (
580
639
obj ,
581
640
seq_shape ,
582
641
seq_dt ,
583
- list ( devs )[ 0 ] ,
642
+ seq_dev ,
584
643
dtype = dtype ,
585
644
usm_type = usm_type ,
586
645
sycl_queue = sycl_queue ,
587
646
order = order ,
588
647
)
648
+ elif len (devs ) > 1 :
649
+ devs = [dev for dev in devs if dev is not None ]
650
+ if sycl_queue is None :
651
+ if len (devs ) == 1 :
652
+ alloc_q = devs [0 ]
653
+ else :
654
+ raise dpctl .utils .ExecutionPlacementError (
655
+ "Please specify `device` or `sycl_queue` keyword "
656
+ "argument to determine where to allocate the "
657
+ "resulting array."
658
+ )
659
+ else :
660
+ alloc_q = sycl_queue
661
+ return _asarray_from_seq (
662
+ obj ,
663
+ seq_shape ,
664
+ seq_dt ,
665
+ alloc_q ,
666
+ # force copying via host
667
+ None ,
668
+ dtype = dtype ,
669
+ usm_type = usm_type ,
670
+ order = order ,
671
+ )
672
+
589
673
raise NotImplementedError (
590
674
"Converting Python sequences is not implemented"
591
675
)
0 commit comments