Skip to content

Commit 2a3558c

Browse files
Merge pull request #1151 from IntelPython/as-type-supports-sequence-of-arrays-from-multiple-devices
As type supports sequence of arrays from multiple devices
2 parents 09acd20 + ce5425b commit 2a3558c

File tree

2 files changed

+139
-21
lines changed

2 files changed

+139
-21
lines changed

dpctl/tensor/_ctors.py

Lines changed: 105 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ def _array_info_dispatch(obj):
6565
return _empty_tuple, int, _host_set
6666
if isinstance(obj, complex):
6767
return _empty_tuple, complex, _host_set
68-
if isinstance(obj, (list, tuple, range)):
68+
if isinstance(
69+
obj,
70+
(
71+
list,
72+
tuple,
73+
),
74+
):
6975
return _array_info_sequence(obj)
7076
if _is_object_with_buffer_protocol(obj):
7177
np_obj = np.array(obj)
@@ -329,7 +335,11 @@ def _usm_types_walker(o, usm_types_list):
329335
usm_ar = _usm_ndarray_from_suai(o)
330336
usm_types_list.append(usm_ar.usm_type)
331337
return
332-
if isinstance(o, (list, tuple)):
338+
if _is_object_with_buffer_protocol(o):
339+
return
340+
if isinstance(o, (int, bool, float, complex)):
341+
return
342+
if isinstance(o, (list, tuple, range)):
333343
for el in o:
334344
_usm_types_walker(el, usm_types_list)
335345
return
@@ -361,11 +371,37 @@ def _device_copy_walker(seq_o, res, events):
361371

362372
def _copy_through_host_walker(seq_o, usm_res):
363373
if isinstance(seq_o, dpt.usm_ndarray):
364-
usm_res[...] = dpt.asnumpy(seq_o).copy()
365-
return
374+
if (
375+
dpctl.utils.get_execution_queue(
376+
(
377+
usm_res.sycl_queue,
378+
seq_o.sycl_queue,
379+
)
380+
)
381+
is None
382+
):
383+
usm_res[...] = dpt.asnumpy(seq_o).copy()
384+
return
385+
else:
386+
usm_res[...] = seq_o
366387
if hasattr(seq_o, "__sycl_usm_array_interface__"):
367388
usm_ar = _usm_ndarray_from_suai(seq_o)
368-
usm_res[...] = dpt.asnumpy(usm_ar).copy()
389+
if (
390+
dpctl.utils.get_execution_queue(
391+
(
392+
usm_res.sycl_queue,
393+
usm_ar.sycl_queue,
394+
)
395+
)
396+
is None
397+
):
398+
usm_res[...] = dpt.asnumpy(usm_ar).copy()
399+
else:
400+
usm_res[...] = usm_ar
401+
return
402+
if _is_object_with_buffer_protocol(seq_o):
403+
np_ar = np.asarray(seq_o)
404+
usm_res[...] = np_ar
369405
return
370406
if isinstance(seq_o, (list, tuple)):
371407
for i, el in enumerate(seq_o):
@@ -378,10 +414,10 @@ def _asarray_from_seq(
378414
seq_obj,
379415
seq_shape,
380416
seq_dt,
381-
seq_dev,
417+
alloc_q,
418+
exec_q,
382419
dtype=None,
383420
usm_type=None,
384-
sycl_queue=None,
385421
order="C",
386422
):
387423
"`obj` is a sequence"
@@ -390,24 +426,13 @@ def _asarray_from_seq(
390426
_usm_types_walker(seq_obj, usm_types_in_seq)
391427
usm_type = dpctl.utils.get_coerced_usm_type(usm_types_in_seq)
392428
dpctl.utils.validate_usm_type(usm_type)
393-
if sycl_queue is None:
394-
exec_q = seq_dev
395-
alloc_q = seq_dev
396-
else:
397-
exec_q = dpctl.utils.get_execution_queue(
398-
(
399-
sycl_queue,
400-
seq_dev,
401-
)
402-
)
403-
alloc_q = sycl_queue
404429
if dtype is None:
405430
dtype = _map_to_device_dtype(seq_dt, alloc_q)
406431
else:
407432
_mapped_dt = _map_to_device_dtype(dtype, alloc_q)
408433
if _mapped_dt != dtype:
409434
raise ValueError(
410-
f"Device {sycl_queue.sycl_device} "
435+
f"Device {alloc_q.sycl_device} "
411436
f"does not support {dtype} natively."
412437
)
413438
dtype = _mapped_dt
@@ -437,6 +462,39 @@ def _asarray_from_seq(
437462
return res
438463

439464

465+
def _asarray_from_seq_single_device(
466+
obj,
467+
seq_shape,
468+
seq_dt,
469+
seq_dev,
470+
dtype=None,
471+
usm_type=None,
472+
sycl_queue=None,
473+
order="C",
474+
):
475+
if sycl_queue is None:
476+
exec_q = seq_dev
477+
alloc_q = seq_dev
478+
else:
479+
exec_q = dpctl.utils.get_execution_queue(
480+
(
481+
sycl_queue,
482+
seq_dev,
483+
)
484+
)
485+
alloc_q = sycl_queue
486+
return _asarray_from_seq(
487+
obj,
488+
seq_shape,
489+
seq_dt,
490+
alloc_q,
491+
exec_q,
492+
dtype=dtype,
493+
usm_type=usm_type,
494+
order=order,
495+
)
496+
497+
440498
def asarray(
441499
obj,
442500
dtype=None,
@@ -576,16 +634,42 @@ def asarray(
576634
order=order,
577635
)
578636
elif len(devs) == 1:
579-
return _asarray_from_seq(
637+
seq_dev = list(devs)[0]
638+
return _asarray_from_seq_single_device(
580639
obj,
581640
seq_shape,
582641
seq_dt,
583-
list(devs)[0],
642+
seq_dev,
584643
dtype=dtype,
585644
usm_type=usm_type,
586645
sycl_queue=sycl_queue,
587646
order=order,
588647
)
648+
elif len(devs) > 1:
649+
devs = [dev for dev in devs if dev is not None]
650+
if sycl_queue is None:
651+
if len(devs) == 1:
652+
alloc_q = devs[0]
653+
else:
654+
raise dpctl.utils.ExecutionPlacementError(
655+
"Please specify `device` or `sycl_queue` keyword "
656+
"argument to determine where to allocate the "
657+
"resulting array."
658+
)
659+
else:
660+
alloc_q = sycl_queue
661+
return _asarray_from_seq(
662+
obj,
663+
seq_shape,
664+
seq_dt,
665+
alloc_q,
666+
# force copying via host
667+
None,
668+
dtype=dtype,
669+
usm_type=usm_type,
670+
order=order,
671+
)
672+
589673
raise NotImplementedError(
590674
"Converting Python sequences is not implemented"
591675
)

dpctl/tests/test_tensor_asarray.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,37 @@ def shape(self):
321321
x = dpt.asarray([d, d], sycl_queue=q)
322322
assert x.sycl_queue == q
323323
assert x.shape == (2,) + d.shape
324+
325+
326+
def test_asarray_seq_of_arrays_on_different_queues():
327+
q = get_queue_or_skip()
328+
329+
m = dpt.empty((2, 4), dtype="i2", sycl_queue=q)
330+
q2 = dpctl.SyclQueue()
331+
w = dpt.empty(4, dtype="i1", sycl_queue=q2)
332+
q3 = dpctl.SyclQueue()
333+
py_seq = [
334+
0,
335+
] * w.shape[0]
336+
res = dpt.asarray([m, [w, py_seq]], sycl_queue=q3)
337+
assert res.sycl_queue == q3
338+
assert dpt.isdtype(res.dtype, "integral")
339+
340+
res = dpt.asarray([m, [w, range(w.shape[0])]], sycl_queue=q3)
341+
assert res.sycl_queue == q3
342+
assert dpt.isdtype(res.dtype, "integral")
343+
344+
res = dpt.asarray([m, [w, w]], sycl_queue=q)
345+
assert res.sycl_queue == q
346+
assert dpt.isdtype(res.dtype, "integral")
347+
348+
res = dpt.asarray([m, [w, dpt.asnumpy(w)]], sycl_queue=q2)
349+
assert res.sycl_queue == q2
350+
assert dpt.isdtype(res.dtype, "integral")
351+
352+
res = dpt.asarray([w, dpt.asnumpy(w)])
353+
assert res.sycl_queue == w.sycl_queue
354+
assert dpt.isdtype(res.dtype, "integral")
355+
356+
with pytest.raises(dpctl.utils.ExecutionPlacementError):
357+
dpt.asarray([m, [w, py_seq]])

0 commit comments

Comments
 (0)