Skip to content

Commit 0b9cec8

Browse files
authored
Merge pull request #1798 from IntelPython/put-along-axis
Implements `dpctl.tensor.put_along_axis`
2 parents 972fdb6 + f81c107 commit 0b9cec8

File tree

5 files changed

+213
-16
lines changed

5 files changed

+213
-16
lines changed

docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ by either integral arrays of indices or boolean mask arrays.
1414
extract
1515
place
1616
put
17+
put_along_axis
1718
take
1819
take_along_axis

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
nonzero,
6666
place,
6767
put,
68+
put_along_axis,
6869
take,
6970
take_along_axis,
7071
)
@@ -385,4 +386,5 @@
385386
"count_nonzero",
386387
"DLDeviceType",
387388
"take_along_axis",
389+
"put_along_axis",
388390
]

dpctl/tensor/_copy_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,13 +938,18 @@ def _place_impl(ary, ary_mask, vals, axis=0):
938938
return
939939

940940

941-
def _put_multi_index(ary, inds, p, vals):
941+
def _put_multi_index(ary, inds, p, vals, mode=0):
942942
if not isinstance(ary, dpt.usm_ndarray):
943943
raise TypeError(
944944
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
945945
)
946946
ary_nd = ary.ndim
947947
p = normalize_axis_index(operator.index(p), ary_nd)
948+
mode = operator.index(mode)
949+
if mode not in [0, 1]:
950+
raise ValueError(
951+
"Invalid value for mode keyword, only 0 or 1 is supported"
952+
)
948953
if isinstance(vals, dpt.usm_ndarray):
949954
queues_ = [ary.sycl_queue, vals.sycl_queue]
950955
usm_types_ = [ary.usm_type, vals.usm_type]
@@ -1018,7 +1023,7 @@ def _put_multi_index(ary, inds, p, vals):
10181023
ind=inds,
10191024
val=rhs,
10201025
axis_start=p,
1021-
mode=0,
1026+
mode=mode,
10221027
sycl_queue=exec_q,
10231028
depends=dep_ev,
10241029
)

dpctl/tensor/_indexing_functions.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.utils
2323

24-
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
24+
from ._copy_utils import (
25+
_extract_impl,
26+
_nonzero_impl,
27+
_put_multi_index,
28+
_take_multi_index,
29+
)
2530
from ._numpy_helper import normalize_axis_index
2631

2732

@@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals):
206211
raise TypeError(
207212
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
208213
)
209-
if isinstance(vals, dpt.usm_ndarray):
210-
queues_ = [x.sycl_queue, vals.sycl_queue]
211-
usm_types_ = [x.usm_type, vals.usm_type]
212-
else:
213-
queues_ = [
214-
x.sycl_queue,
215-
]
216-
usm_types_ = [
217-
x.usm_type,
218-
]
219214
if not isinstance(indices, dpt.usm_ndarray):
220215
raise TypeError(
221216
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
222217
type(indices)
223218
)
224219
)
220+
if isinstance(vals, dpt.usm_ndarray):
221+
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
222+
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
223+
else:
224+
queues_ = [x.sycl_queue, indices.sycl_queue]
225+
usm_types_ = [x.usm_type, indices.usm_type]
225226
if indices.ndim != 1:
226227
raise ValueError(
227228
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
@@ -232,8 +233,6 @@ def put_vec_duplicates(vec, ind, vals):
232233
indices.dtype
233234
)
234235
)
235-
queues_.append(indices.sycl_queue)
236-
usm_types_.append(indices.usm_type)
237236
exec_q = dpctl.utils.get_execution_queue(queues_)
238237
if exec_q is None:
239238
raise dpctl.utils.ExecutionPlacementError
@@ -502,3 +501,79 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
502501
for i in range(x_nd)
503502
)
504503
return _take_multi_index(x, _ind, 0, mode=mode_i)
504+
505+
506+
def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"):
507+
"""
508+
Puts elements into an array at the one-dimensional indices specified by
509+
``indices`` along a provided ``axis``.
510+
511+
Args:
512+
x (usm_ndarray):
513+
input array. Must be compatible with ``indices``, except for the
514+
axis (dimension) specified by ``axis``.
515+
indices (usm_ndarray):
516+
array indices. Must have the same rank (i.e., number of dimensions)
517+
as ``x``.
518+
vals (usm_ndarray):
519+
Array of values to be put into ``x``.
520+
Must be broadcastable to the shape of ``indices``.
521+
axis: int
522+
axis along which to select values. If ``axis`` is negative, the
523+
function determines the axis along which to select values by
524+
counting from the last dimension. Default: ``-1``.
525+
mode (str, optional):
526+
How out-of-bounds indices will be handled. Possible values
527+
are:
528+
529+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
530+
negative indices.
531+
- ``"clip"``: clips indices to (``0 <= i < n``).
532+
533+
Default: ``"wrap"``.
534+
535+
.. note::
536+
537+
If input array ``indices`` contains duplicates, a race condition
538+
occurs, and the value written into corresponding positions in ``x``
539+
may vary from run to run. Preserving sequential semantics in handing
540+
the duplicates to achieve deterministic behavior requires additional
541+
work.
542+
"""
543+
if not isinstance(x, dpt.usm_ndarray):
544+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
545+
if not isinstance(indices, dpt.usm_ndarray):
546+
raise TypeError(
547+
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
548+
)
549+
x_nd = x.ndim
550+
if x_nd != indices.ndim:
551+
raise ValueError(
552+
"Number of dimensions in the first and the second "
553+
"argument arrays must be equal"
554+
)
555+
pp = normalize_axis_index(operator.index(axis), x_nd)
556+
if isinstance(vals, dpt.usm_ndarray):
557+
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
558+
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
559+
else:
560+
queues_ = [x.sycl_queue, indices.sycl_queue]
561+
usm_types_ = [x.usm_type, indices.usm_type]
562+
exec_q = dpctl.utils.get_execution_queue(queues_)
563+
if exec_q is None:
564+
raise dpctl.utils.ExecutionPlacementError(
565+
"Execution placement can not be unambiguously inferred "
566+
"from input arguments. "
567+
)
568+
out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
569+
mode_i = _get_indexing_mode(mode)
570+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
571+
_ind = tuple(
572+
(
573+
indices
574+
if i == pp
575+
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
576+
)
577+
for i in range(x_nd)
578+
)
579+
return _put_multi_index(x, _ind, 0, vals, mode=mode_i)

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ def test_take_along_axis_validation():
15781578
def_dtypes = info_.default_dtypes(device=x_dev)
15791579
ind_dt = def_dtypes["indexing"]
15801580
ind = dpt.zeros(1, dtype=ind_dt)
1581-
# axis valudation
1581+
# axis validation
15821582
with pytest.raises(ValueError):
15831583
dpt.take_along_axis(x, ind, axis=1)
15841584
# mode validation
@@ -1594,6 +1594,116 @@ def test_take_along_axis_validation():
15941594
dpt.take_along_axis(x, ind2)
15951595

15961596

1597+
def test_put_along_axis():
1598+
get_queue_or_skip()
1599+
1600+
n0, n1, n2 = 3, 5, 7
1601+
x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2))
1602+
ind_dt = dpt.__array_namespace_info__().default_dtypes(
1603+
device=x.sycl_device
1604+
)["indexing"]
1605+
ind0 = dpt.ones((1, n1, n2), dtype=ind_dt)
1606+
ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt)
1607+
ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt)
1608+
1609+
xc = dpt.copy(x)
1610+
vals = dpt.ones(ind0.shape, dtype=x.dtype)
1611+
dpt.put_along_axis(xc, ind0, vals, axis=0)
1612+
assert dpt.all(dpt.take_along_axis(xc, ind0, axis=0) == vals)
1613+
1614+
xc = dpt.copy(x)
1615+
vals = dpt.ones(ind1.shape, dtype=x.dtype)
1616+
dpt.put_along_axis(xc, ind1, vals, axis=1)
1617+
assert dpt.all(dpt.take_along_axis(xc, ind1, axis=1) == vals)
1618+
1619+
xc = dpt.copy(x)
1620+
vals = dpt.ones(ind2.shape, dtype=x.dtype)
1621+
dpt.put_along_axis(xc, ind2, vals, axis=2)
1622+
assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals)
1623+
1624+
xc = dpt.copy(x)
1625+
vals = dpt.ones(ind2.shape, dtype=x.dtype)
1626+
dpt.put_along_axis(xc, ind2, dpt.asnumpy(vals), axis=2)
1627+
assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals)
1628+
1629+
1630+
def test_put_along_axis_validation():
1631+
# type check on the first argument
1632+
with pytest.raises(TypeError):
1633+
dpt.put_along_axis(tuple(), list(), list())
1634+
get_queue_or_skip()
1635+
n1, n2 = 2, 5
1636+
x = dpt.ones(n1 * n2)
1637+
# type check on the second argument
1638+
with pytest.raises(TypeError):
1639+
dpt.put_along_axis(x, list(), list())
1640+
x_dev = x.sycl_device
1641+
info_ = dpt.__array_namespace_info__()
1642+
def_dtypes = info_.default_dtypes(device=x_dev)
1643+
ind_dt = def_dtypes["indexing"]
1644+
ind = dpt.zeros(1, dtype=ind_dt)
1645+
vals = dpt.zeros(1, dtype=x.dtype)
1646+
# axis validation
1647+
with pytest.raises(ValueError):
1648+
dpt.put_along_axis(x, ind, vals, axis=1)
1649+
# mode validation
1650+
with pytest.raises(ValueError):
1651+
dpt.put_along_axis(x, ind, vals, axis=0, mode="invalid")
1652+
# same array-ranks validation
1653+
with pytest.raises(ValueError):
1654+
dpt.put_along_axis(dpt.reshape(x, (n1, n2)), ind, vals)
1655+
# check compute-follows-data
1656+
q2 = dpctl.SyclQueue(x_dev, property="enable_profiling")
1657+
ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2)
1658+
with pytest.raises(ExecutionPlacementError):
1659+
dpt.put_along_axis(x, ind2, vals)
1660+
1661+
1662+
def test_put_along_axis_application():
1663+
get_queue_or_skip()
1664+
info_ = dpt.__array_namespace_info__()
1665+
def_dtypes = info_.default_dtypes(device=None)
1666+
ind_dt = def_dtypes["indexing"]
1667+
all_perms = dpt.asarray(
1668+
[
1669+
[0, 1, 2, 3],
1670+
[0, 2, 1, 3],
1671+
[2, 0, 1, 3],
1672+
[2, 1, 0, 3],
1673+
[1, 0, 2, 3],
1674+
[1, 2, 0, 3],
1675+
[0, 1, 3, 2],
1676+
[0, 2, 3, 1],
1677+
[2, 0, 3, 1],
1678+
[2, 1, 3, 0],
1679+
[1, 0, 3, 2],
1680+
[1, 2, 3, 0],
1681+
[0, 3, 1, 2],
1682+
[0, 3, 2, 1],
1683+
[2, 3, 0, 1],
1684+
[2, 3, 1, 0],
1685+
[1, 3, 0, 2],
1686+
[1, 3, 2, 0],
1687+
[3, 0, 1, 2],
1688+
[3, 0, 2, 1],
1689+
[3, 2, 0, 1],
1690+
[3, 2, 1, 0],
1691+
[3, 1, 0, 2],
1692+
[3, 1, 2, 0],
1693+
],
1694+
dtype=ind_dt,
1695+
)
1696+
p_mats = dpt.zeros((24, 4, 4), dtype=dpt.int64)
1697+
vals = dpt.ones((24, 4, 1), dtype=p_mats.dtype)
1698+
# form 24 permutation matrices
1699+
dpt.put_along_axis(p_mats, all_perms[..., dpt.newaxis], vals, axis=2)
1700+
p2 = p_mats @ p_mats
1701+
p4 = p2 @ p2
1702+
p8 = p4 @ p4
1703+
expected = dpt.eye(4, dtype=p_mats.dtype)[dpt.newaxis, ...]
1704+
assert dpt.all(p8 @ p4 == expected)
1705+
1706+
15971707
def check__extract_impl_validation(fn):
15981708
x = dpt.ones(10)
15991709
ind = dpt.ones(10, dtype="?")
@@ -1670,7 +1780,11 @@ def check__put_multi_index_validation(fn):
16701780
with pytest.raises(ValueError):
16711781
fn(x2, (ind1, ind2), 0, x2)
16721782
with pytest.raises(TypeError):
1783+
# invalid index type
16731784
fn(x2, (ind1, list()), 0, x2)
1785+
with pytest.raises(ValueError):
1786+
# invalid mode keyword value
1787+
fn(x, inds, 0, vals, mode=100)
16741788

16751789

16761790
def test__copy_utils():

0 commit comments

Comments
 (0)