Skip to content

Commit f6b8ed5

Browse files
committed
Copy from/to multidimensional buffers
Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent d84cb16 commit f6b8ed5

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

dpctl/_sycl_queue.pyx

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,15 @@ import ctypes
6565
from .enum_types import backend_type
6666

6767
from cpython cimport pycapsule
68-
from cpython.buffer cimport PyObject_CheckBuffer
68+
from cpython.buffer cimport (
69+
PyObject_CheckBuffer,
70+
Py_buffer,
71+
PyObject_GetBuffer,
72+
PyBUF_SIMPLE,
73+
PyBUF_ANY_CONTIGUOUS,
74+
PyBUF_WRITABLE,
75+
PyBuffer_Release
76+
)
6977
from cpython.ref cimport Py_DECREF, Py_INCREF, PyObject
7078
from libc.stdlib cimport free, malloc
7179

@@ -338,14 +346,20 @@ cdef DPCTLSyclEventRef _memcpy_impl(
338346
cdef void *c_dst_ptr = NULL
339347
cdef void *c_src_ptr = NULL
340348
cdef DPCTLSyclEventRef ERef = NULL
341-
cdef const unsigned char[::1] src_host_buf = None
342-
cdef unsigned char[::1] dst_host_buf = None
349+
cdef Py_buffer src_buf_view
350+
cdef Py_buffer dst_buf_view
351+
cdef bint src_is_buf = False
352+
cdef bint dst_is_buf = False
353+
cdef int ret_code = 0
343354

344355
if isinstance(src, _Memory):
345356
c_src_ptr = <void*>(<_Memory>src).get_data_ptr()
346357
elif _is_buffer(src):
347-
src_host_buf = src
348-
c_src_ptr = <void *>&src_host_buf[0]
358+
ret_code = PyObject_GetBuffer(src, &src_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS)
359+
if ret_code != 0:
360+
raise RuntimeError("Could not access buffer")
361+
c_src_ptr = src_buf_view.buf
362+
src_is_buf = True
349363
else:
350364
raise TypeError(
351365
"Parameter `src` should have either type "
@@ -356,8 +370,11 @@ cdef DPCTLSyclEventRef _memcpy_impl(
356370
if isinstance(dst, _Memory):
357371
c_dst_ptr = <void*>(<_Memory>dst).get_data_ptr()
358372
elif _is_buffer(dst):
359-
dst_host_buf = dst
360-
c_dst_ptr = <void *>&dst_host_buf[0]
373+
ret_code = PyObject_GetBuffer(dst, &dst_buf_view, PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS | PyBUF_WRITABLE)
374+
if ret_code != 0:
375+
raise RuntimeError("Could not access buffer")
376+
c_dst_ptr = dst_buf_view.buf
377+
dst_is_buf = True
361378
else:
362379
raise TypeError(
363380
"Parameter `dst` should have either type "
@@ -376,6 +393,12 @@ cdef DPCTLSyclEventRef _memcpy_impl(
376393
dep_events,
377394
dep_events_count
378395
)
396+
397+
if src_is_buf:
398+
PyBuffer_Release(&src_buf_view)
399+
if dst_is_buf:
400+
PyBuffer_Release(&dst_buf_view)
401+
379402
return ERef
380403

381404

dpctl/tests/test_sycl_queue_memcpy.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import dpctl
2323
import dpctl.memory
2424

25+
import numpy as np
26+
2527

2628
def _create_memory(q):
2729
nbytes = 1024
@@ -97,6 +99,44 @@ def test_memcpy_copy_host_to_host():
9799
assert dst_buf == src_buf
98100

99101

102+
def test_2D_memcpy_copy_host_to_usm():
103+
try:
104+
q = dpctl.SyclQueue()
105+
except dpctl.SyclQueueCreationError:
106+
pytest.skip("Default constructor for SyclQueue failed")
107+
usm_obj = _create_memory(q)
108+
109+
n = 12
110+
canary = bytearray([i for i in range(n)])
111+
host_obj = np.frombuffer(canary, dtype=np.uint8).reshape(3, 4)
112+
113+
q.memcpy(usm_obj, host_obj, len(canary))
114+
115+
mv2 = memoryview(usm_obj)
116+
117+
assert mv2[: len(canary)] == canary
118+
119+
120+
def test_2D_memcpy_copy_usm_to_host():
121+
try:
122+
q = dpctl.SyclQueue()
123+
except dpctl.SyclQueueCreationError:
124+
pytest.skip("Default constructor for SyclQueue failed")
125+
usm_obj = _create_memory(q)
126+
mv2 = memoryview(usm_obj)
127+
128+
n = 12
129+
shape = (3, 4)
130+
for id in range(n):
131+
mv2[id] = id
132+
133+
host_obj = np.ones(shape, dtype=np.uint8)
134+
135+
q.memcpy(host_obj, usm_obj, n)
136+
137+
assert np.array_equal(host_obj, np.arange(n, dtype=np.uint8).reshape(shape))
138+
139+
100140
def test_memcpy_async():
101141
try:
102142
q = dpctl.SyclQueue()

0 commit comments

Comments
 (0)