21
21
import dpctl .tensor ._tensor_impl as ti
22
22
import dpctl .utils
23
23
24
- from ._copy_utils import _extract_impl , _nonzero_impl , _take_multi_index
24
+ from ._copy_utils import (
25
+ _extract_impl ,
26
+ _nonzero_impl ,
27
+ _put_multi_index ,
28
+ _take_multi_index ,
29
+ )
25
30
from ._numpy_helper import normalize_axis_index
26
31
27
32
@@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals):
206
211
raise TypeError (
207
212
"Expected instance of `dpt.usm_ndarray`, got `{}`." .format (type (x ))
208
213
)
209
- if isinstance (vals , dpt .usm_ndarray ):
210
- queues_ = [x .sycl_queue , vals .sycl_queue ]
211
- usm_types_ = [x .usm_type , vals .usm_type ]
212
- else :
213
- queues_ = [
214
- x .sycl_queue ,
215
- ]
216
- usm_types_ = [
217
- x .usm_type ,
218
- ]
219
214
if not isinstance (indices , dpt .usm_ndarray ):
220
215
raise TypeError (
221
216
"`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
222
217
type (indices )
223
218
)
224
219
)
220
+ if isinstance (vals , dpt .usm_ndarray ):
221
+ queues_ = [x .sycl_queue , indices .sycl_queue , vals .sycl_queue ]
222
+ usm_types_ = [x .usm_type , indices .usm_type , vals .usm_type ]
223
+ else :
224
+ queues_ = [x .sycl_queue , indices .sycl_queue ]
225
+ usm_types_ = [x .usm_type , indices .usm_type ]
225
226
if indices .ndim != 1 :
226
227
raise ValueError (
227
228
"`indices` expected a 1D array, got `{}`" .format (indices .ndim )
@@ -232,7 +233,6 @@ def put_vec_duplicates(vec, ind, vals):
232
233
indices .dtype
233
234
)
234
235
)
235
- queues_ .append (indices .sycl_queue )
236
236
usm_types_ .append (indices .usm_type )
237
237
exec_q = dpctl .utils .get_execution_queue (queues_ )
238
238
if exec_q is None :
@@ -502,3 +502,81 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
502
502
for i in range (x_nd )
503
503
)
504
504
return _take_multi_index (x , _ind , 0 , mode = mode_i )
505
+
506
+
507
+ def put_along_axis (x , indices , vals , / , * , axis = - 1 , mode = "wrap" ):
508
+ """
509
+ Puts elements into an array at the one-dimensional indices specified by
510
+ ``indices`` along a provided ``axis``.
511
+
512
+ Args:
513
+ x (usm_ndarray):
514
+ input array. Must be compatible with ``indices``, except for the
515
+ axis (dimension) specified by ``axis``.
516
+ indices (usm_ndarray):
517
+ array indices. Must have the same rank (i.e., number of dimensions)
518
+ as ``x``.
519
+ vals (usm_ndarray):
520
+ Array of values to be put into ``x``.
521
+ Must be broadcastable to the shape of ``indices``.
522
+ axis: int
523
+ axis along which to select values. If ``axis`` is negative, the
524
+ function determines the axis along which to select values by
525
+ counting from the last dimension. Default: ``-1``.
526
+ mode (str, optional):
527
+ How out-of-bounds indices will be handled. Possible values
528
+ are:
529
+
530
+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
531
+ negative indices.
532
+ - ``"clip"``: clips indices to (``0 <= i < n``).
533
+
534
+ Default: ``"wrap"``.
535
+
536
+ .. note::
537
+
538
+ If input array ``indices`` contains duplicates, a race condition
539
+ occurs, and the value written into corresponding positions in ``x``
540
+ may vary from run to run. Preserving sequential semantics in handing
541
+ the duplicates to achieve deterministic behavior requires additional
542
+ work.
543
+
544
+ See :func:`dpctl.tensor.put` for an example.
545
+ """
546
+ if not isinstance (x , dpt .usm_ndarray ):
547
+ raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
548
+ if not isinstance (indices , dpt .usm_ndarray ):
549
+ raise TypeError (
550
+ f"Expected dpctl.tensor.usm_ndarray, got { type (indices )} "
551
+ )
552
+ x_nd = x .ndim
553
+ if x_nd != indices .ndim :
554
+ raise ValueError (
555
+ "Number of dimensions in the first and the second "
556
+ "argument arrays must be equal"
557
+ )
558
+ pp = normalize_axis_index (operator .index (axis ), x_nd )
559
+ if isinstance (vals , dpt .usm_ndarray ):
560
+ queues_ = [x .sycl_queue , indices .sycl_queue , vals .sycl_queue ]
561
+ usm_types_ = [x .usm_type , indices .usm_type , vals .usm_type ]
562
+ else :
563
+ queues_ = [x .sycl_queue , indices .sycl_queue ]
564
+ usm_types_ = [x .usm_type , indices .usm_type ]
565
+ exec_q = dpctl .utils .get_execution_queue (queues_ )
566
+ if exec_q is None :
567
+ raise dpctl .utils .ExecutionPlacementError (
568
+ "Execution placement can not be unambiguously inferred "
569
+ "from input arguments. "
570
+ )
571
+ out_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
572
+ mode_i = _get_indexing_mode (mode )
573
+ indexes_dt = ti .default_device_index_type (exec_q .sycl_device )
574
+ _ind = tuple (
575
+ (
576
+ indices
577
+ if i == pp
578
+ else _range (x .shape [i ], i , x_nd , exec_q , out_usm_type , indexes_dt )
579
+ )
580
+ for i in range (x_nd )
581
+ )
582
+ return _put_multi_index (x , _ind , 0 , vals , mode = mode_i )
0 commit comments