Skip to content

Commit 750fc8b

Browse files
Added support for asarray on sequences with arrays from incop. queues
``` In [1]: import dpctl, dpctl.tensor as dpt In [2]: m = dpt.zeros((2, 4), dtype="i2", device="cpu") In [3]: w = dpt.full(4, -1, device="opencl:gpu") In [4]: res = dpt.asarray([m, [w, [0,] * 4 ]], dtype="f4", device="cpu") In [5]: res Out[5]: usm_ndarray([[[ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], [[-1., -1., -1., -1.], [ 0., 0., 0., 0.]]], dtype=float32) In [6]: res.device Out[6]: Device(opencl:cpu:0) ``` This usage requires user to specify sycl_queue or device keyword to indicate where the result is created, or ExecitionPlacementError is raised.
1 parent d945d95 commit 750fc8b

File tree

1 file changed

+106
-19
lines changed

1 file changed

+106
-19
lines changed

dpctl/tensor/_ctors.py

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ def _usm_types_walker(o, usm_types_list):
329329
usm_ar = _usm_ndarray_from_suai(o)
330330
usm_types_list.append(usm_ar.usm_type)
331331
return
332+
if _is_object_with_buffer_protocol(o):
333+
return
334+
if isinstance(o, (int, bool, float, complex)):
335+
return
332336
if isinstance(o, (list, tuple)):
333337
for el in o:
334338
_usm_types_walker(el, usm_types_list)
@@ -361,11 +365,37 @@ def _device_copy_walker(seq_o, res, events):
361365

362366
def _copy_through_host_walker(seq_o, usm_res):
363367
if isinstance(seq_o, dpt.usm_ndarray):
364-
usm_res[...] = dpt.asnumpy(seq_o).copy()
365-
return
368+
if (
369+
dpctl.utils.get_execution_queue(
370+
(
371+
usm_res.sycl_queue,
372+
seq_o.sycl_queue,
373+
)
374+
)
375+
is None
376+
):
377+
usm_res[...] = dpt.asnumpy(seq_o).copy()
378+
return
379+
else:
380+
usm_res[...] = seq_o
366381
if hasattr(seq_o, "__sycl_usm_array_interface__"):
367382
usm_ar = _usm_ndarray_from_suai(seq_o)
368-
usm_res[...] = dpt.asnumpy(usm_ar).copy()
383+
if (
384+
dpctl.utils.get_execution_queue(
385+
(
386+
usm_res.sycl_queue,
387+
usm_ar.sycl_queue,
388+
)
389+
)
390+
is None
391+
):
392+
usm_res[...] = dpt.asnumpy(usm_ar).copy()
393+
else:
394+
usm_res[...] = usm_ar
395+
return
396+
if _is_object_with_buffer_protocol(seq_o):
397+
np_ar = np.asarray(seq_o)
398+
usm_res[...] = np_ar
369399
return
370400
if isinstance(seq_o, (list, tuple)):
371401
for i, el in enumerate(seq_o):
@@ -378,10 +408,10 @@ def _asarray_from_seq(
378408
seq_obj,
379409
seq_shape,
380410
seq_dt,
381-
seq_dev,
411+
alloc_q,
412+
exec_q,
382413
dtype=None,
383414
usm_type=None,
384-
sycl_queue=None,
385415
order="C",
386416
):
387417
"`obj` is a sequence"
@@ -390,24 +420,13 @@ def _asarray_from_seq(
390420
_usm_types_walker(seq_obj, usm_types_in_seq)
391421
usm_type = dpctl.utils.get_coerced_usm_type(usm_types_in_seq)
392422
dpctl.utils.validate_usm_type(usm_type)
393-
if sycl_queue is None:
394-
exec_q = seq_dev
395-
alloc_q = seq_dev
396-
else:
397-
exec_q = dpctl.utils.get_execution_queue(
398-
(
399-
sycl_queue,
400-
seq_dev,
401-
)
402-
)
403-
alloc_q = sycl_queue
404423
if dtype is None:
405424
dtype = _map_to_device_dtype(seq_dt, alloc_q)
406425
else:
407426
_mapped_dt = _map_to_device_dtype(dtype, alloc_q)
408427
if _mapped_dt != dtype:
409428
raise ValueError(
410-
f"Device {sycl_queue.sycl_device} "
429+
f"Device {alloc_q.sycl_device} "
411430
f"does not support {dtype} natively."
412431
)
413432
dtype = _mapped_dt
@@ -437,6 +456,39 @@ def _asarray_from_seq(
437456
return res
438457

439458

459+
def _asarray_from_seq_single_device(
460+
obj,
461+
seq_shape,
462+
seq_dt,
463+
seq_dev,
464+
dtype=None,
465+
usm_type=None,
466+
sycl_queue=None,
467+
order="C",
468+
):
469+
if sycl_queue is None:
470+
exec_q = seq_dev
471+
alloc_q = seq_dev
472+
else:
473+
exec_q = dpctl.utils.get_execution_queue(
474+
(
475+
sycl_queue,
476+
seq_dev,
477+
)
478+
)
479+
alloc_q = sycl_queue
480+
return _asarray_from_seq(
481+
obj,
482+
seq_shape,
483+
seq_dt,
484+
alloc_q,
485+
exec_q,
486+
dtype=dtype,
487+
usm_type=usm_type,
488+
order=order,
489+
)
490+
491+
440492
def asarray(
441493
obj,
442494
dtype=None,
@@ -576,16 +628,51 @@ def asarray(
576628
order=order,
577629
)
578630
elif len(devs) == 1:
579-
return _asarray_from_seq(
631+
seq_dev = list(devs)[0]
632+
return _asarray_from_seq_single_device(
580633
obj,
581634
seq_shape,
582635
seq_dt,
583-
list(devs)[0],
636+
seq_dev,
584637
dtype=dtype,
585638
usm_type=usm_type,
586639
sycl_queue=sycl_queue,
587640
order=order,
588641
)
642+
elif len(devs) > 1:
643+
devs = [dev for dev in devs if dev is not None]
644+
if len(devs) == 1:
645+
seq_dev = devs[0]
646+
return _asarray_from_seq_single_device(
647+
obj,
648+
seq_shape,
649+
seq_dt,
650+
seq_dev,
651+
dtype=dtype,
652+
usm_type=usm_type,
653+
sycl_queue=sycl_queue,
654+
order=order,
655+
)
656+
else:
657+
if sycl_queue is None:
658+
raise dpctl.utils.ExecutionPlacementError(
659+
"Please specify `device` or `sycl_queue` keyword "
660+
"argument to determine where to allocated the "
661+
"resulting array."
662+
)
663+
alloc_q = sycl_queue
664+
exec_q = None # force copying via host
665+
return _asarray_from_seq(
666+
obj,
667+
seq_shape,
668+
seq_dt,
669+
alloc_q,
670+
exec_q,
671+
dtype=dtype,
672+
usm_type=usm_type,
673+
order=order,
674+
)
675+
589676
raise NotImplementedError(
590677
"Converting Python sequences is not implemented"
591678
)

0 commit comments

Comments
 (0)