Skip to content

Commit 3352c3d

Browse files
Merge pull request #1139 from IntelPython/fix-gh-1134-asarray
2 parents df25aa0 + d294fb6 commit 3352c3d

File tree

5 files changed

+276
-61
lines changed

5 files changed

+276
-61
lines changed

dpctl/memory/_memory.pyx

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -679,17 +679,13 @@ cdef class _Memory:
679679
cdef class MemoryUSMShared(_Memory):
680680
"""
681681
MemoryUSMShared(nbytes, alignment=0, queue=None, copy=False)
682-
allocates nbytes of USM shared memory.
683682
684-
Non-positive alignments are not used (malloc_shared is used instead).
685-
For the queue=None case the ``dpctl.SyclQueue()`` is used to allocate
686-
memory.
683+
An object representing allocation of SYCL USM-shared memory.
687684
688-
MemoryUSMShared(usm_obj) constructor creates instance from `usm_obj`
689-
expected to implement `__sycl_usm_array_interface__` protocol and to expose
690-
a contiguous block of USM shared allocation. Use `copy=True` to
691-
perform a copy if USM type of the allocation represented by the argument
692-
is other than 'shared'.
685+
Non-positive ``alignment`` values are not ignored and
686+
the allocator ``malloc_shared`` is used for allocation instead.
687+
If ``queue`` is ``None`` a cached default-constructed
688+
:class:`dpctl.SyclQueue` is used to allocate memory.
693689
"""
694690
def __cinit__(self, other, *, Py_ssize_t alignment=0,
695691
SyclQueue queue=None, int copy=False):
@@ -720,17 +716,13 @@ cdef class MemoryUSMShared(_Memory):
720716
cdef class MemoryUSMHost(_Memory):
721717
"""
722718
MemoryUSMHost(nbytes, alignment=0, queue=None, copy=False)
723-
allocates nbytes of USM host memory.
724719
725-
Non-positive alignments are not used (malloc_host is used instead).
726-
For the queue=None case the ``dpctl.SyclQueue()`` is used to allocate
727-
memory.
720+
An object representing allocation of SYCL USM-host memory.
728721
729-
MemoryUSMDevice(usm_obj) constructor create instance from `usm_obj`
730-
expected to implement `__sycl_usm_array_interface__` protocol and to expose
731-
a contiguous block of USM host allocation. Use `copy=True` to
732-
perform a copy if USM type of the allocation represented by the argument
733-
is other than 'host'.
722+
Non-positive ``alignment`` values are not ignored and
723+
the allocator ``malloc_host`` is used for allocation instead.
724+
If ``queue`` is ``None`` a cached default-constructed
725+
:class:`dpctl.SyclQueue` is used to allocate memory.
734726
"""
735727
def __cinit__(self, other, *, Py_ssize_t alignment=0,
736728
SyclQueue queue=None, int copy=False):
@@ -762,17 +754,13 @@ cdef class MemoryUSMHost(_Memory):
762754
cdef class MemoryUSMDevice(_Memory):
763755
"""
764756
MemoryUSMDevice(nbytes, alignment=0, queue=None, copy=False)
765-
allocates nbytes of USM device memory.
766757
767-
Non-positive alignments are not used (malloc_device is used instead).
768-
For the queue=None case the ``dpctl.SyclQueue()`` is used to allocate
769-
memory.
758+
An object representing allocation of SYCL USM-device memory.
770759
771-
MemoryUSMDevice(usm_obj) constructor create instance from `usm_obj`
772-
expected to implement `__sycl_usm_array_interface__` protocol and exposing
773-
a contiguous block of USM device allocation. Use `copy=True` to
774-
perform a copy if USM type of the allocation represented by the argument
775-
is other than 'device'.
760+
Non-positive ``alignment`` values are not ignored and
761+
the allocator ``malloc_device`` is used for allocation instead.
762+
If ``queue`` is ``None`` a cached default-constructed
763+
:class:`dpctl.SyclQueue` is used to allocate memory.
776764
"""
777765
def __cinit__(self, other, *, Py_ssize_t alignment=0,
778766
SyclQueue queue=None, int copy=False):

dpctl/memory/_sycl_usm_array_interface_utils.pxi

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,27 +88,38 @@ cdef object _pointers_from_shape_and_stride(
8888
8989
Returns: tuple(min_disp, nbytes)
9090
"""
91+
cdef Py_ssize_t nelems = 1
92+
cdef Py_ssize_t min_disp = 0
93+
cdef Py_ssize_t max_disp = 0
94+
cdef int i
95+
cdef Py_ssize_t sh_i = 0
96+
cdef Py_ssize_t str_i = 0
9197
if (nd > 0):
9298
if (ary_strides is None):
9399
nelems = 1
94100
for si in ary_shape:
95101
sh_i = int(si)
96-
if (sh_i <= 0):
102+
if (sh_i < 0):
97103
raise ValueError("Array shape elements need to be positive")
98104
nelems = nelems * sh_i
99-
return (ary_offset, nelems * itemsize)
105+
return (ary_offset, max(nelems, 1) * itemsize)
100106
else:
101107
min_disp = ary_offset
102108
max_disp = ary_offset
103109
for i in range(nd):
104110
str_i = int(ary_strides[i])
105111
sh_i = int(ary_shape[i])
106-
if (sh_i <= 0):
112+
if (sh_i < 0):
107113
raise ValueError("Array shape elements need to be positive")
108-
if (str_i > 0):
109-
max_disp += str_i * (sh_i - 1)
114+
if (sh_i > 0):
115+
if (str_i > 0):
116+
max_disp += str_i * (sh_i - 1)
117+
else:
118+
min_disp += str_i * (sh_i - 1)
110119
else:
111-
min_disp += str_i * (sh_i - 1);
120+
nelems = 0
121+
if nelems == 0:
122+
return (ary_offset, itemsize)
112123
return (min_disp, (max_disp - min_disp + 1) * itemsize)
113124
elif (nd == 0):
114125
return (ary_offset, itemsize)

dpctl/tensor/_ctors.py

Lines changed: 152 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import dpctl.tensor._tensor_impl as ti
2525
import dpctl.utils
2626
from dpctl.tensor._device import normalize_queue_device
27+
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol
2728

2829
__doc__ = "Implementation of creation functions in :module:`dpctl.tensor`"
2930

@@ -66,11 +67,12 @@ def _array_info_dispatch(obj):
6667
return _empty_tuple, complex, _host_set
6768
if isinstance(obj, (list, tuple, range)):
6869
return _array_info_sequence(obj)
69-
if any(
70-
isinstance(obj, s)
71-
for s in [np.integer, np.floating, np.complexfloating, np.bool_]
72-
):
73-
return _empty_tuple, obj.dtype, _host_set
70+
if _is_object_with_buffer_protocol(obj):
71+
np_obj = np.array(obj)
72+
return np_obj.shape, np_obj.dtype, _host_set
73+
if hasattr(obj, "__sycl_usm_array_interface__"):
74+
usm_ar = _usm_ndarray_from_suai(obj)
75+
return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue])
7476
raise ValueError(type(obj))
7577

7678

@@ -219,6 +221,18 @@ def _map_to_device_dtype(dt, q):
219221
raise RuntimeError(f"Unrecognized data type '{dt}' encountered.")
220222

221223

224+
def _usm_ndarray_from_suai(obj):
225+
sua_iface = getattr(obj, "__sycl_usm_array_interface__")
226+
membuf = dpm.as_usm_memory(obj)
227+
ary = dpt.usm_ndarray(
228+
sua_iface["shape"],
229+
dtype=sua_iface["typestr"],
230+
buffer=membuf,
231+
strides=sua_iface.get("strides", None),
232+
)
233+
return ary
234+
235+
222236
def _asarray_from_numpy_ndarray(
223237
ary, dtype=None, usm_type=None, sycl_queue=None, order="K"
224238
):
@@ -276,17 +290,6 @@ def _asarray_from_numpy_ndarray(
276290
return res
277291

278292

279-
def _is_object_with_buffer_protocol(obj):
280-
"Returns `True` if object support Python buffer protocol"
281-
try:
282-
# use context manager to ensure
283-
# buffer is instantly released
284-
with memoryview(obj):
285-
return True
286-
except TypeError:
287-
return False
288-
289-
290293
def _ensure_native_dtype_device_support(dtype, dev) -> None:
291294
"""Check that dtype is natively supported by device.
292295
@@ -318,6 +321,122 @@ def _ensure_native_dtype_device_support(dtype, dev) -> None:
318321
)
319322

320323

324+
def _usm_types_walker(o, usm_types_list):
325+
if isinstance(o, dpt.usm_ndarray):
326+
usm_types_list.append(o.usm_type)
327+
return
328+
if hasattr(o, "__sycl_usm_array_interface__"):
329+
usm_ar = _usm_ndarray_from_suai(o)
330+
usm_types_list.append(usm_ar.usm_type)
331+
return
332+
if isinstance(o, (list, tuple)):
333+
for el in o:
334+
_usm_types_walker(el, usm_types_list)
335+
return
336+
raise TypeError
337+
338+
339+
def _device_copy_walker(seq_o, res, events):
340+
if isinstance(seq_o, dpt.usm_ndarray):
341+
exec_q = res.sycl_queue
342+
ht_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
343+
src=seq_o, dst=res, sycl_queue=exec_q
344+
)
345+
events.append(ht_ev)
346+
return
347+
if hasattr(seq_o, "__sycl_usm_array_interface__"):
348+
usm_ar = _usm_ndarray_from_suai(seq_o)
349+
exec_q = res.sycl_queue
350+
ht_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
351+
src=usm_ar, dst=res, sycl_queue=exec_q
352+
)
353+
events.append(ht_ev)
354+
return
355+
if isinstance(seq_o, (list, tuple)):
356+
for i, el in enumerate(seq_o):
357+
_device_copy_walker(el, res[i], events)
358+
return
359+
raise TypeError
360+
361+
362+
def _copy_through_host_walker(seq_o, usm_res):
363+
if isinstance(seq_o, dpt.usm_ndarray):
364+
usm_res[...] = dpt.asnumpy(seq_o).copy()
365+
return
366+
if hasattr(seq_o, "__sycl_usm_array_interface__"):
367+
usm_ar = _usm_ndarray_from_suai(seq_o)
368+
usm_res[...] = dpt.asnumpy(usm_ar).copy()
369+
return
370+
if isinstance(seq_o, (list, tuple)):
371+
for i, el in enumerate(seq_o):
372+
_copy_through_host_walker(el, usm_res[i])
373+
return
374+
usm_res[...] = np.asarray(seq_o)
375+
376+
377+
def _asarray_from_seq(
378+
seq_obj,
379+
seq_shape,
380+
seq_dt,
381+
seq_dev,
382+
dtype=None,
383+
usm_type=None,
384+
sycl_queue=None,
385+
order="C",
386+
):
387+
"`obj` is a sequence"
388+
if usm_type is None:
389+
usm_types_in_seq = []
390+
_usm_types_walker(seq_obj, usm_types_in_seq)
391+
usm_type = dpctl.utils.get_coerced_usm_type(usm_types_in_seq)
392+
dpctl.utils.validate_usm_type(usm_type)
393+
if sycl_queue is None:
394+
exec_q = seq_dev
395+
alloc_q = seq_dev
396+
else:
397+
exec_q = dpctl.utils.get_execution_queue(
398+
(
399+
sycl_queue,
400+
seq_dev,
401+
)
402+
)
403+
alloc_q = sycl_queue
404+
if dtype is None:
405+
dtype = _map_to_device_dtype(seq_dt, alloc_q)
406+
else:
407+
_mapped_dt = _map_to_device_dtype(dtype, alloc_q)
408+
if _mapped_dt != dtype:
409+
raise ValueError(
410+
f"Device {sycl_queue.sycl_device} "
411+
f"does not support {dtype} natively."
412+
)
413+
dtype = _mapped_dt
414+
if order in "KA":
415+
order = "C"
416+
if isinstance(exec_q, dpctl.SyclQueue):
417+
res = dpt.empty(
418+
seq_shape,
419+
dtype=dtype,
420+
usm_type=usm_type,
421+
sycl_queue=alloc_q,
422+
order=order,
423+
)
424+
ht_events = []
425+
_device_copy_walker(seq_obj, res, ht_events)
426+
dpctl.SyclEvent.wait_for(ht_events)
427+
return res
428+
else:
429+
res = dpt.empty(
430+
seq_shape,
431+
dtype=dtype,
432+
usm_type=usm_type,
433+
sycl_queue=alloc_q,
434+
order=order,
435+
)
436+
_copy_through_host_walker(seq_obj, res)
437+
return res
438+
439+
321440
def asarray(
322441
obj,
323442
dtype=None,
@@ -327,7 +446,9 @@ def asarray(
327446
sycl_queue=None,
328447
order="K",
329448
):
330-
"""
449+
""" asarray(obj, dtype=None, copy=None, device=None, \
450+
usm_type=None, sycl_queue=None, order="K")
451+
331452
Converts `obj` to :class:`dpctl.tensor.usm_ndarray`.
332453
333454
Args:
@@ -347,7 +468,7 @@ def asarray(
347468
allocations if possible, but allowed to perform a copy otherwise.
348469
Default: `None`.
349470
order ("C","F","A","K", optional): memory layout of the output array.
350-
Default: "C"
471+
Default: "K"
351472
device (optional): array API concept of device where the output array
352473
is created. `device` can be `None`, a oneAPI filter selector string,
353474
an instance of :class:`dpctl.SyclDevice` corresponding to a
@@ -407,14 +528,7 @@ def asarray(
407528
order=order,
408529
)
409530
if hasattr(obj, "__sycl_usm_array_interface__"):
410-
sua_iface = getattr(obj, "__sycl_usm_array_interface__")
411-
membuf = dpm.as_usm_memory(obj)
412-
ary = dpt.usm_ndarray(
413-
sua_iface["shape"],
414-
dtype=sua_iface["typestr"],
415-
buffer=membuf,
416-
strides=sua_iface.get("strides", None),
417-
)
531+
ary = _usm_ndarray_from_suai(obj)
418532
return _asarray_from_usm_ndarray(
419533
ary,
420534
dtype=dtype,
@@ -452,7 +566,7 @@ def asarray(
452566
raise ValueError(
453567
"Converting Python sequence to usm_ndarray requires a copy"
454568
)
455-
_, _, devs = _array_info_sequence(obj)
569+
seq_shape, seq_dt, devs = _array_info_sequence(obj)
456570
if devs == _host_set:
457571
return _asarray_from_numpy_ndarray(
458572
np.asarray(obj, dtype=dtype, order=order),
@@ -461,7 +575,17 @@ def asarray(
461575
sycl_queue=sycl_queue,
462576
order=order,
463577
)
464-
# for sequences
578+
elif len(devs) == 1:
579+
return _asarray_from_seq(
580+
obj,
581+
seq_shape,
582+
seq_dt,
583+
list(devs)[0],
584+
dtype=dtype,
585+
usm_type=usm_type,
586+
sycl_queue=sycl_queue,
587+
order=order,
588+
)
465589
raise NotImplementedError(
466590
"Converting Python sequences is not implemented"
467591
)

dpctl/tensor/_usmarray.pyx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,3 +1629,8 @@ cdef api object UsmNDArray_MakeFromPtr(
16291629
offset=offset
16301630
)
16311631
return arr
1632+
1633+
1634+
def _is_object_with_buffer_protocol(o):
1635+
"Returns True if object support Python buffer protocol"
1636+
return _is_buffer(o)

0 commit comments

Comments
 (0)