Skip to content

Accept NumPy arrays in advanced indexing #2128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 67 additions & 60 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
if not isinstance(ary_mask, dpt.usm_ndarray):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
if isinstance(ary_mask, dpt.usm_ndarray):
dst_usm_type = dpctl.utils.get_coerced_usm_type(
(ary.usm_type, ary_mask.usm_type)
)
dst_usm_type = dpctl.utils.get_coerced_usm_type(
(ary.usm_type, ary_mask.usm_type)
)
exec_q = dpctl.utils.get_execution_queue(
(ary.sycl_queue, ary_mask.sycl_queue)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"arrays have different associated queues. "
"Use `y.to_device(x.device)` to migrate."
exec_q = dpctl.utils.get_execution_queue(
(ary.sycl_queue, ary_mask.sycl_queue)
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"arrays have different associated queues. "
"Use `y.to_device(x.device)` to migrate."
)
elif isinstance(ary_mask, np.ndarray):
dst_usm_type = ary.usm_type
exec_q = ary.sycl_queue
ary_mask = dpt.asarray(
ary_mask, usm_type=dst_usm_type, sycl_queue=exec_q
)
else:
raise TypeError(
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
f"{type(ary_mask)}"
)
ary_nd = ary.ndim
pp = normalize_axis_index(operator.index(axis), ary_nd)
Expand Down Expand Up @@ -837,35 +845,40 @@ def _nonzero_impl(ary):
return res


def _validate_indices(inds, queue_list, usm_type_list):
def _get_indices_queue_usm_type(inds, queue, usm_type):
"""
Utility for validating indices are usm_ndarray of integral dtype or Python
integers. At least one must be an array.
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
dtype or Python integers. At least one must be an array.

For each array, the queue and usm type are appended to `queue_list` and
`usm_type_list`, respectively.
"""
any_usmarray = False
queues = [queue]
usm_types = [usm_type]
any_array = False
for ind in inds:
if isinstance(ind, dpt.usm_ndarray):
any_usmarray = True
if isinstance(ind, (np.ndarray, dpt.usm_ndarray)):
any_array = True
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) "
"type"
)
queue_list.append(ind.sycl_queue)
usm_type_list.append(ind.usm_type)
if isinstance(ind, dpt.usm_ndarray):
queues.append(ind.sycl_queue)
usm_types.append(ind.usm_type)
elif not isinstance(ind, Integral):
raise TypeError(
"all elements of `ind` expected to be usm_ndarrays "
f"or integers, found {type(ind)}"
"all elements of `ind` expected to be usm_ndarrays, "
f"NumPy arrays, or integers, found {type(ind)}"
)
if not any_usmarray:
if not any_array:
raise TypeError(
"at least one element of `inds` expected to be a usm_ndarray"
"at least one element of `inds` expected to be an array"
)
return inds
usm_type = dpctl.utils.get_coerced_usm_type(usm_types)
q = dpctl.utils.get_execution_queue(queues)
return q, usm_type


def _prepare_indices_arrays(inds, q, usm_type):
Expand Down Expand Up @@ -922,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
raise ValueError(
"Invalid value for mode keyword, only 0 or 1 is supported"
)
queues_ = [
ary.sycl_queue,
]
usm_types_ = [
ary.usm_type,
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)

_validate_indices(inds, queues_, usm_types_)
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
exec_q, res_usm_type = _get_indices_queue_usm_type(
inds, ary.sycl_queue, ary.usm_type
)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Can not automatically determine where to allocate the "
Expand All @@ -942,8 +949,7 @@ def _take_multi_index(ary, inds, p, mode=0):
"be associated with the same queue."
)

if len(inds) > 1:
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
Expand Down Expand Up @@ -976,16 +982,28 @@ def _place_impl(ary, ary_mask, vals, axis=0):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
if not isinstance(ary_mask, dpt.usm_ndarray):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
if isinstance(ary_mask, dpt.usm_ndarray):
exec_q = dpctl.utils.get_execution_queue(
(
ary.sycl_queue,
ary_mask.sycl_queue,
)
)
exec_q = dpctl.utils.get_execution_queue(
(
ary.sycl_queue,
ary_mask.sycl_queue,
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"arrays have different associated queues. "
"Use `y.to_device(x.device)` to migrate."
)
elif isinstance(ary_mask, np.ndarray):
exec_q = ary.sycl_queue
ary_mask = dpt.asarray(
ary_mask, usm_type=ary.usm_type, sycl_queue=exec_q
)
else:
raise TypeError(
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
f"{type(ary_mask)}"
)
)
if exec_q is not None:
if not isinstance(vals, dpt.usm_ndarray):
vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q)
Expand Down Expand Up @@ -1048,23 +1066,13 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
raise ValueError(
"Invalid value for mode keyword, only 0 or 1 is supported"
)
if isinstance(vals, dpt.usm_ndarray):
queues_ = [ary.sycl_queue, vals.sycl_queue]
usm_types_ = [ary.usm_type, vals.usm_type]
else:
queues_ = [
ary.sycl_queue,
]
usm_types_ = [
ary.usm_type,
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)

_validate_indices(inds, queues_, usm_types_)
exec_q, vals_usm_type = _get_indices_queue_usm_type(
inds, ary.sycl_queue, ary.usm_type
)

vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is not None:
if not isinstance(vals, dpt.usm_ndarray):
vals = dpt.asarray(
Expand All @@ -1080,8 +1088,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
"be associated with the same queue."
)

if len(inds) > 1:
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)

ind0 = inds[0]
ary_sh = ary.shape
Expand Down
13 changes: 7 additions & 6 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numbers
from operator import index
from cpython.buffer cimport PyObject_CheckBuffer
from numpy import ndarray


cdef bint _is_buffer(object o):
Expand Down Expand Up @@ -46,7 +47,7 @@ cdef Py_ssize_t _slice_len(

cdef bint _is_integral(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if isinstance(x, (ndarray, usm_ndarray)):
if x.ndim > 0:
return False
if x.dtype.kind not in "ui":
Expand Down Expand Up @@ -74,7 +75,7 @@ cdef bint _is_integral(object x) except *:

cdef bint _is_boolean(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if isinstance(x, (ndarray, usm_ndarray)):
if x.ndim > 0:
return False
if x.dtype.kind not in "b":
Expand Down Expand Up @@ -185,7 +186,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
raise IndexError(
"Index {0} is out of range for axes 0 with "
"size {1}".format(ind, shape[0]))
elif isinstance(ind, usm_ndarray):
elif isinstance(ind, (ndarray, usm_ndarray)):
return (shape, strides, offset, (ind,), 0)
elif isinstance(ind, tuple):
axes_referenced = 0
Expand Down Expand Up @@ -216,7 +217,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
axes_referenced += 1
if not array_streak_started and array_streak_interrupted:
explicit_index += 1
elif isinstance(i, usm_ndarray):
elif isinstance(i, (ndarray, usm_ndarray)):
if not seen_arrays_yet:
seen_arrays_yet = True
array_streak_started = True
Expand Down Expand Up @@ -302,7 +303,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
array_streak = False
elif _is_integral(ind_i):
if array_streak:
if not isinstance(ind_i, usm_ndarray):
if not isinstance(ind_i, (ndarray, usm_ndarray)):
ind_i = index(ind_i)
# integer will be converted to an array,
# still raise if OOB
Expand Down Expand Up @@ -337,7 +338,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
"Index {0} is out of range for axes "
"{1} with size {2}".format(ind_i, k, shape[k])
)
elif isinstance(ind_i, usm_ndarray):
elif isinstance(ind_i, (ndarray, usm_ndarray)):
if not array_streak:
array_streak = True
if not advanced_start_pos_set:
Expand Down
22 changes: 22 additions & 0 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,28 @@ def test_advanced_slice16():
assert isinstance(y, dpt.usm_ndarray)


def test_integer_indexing_numpy_array():
q = get_queue_or_skip()
ii = np.asarray([1, 2])
x = dpt.arange(10, dtype="i4", sycl_queue=q)
y = x[ii]
assert isinstance(y, dpt.usm_ndarray)
assert y.shape == ii.shape
assert dpt.all(dpt.asarray(ii, sycl_queue=q) == y)


def test_boolean_indexing_numpy_array():
q = get_queue_or_skip()
ii = np.asarray(
[False, True, True, False, False, False, False, False, False, False]
)
x = dpt.arange(10, dtype="i4", sycl_queue=q)
y = x[ii]
assert isinstance(y, dpt.usm_ndarray)
assert y.shape == (2,)
assert dpt.all(x[1:3] == y)


def test_boolean_indexing_validation():
get_queue_or_skip()
x = dpt.zeros(10, dtype="i4")
Expand Down