Skip to content

Commit b8c1e3e

Browse files
authored
indexing module to descriptor 3 (#878)
1 parent c5a4170 commit b8c1e3e

File tree

2 files changed

+11
-28
lines changed

2 files changed

+11
-28
lines changed

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ cpdef dpnp_putmask(object arr, object mask, object values):
269269
arr[i] = values[i % values_size]
270270

271271

272-
cpdef object dpnp_select(condlist, choicelist, default):
272+
cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default):
273273
cdef size_t size_ = condlist[0].size
274274
cdef utils.dpnp_descriptor res_array = utils_py.create_output_descriptor_py(condlist[0].shape, choicelist[0].dtype, None)
275275

dpnp/dpnp_iface_indexing.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import collections
4444

4545
from dpnp.dpnp_algo import *
46-
from dpnp.dparray import dparray
4746
from dpnp.dpnp_utils import *
4847

4948
import dpnp
@@ -83,44 +82,34 @@ def choose(x1, choices, out=None, mode='raise'):
8382
:obj:`take_along_axis` : Preferable if choices is an array.
8483
"""
8584
if not use_origin_backend(x1):
86-
if not isinstance(x1, list) and not isinstance(x1, dparray):
85+
if not isinstance(x1, list):
8786
pass
8887
elif not isinstance(choices, list):
8988
pass
9089
elif out is not None:
9190
pass
9291
elif mode != 'raise':
9392
pass
94-
elif isinstance(choices, list):
93+
else:
9594
val = True
95+
len_ = len(x1)
96+
size_ = choices[0].size
9697
for i in range(len(choices)):
97-
if not isinstance(choices[i], dparray):
98+
if choices[i].size != size_ or choices[i].size != len_:
9899
val = False
99100
break
100101
if not val:
101102
pass
102103
else:
103104
val = True
104-
len_ = len(x1)
105-
size_ = choices[0].size
106-
for i in range(len(choices)):
107-
if choices[i].size != size_ or choices[i].size != len_:
105+
for i in range(len_):
106+
if x1[i] >= size_:
108107
val = False
109108
break
110109
if not val:
111110
pass
112111
else:
113-
val = True
114-
for i in range(len_):
115-
if x1[i] >= size_:
116-
val = False
117-
break
118-
if not val:
119-
pass
120-
else:
121-
return dpnp_choose(x1, choices).get_pyobj()
122-
else:
123-
return dpnp_choose(x1, choices).get_pyobj()
112+
return dpnp_choose(x1, choices).get_pyobj()
124113

125114
return call_origin(numpy.choose, x1, choices, out, mode)
126115

@@ -456,15 +445,11 @@ def select(condlist, choicelist, default=0):
456445
if not use_origin_backend():
457446
if not isinstance(condlist, list):
458447
pass
459-
elif not isinstance(condlist[0], dparray):
460-
pass
461448
elif not isinstance(choicelist, list):
462449
pass
463-
elif not isinstance(choicelist[0], dparray):
464-
pass
465450
elif len(condlist) != len(choicelist):
466451
pass
467-
elif len(condlist) == len(choicelist):
452+
else:
468453
val = True
469454
size_ = condlist[0].size
470455
for i in range(len(condlist)):
@@ -473,9 +458,7 @@ def select(condlist, choicelist, default=0):
473458
if not val:
474459
pass
475460
else:
476-
return dpnp_select(condlist, choicelist, default)
477-
else:
478-
return dpnp_select(condlist, choicelist, default)
461+
return dpnp_select(condlist, choicelist, default).get_pyobj()
479462

480463
return call_origin(numpy.select, condlist, choicelist, default)
481464

0 commit comments

Comments
 (0)