Skip to content

Commit 429d698

Browse files
authored
add choose (#533)
1 parent 0503ad6 commit 429d698

File tree

6 files changed

+87
-16
lines changed

6 files changed

+87
-16
lines changed

dpnp/dparray.pyx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,18 @@ cdef class dparray:
851851
852852
return min(self, axis)
853853
854+
"""
855+
-------------------------------------------------------------------------
856+
Indexing
857+
-------------------------------------------------------------------------
858+
"""
859+
860+
def choose(input, choices, out=None, mode='raise'):
861+
"""
862+
Construct an array from an index array and a set of arrays to choose from.
863+
"""
864+
return choose(input, choices, out, mode)
865+
854866
"""
855867
-------------------------------------------------------------------------
856868
Sorting

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ from dpnp.dpnp_iface_counting import count_nonzero
3939

4040

4141
__all__ += [
42+
"dpnp_choose",
4243
"dpnp_diag_indices",
4344
"dpnp_diagonal",
4445
"dpnp_fill_diagonal",
@@ -56,6 +57,13 @@ __all__ += [
5657
]
5758

5859

60+
cpdef dparray dpnp_choose(input, choices):
61+
res_array = dparray(len(input), dtype=choices[0].dtype)
62+
for i in range(len(input)):
63+
res_array[i] = (choices[input[i]])[i]
64+
return res_array
65+
66+
5967
cpdef tuple dpnp_diag_indices(n, ndim):
6068
cdef dparray res_item = dpnp.arange(n, dtype=dpnp.int64)
6169

dpnp/dpnp_iface_indexing.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252

5353
__all__ = [
54+
"choose",
5455
"diag_indices",
5556
"diag_indices_from",
5657
"diagonal",
@@ -69,6 +70,59 @@
6970
]
7071

7172

73+
def choose(input, choices, out=None, mode='raise'):
74+
"""
75+
Construct an array from an index array and a set of arrays to choose from.
76+
77+
For full documentation refer to :obj:`numpy.choose`.
78+
79+
See also
80+
--------
81+
:obj:`take_along_axis` : Preferable if choices is an array.
82+
"""
83+
if not use_origin_backend(input):
84+
if not isinstance(input, list) and not isinstance(input, dparray):
85+
pass
86+
elif not isinstance(choices, list):
87+
pass
88+
elif out is not None:
89+
pass
90+
elif mode != 'raise':
91+
pass
92+
elif isinstance(choices, list):
93+
val = True
94+
for i in range(len(choices)):
95+
if not isinstance(choices[i], dparray):
96+
val = False
97+
break
98+
if not val:
99+
pass
100+
else:
101+
val = True
102+
len_ = len(input)
103+
size_ = choices[0].size
104+
for i in range(len(choices)):
105+
if choices[i].size != size_ or choices[i].size != len_:
106+
val = False
107+
break
108+
if not val:
109+
pass
110+
else:
111+
val = True
112+
for i in range(len_):
113+
if input[i] >= size_:
114+
val = False
115+
break
116+
if not val:
117+
pass
118+
else:
119+
return dpnp_choose(input, choices)
120+
else:
121+
return dpnp_choose(input, choices)
122+
123+
return call_origin(numpy.choose, input, choices, out, mode)
124+
125+
72126
def diag_indices(n, ndim=2):
73127
"""
74128
Return the indices to access the main diagonal of an array.

tests/skipped_tests.tbl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -582,14 +582,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_
582582
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_by_scalar
583583
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
584584
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_no_axis
585-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose
586-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_broadcast
587-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_broadcast2
588-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_broadcast_fail
589-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_clip
590-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_wrap
591-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_raise
592-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_unknown_clip
593585
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
594586
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
595587
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -735,14 +735,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_
735735
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_by_scalar
736736
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow
737737
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_no_axis
738-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose
739-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_broadcast
740-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_broadcast2
741-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_broadcast_fail
742-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_clip
743-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_choose_wrap
744-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_raise
745-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestChoose::test_unknown_clip
746738
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
747739
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
748740
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/test_indexing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@
55
import numpy
66

77

8+
def test_choose():
9+
a = numpy.r_[:4]
10+
ia = dpnp.array(a)
11+
b = numpy.r_[-4:0]
12+
ib = dpnp.array(b)
13+
c = numpy.r_[100:500:100]
14+
ic = dpnp.array(c)
15+
16+
expected = numpy.choose([0, 0, 0, 0], [a, b, c])
17+
result = dpnp.choose([0, 0, 0, 0], [ia, ib, ic])
18+
numpy.testing.assert_array_equal(expected, result)
19+
20+
821
@pytest.mark.parametrize("offset",
922
[0, 1],
1023
ids=['0', '1'])

0 commit comments

Comments
 (0)