Skip to content

Commit 399d74d

Browse files
authored
Select (#527)
* add select
1 parent d6c7cb6 commit 399d74d

File tree

5 files changed

+73
-2
lines changed

5 files changed

+73
-2
lines changed

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ __all__ += [
4747
"dpnp_place",
4848
"dpnp_put",
4949
"dpnp_putmask",
50+
"dpnp_select",
5051
"dpnp_take",
5152
"dpnp_tril_indices",
5253
"dpnp_tril_indices_from",
@@ -232,6 +233,22 @@ cpdef dpnp_putmask(dparray arr, dparray mask, dparray values):
232233
arr[i] = values[i % values_size]
233234

234235

236+
cpdef dparray dpnp_select(condlist, choicelist, default):
237+
size_ = condlist[0].size
238+
res_array = dparray(size_, dtype=choicelist[0].dtype)
239+
pass_val = {a: default for a in range(size_)}
240+
for i in range(len(condlist)):
241+
for j in range(size_):
242+
if (condlist[i])[j]:
243+
res_array[j] = (choicelist[i])[j]
244+
pass_val.pop(j)
245+
246+
for ind, val in pass_val.items():
247+
res_array[ind] = val
248+
249+
return res_array.reshape(condlist[0].shape)
250+
251+
235252
cpdef dparray dpnp_take(dparray input, dparray indices):
236253
indices_size = indices.size
237254
res_array = dparray(indices_size, dtype=input.dtype)

dpnp/dpnp_iface_indexing.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"place",
6161
"put",
6262
"putmask",
63+
"select",
6364
"take",
6465
"tril_indices",
6566
"tril_indices_from",
@@ -371,6 +372,43 @@ def putmask(arr, mask, values):
371372
return call_origin(numpy.putmask, arr, mask, values)
372373

373374

375+
def select(condlist, choicelist, default=0):
376+
"""
377+
Return an array drawn from elements in choicelist, depending on conditions.
378+
For full documentation refer to :obj:`numpy.select`.
379+
380+
Limitations
381+
-----------
382+
Arrays of input lists are supported as :obj:`dpnp.ndarray`.
383+
Parameter ``default`` are supported only with default values.
384+
"""
385+
if not use_origin_backend():
386+
if not isinstance(condlist, list):
387+
pass
388+
elif not isinstance(condlist[0], dparray):
389+
pass
390+
elif not isinstance(choicelist, list):
391+
pass
392+
elif not isinstance(choicelist[0], dparray):
393+
pass
394+
elif len(condlist) != len(choicelist):
395+
pass
396+
elif len(condlist) == len(choicelist):
397+
val = True
398+
size_ = condlist[0].size
399+
for i in range(len(condlist)):
400+
if condlist[i].size != size_ or choicelist[i].size != size_:
401+
val = False
402+
if not val:
403+
pass
404+
else:
405+
return dpnp_select(condlist, choicelist, default)
406+
else:
407+
return dpnp_select(condlist, choicelist, default)
408+
409+
return call_origin(numpy.select, condlist, choicelist, default)
410+
411+
374412
def take(input, indices, axis=None, out=None, mode='raise'):
375413
"""
376414
Take elements from an array.

tests/skipped_tests.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_
599599
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error
600600
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable
601601
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable
602-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_choicelist
603602
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_condlist
604603
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises
605604
tests/third_party/cupy/indexing_tests/test_insert.py::TestPlace_param_0_{n_vals=0, shape=(7,)}::test_place

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_
752752
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_length_error
753753
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable
754754
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable
755-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_choicelist
756755
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_condlist
757756
tests/third_party/cupy/indexing_tests/test_insert.py::TestPutmask::test_putmask_non_equal_shape_raises
758757
tests/third_party/cupy/indexing_tests/test_insert.py::TestPlaceRaises_param_0_{shape=(7,)}::test_place_empty_value_error

tests/test_indexing.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,24 @@ def test_putmask3(arr, mask, vals):
284284
numpy.testing.assert_array_equal(a, ia)
285285

286286

287+
def test_select():
288+
cond_val1 = numpy.array([True, True, True, False, False, False, False, False, False, False])
289+
cond_val2 = numpy.array([False, False, False, False, False, True, True, True, True, True])
290+
icond_val1 = dpnp.array(cond_val1)
291+
icond_val2 = dpnp.array(cond_val2)
292+
condlist = [cond_val1, cond_val2]
293+
icondlist = [icond_val1, icond_val2]
294+
choice_val1 = numpy.full(10, -2)
295+
choice_val2 = numpy.full(10, -1)
296+
ichoice_val1 = dpnp.array(choice_val1)
297+
ichoice_val2 = dpnp.array(choice_val2)
298+
choicelist = [choice_val1, choice_val2]
299+
ichoicelist = [ichoice_val1, ichoice_val2]
300+
expected = numpy.select(condlist, choicelist)
301+
result = dpnp.select(icondlist, ichoicelist)
302+
numpy.testing.assert_array_equal(expected, result)
303+
304+
287305
@pytest.mark.parametrize("indices",
288306
[[[0, 0], [0, 0]],
289307
[[1, 2], [1, 2]],

0 commit comments

Comments
 (0)