Skip to content

Commit d2b5476

Browse files
committed
Implements put_along_axis
Also makes minor tweaks to `take_along_axis`
1 parent 43f5aea commit d2b5476

File tree

3 files changed

+99
-14
lines changed

3 files changed

+99
-14
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
nonzero,
6666
place,
6767
put,
68+
put_along_axis,
6869
take,
6970
take_along_axis,
7071
)
@@ -384,4 +385,5 @@
384385
"diff",
385386
"count_nonzero",
386387
"take_along_axis",
388+
"put_along_axis",
387389
]

dpctl/tensor/_copy_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,13 +938,18 @@ def _place_impl(ary, ary_mask, vals, axis=0):
938938
return
939939

940940

941-
def _put_multi_index(ary, inds, p, vals):
941+
def _put_multi_index(ary, inds, p, vals, mode=0):
942942
if not isinstance(ary, dpt.usm_ndarray):
943943
raise TypeError(
944944
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
945945
)
946946
ary_nd = ary.ndim
947947
p = normalize_axis_index(operator.index(p), ary_nd)
948+
mode = operator.index(mode)
949+
if mode not in [0, 1]:
950+
raise ValueError(
951+
"Invalid value for mode keyword, only 0 or 1 is supported"
952+
)
948953
if isinstance(vals, dpt.usm_ndarray):
949954
queues_ = [ary.sycl_queue, vals.sycl_queue]
950955
usm_types_ = [ary.usm_type, vals.usm_type]
@@ -1018,7 +1023,7 @@ def _put_multi_index(ary, inds, p, vals):
10181023
ind=inds,
10191024
val=rhs,
10201025
axis_start=p,
1021-
mode=0,
1026+
mode=mode,
10221027
sycl_queue=exec_q,
10231028
depends=dep_ev,
10241029
)

dpctl/tensor/_indexing_functions.py

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.utils
2323

24-
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
24+
from ._copy_utils import (
25+
_extract_impl,
26+
_nonzero_impl,
27+
_put_multi_index,
28+
_take_multi_index,
29+
)
2530
from ._numpy_helper import normalize_axis_index
2631

2732

@@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals):
206211
raise TypeError(
207212
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
208213
)
209-
if isinstance(vals, dpt.usm_ndarray):
210-
queues_ = [x.sycl_queue, vals.sycl_queue]
211-
usm_types_ = [x.usm_type, vals.usm_type]
212-
else:
213-
queues_ = [
214-
x.sycl_queue,
215-
]
216-
usm_types_ = [
217-
x.usm_type,
218-
]
219214
if not isinstance(indices, dpt.usm_ndarray):
220215
raise TypeError(
221216
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
222217
type(indices)
223218
)
224219
)
220+
if isinstance(vals, dpt.usm_ndarray):
221+
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
222+
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
223+
else:
224+
queues_ = [x.sycl_queue, indices.sycl_queue]
225+
usm_types_ = [x.usm_type, indices.usm_type]
225226
if indices.ndim != 1:
226227
raise ValueError(
227228
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
@@ -232,7 +233,6 @@ def put_vec_duplicates(vec, ind, vals):
232233
indices.dtype
233234
)
234235
)
235-
queues_.append(indices.sycl_queue)
236236
usm_types_.append(indices.usm_type)
237237
exec_q = dpctl.utils.get_execution_queue(queues_)
238238
if exec_q is None:
@@ -502,3 +502,81 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
502502
for i in range(x_nd)
503503
)
504504
return _take_multi_index(x, _ind, 0, mode=mode_i)
505+
506+
507+
def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"):
508+
"""
509+
Puts elements into an array at the one-dimensional indices specified by
510+
``indices`` along a provided ``axis``.
511+
512+
Args:
513+
x (usm_ndarray):
514+
input array. Must be compatible with ``indices``, except for the
515+
axis (dimension) specified by ``axis``.
516+
indices (usm_ndarray):
517+
array indices. Must have the same rank (i.e., number of dimensions)
518+
as ``x``.
519+
vals (usm_ndarray):
520+
Array of values to be put into ``x``.
521+
Must be broadcastable to the shape of ``indices``.
522+
axis: int
523+
axis along which to select values. If ``axis`` is negative, the
524+
function determines the axis along which to select values by
525+
counting from the last dimension. Default: ``-1``.
526+
mode (str, optional):
527+
How out-of-bounds indices will be handled. Possible values
528+
are:
529+
530+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
531+
negative indices.
532+
- ``"clip"``: clips indices to (``0 <= i < n``).
533+
534+
Default: ``"wrap"``.
535+
536+
.. note::
537+
538+
If input array ``indices`` contains duplicates, a race condition
539+
occurs, and the value written into corresponding positions in ``x``
540+
may vary from run to run. Preserving sequential semantics in handing
541+
the duplicates to achieve deterministic behavior requires additional
542+
work.
543+
544+
See :func:`dpctl.tensor.put` for an example.
545+
"""
546+
if not isinstance(x, dpt.usm_ndarray):
547+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
548+
if not isinstance(indices, dpt.usm_ndarray):
549+
raise TypeError(
550+
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
551+
)
552+
x_nd = x.ndim
553+
if x_nd != indices.ndim:
554+
raise ValueError(
555+
"Number of dimensions in the first and the second "
556+
"argument arrays must be equal"
557+
)
558+
pp = normalize_axis_index(operator.index(axis), x_nd)
559+
if isinstance(vals, dpt.usm_ndarray):
560+
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
561+
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
562+
else:
563+
queues_ = [x.sycl_queue, indices.sycl_queue]
564+
usm_types_ = [x.usm_type, indices.usm_type]
565+
exec_q = dpctl.utils.get_execution_queue(queues_)
566+
if exec_q is None:
567+
raise dpctl.utils.ExecutionPlacementError(
568+
"Execution placement can not be unambiguously inferred "
569+
"from input arguments. "
570+
)
571+
out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
572+
mode_i = _get_indexing_mode(mode)
573+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
574+
_ind = tuple(
575+
(
576+
indices
577+
if i == pp
578+
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
579+
)
580+
for i in range(x_nd)
581+
)
582+
return _put_multi_index(x, _ind, 0, vals, mode=mode_i)

0 commit comments

Comments
 (0)