Skip to content

Commit 41aa579

Browse files
indexing ops: diag indices (#339)
* add indexing tests
1 parent 7b1172b commit 41aa579

File tree

10 files changed

+1883
-1
lines changed

10 files changed

+1883
-1
lines changed

dpnp/backend_indexing.pyx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,22 @@ from dpnp.dpnp_iface_counting import count_nonzero
3939

4040

4141
__all__ += [
42+
"dpnp_diag_indices",
4243
"dpnp_nonzero",
4344
]
4445

4546

47+
cpdef tuple dpnp_diag_indices(n, ndim):
48+
cdef dparray res_item = dpnp.arange(n, dtype=dpnp.int64)
49+
50+
# yes, all are the same item
51+
result = []
52+
for i in range(ndim):
53+
result.append(res_item)
54+
55+
return tuple(result)
56+
57+
4658
cpdef tuple dpnp_nonzero(dparray in_array1):
4759
res_count = in_array1.ndim
4860

dpnp/dpnp_iface_indexing.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,106 @@
4949

5050

5151
__all__ = [
52+
"diag_indices",
53+
"diag_indices_from",
5254
"nonzero",
5355
]
5456

5557

58+
def diag_indices(n, ndim=2):
59+
"""
60+
Return the indices to access the main diagonal of an array.
61+
62+
This returns a tuple of indices that can be used to access the main
63+
diagonal of an array `a` with ``a.ndim >= 2`` dimensions and shape
64+
(n, n, ..., n). For ``a.ndim = 2`` this is the usual diagonal, for
65+
``a.ndim > 2`` this is the set of indices to access ``a[i, i, ..., i]``
66+
for ``i = [0..n-1]``.
67+
68+
For full documentation refer to :obj:`numpy.diag_indices`.
69+
70+
See also
71+
--------
72+
:obj:`diag_indices_from` : Return the indices to access the main
73+
diagonal of an n-dimensional array.
74+
75+
Examples
76+
--------
77+
Create a set of indices to access the diagonal of a (4, 4) array:
78+
79+
>>> import dpnp as np
80+
>>> di = np.diag_indices(4)
81+
>>> di
82+
(array([0, 1, 2, 3]), array([0, 1, 2, 3]))
83+
>>> a = np.arange(16).reshape(4, 4)
84+
>>> a
85+
array([[ 0, 1, 2, 3],
86+
[ 4, 5, 6, 7],
87+
[ 8, 9, 10, 11],
88+
[12, 13, 14, 15]])
89+
>>> a[di] = 100
90+
>>> a
91+
array([[100, 1, 2, 3],
92+
[ 4, 100, 6, 7],
93+
[ 8, 9, 100, 11],
94+
[ 12, 13, 14, 100]])
95+
96+
Now, we create indices to manipulate a 3-D array:
97+
98+
>>> d3 = np.diag_indices(2, 3)
99+
>>> d3
100+
(array([0, 1]), array([0, 1]), array([0, 1]))
101+
102+
And use it to set the diagonal of an array of zeros to 1:
103+
104+
>>> a = np.zeros((2, 2, 2), dtype=int)
105+
>>> a[d3] = 1
106+
>>> a
107+
array([[[1, 0],
108+
[0, 0]],
109+
[[0, 0],
110+
[0, 1]]])
111+
112+
"""
113+
114+
if not use_origin_backend():
115+
return dpnp_diag_indices(n, ndim)
116+
117+
return call_origin(numpy.diag_indices, n, ndim)
118+
119+
120+
def diag_indices_from(arr):
121+
"""
122+
Return the indices to access the main diagonal of an n-dimensional array.
123+
124+
For full documentation refer to :obj:`numpy.diag_indices_from`.
125+
126+
See also
127+
--------
128+
:obj:`diag_indices` : Return the indices to access the main
129+
diagonal of an array.
130+
131+
"""
132+
133+
is_a_dparray = isinstance(arr, dparray)
134+
135+
if (not use_origin_backend(arr) and is_a_dparray):
136+
# original limitation
137+
if not arr.ndim >= 2:
138+
checker_throw_value_error("diag_indices_from", "arr.ndim", arr.ndim, "at least 2-d")
139+
140+
# original limitation
141+
# For more than d=2, the strided formula is only valid for arrays with
142+
# all dimensions equal, so we check first.
143+
if not numpy.alltrue(numpy.diff(arr.shape) == 0): # TODO: replace alltrue and diff funcs with dpnp own ones
144+
checker_throw_value_error("diag_indices_from", "arr.shape", arr.shape,
145+
"All dimensions of input must be of equal length")
146+
147+
return dpnp_diag_indices(arr.shape[0], arr.ndim)
148+
149+
return call_origin(numpy.diag_indices_from, arr)
150+
151+
56152
def nonzero(a):
57153
"""
58154
Return the indices of the elements that are non-zero.

tests/skipped_tests.tbl

Lines changed: 328 additions & 0 deletions
Large diffs are not rendered by default.

tests/skipped_tests_gpu.tbl

Lines changed: 328 additions & 0 deletions
Large diffs are not rendered by default.

tests/third_party/cupy/indexing_tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)