43
43
import collections
44
44
45
45
from dpnp .dpnp_algo import *
46
- from dpnp .dparray import dparray
47
46
from dpnp .dpnp_utils import *
48
47
49
48
import dpnp
@@ -83,44 +82,34 @@ def choose(x1, choices, out=None, mode='raise'):
83
82
:obj:`take_along_axis` : Preferable if choices is an array.
84
83
"""
85
84
if not use_origin_backend (x1 ):
86
- if not isinstance (x1 , list ) and not isinstance ( x1 , dparray ) :
85
+ if not isinstance (x1 , list ):
87
86
pass
88
87
elif not isinstance (choices , list ):
89
88
pass
90
89
elif out is not None :
91
90
pass
92
91
elif mode != 'raise' :
93
92
pass
94
- elif isinstance ( choices , list ) :
93
+ else :
95
94
val = True
95
+ len_ = len (x1 )
96
+ size_ = choices [0 ].size
96
97
for i in range (len (choices )):
97
- if not isinstance ( choices [i ], dparray ) :
98
+ if choices [ i ]. size != size_ or choices [i ]. size != len_ :
98
99
val = False
99
100
break
100
101
if not val :
101
102
pass
102
103
else :
103
104
val = True
104
- len_ = len (x1 )
105
- size_ = choices [0 ].size
106
- for i in range (len (choices )):
107
- if choices [i ].size != size_ or choices [i ].size != len_ :
105
+ for i in range (len_ ):
106
+ if x1 [i ] >= size_ :
108
107
val = False
109
108
break
110
109
if not val :
111
110
pass
112
111
else :
113
- val = True
114
- for i in range (len_ ):
115
- if x1 [i ] >= size_ :
116
- val = False
117
- break
118
- if not val :
119
- pass
120
- else :
121
- return dpnp_choose (x1 , choices ).get_pyobj ()
122
- else :
123
- return dpnp_choose (x1 , choices ).get_pyobj ()
112
+ return dpnp_choose (x1 , choices ).get_pyobj ()
124
113
125
114
return call_origin (numpy .choose , x1 , choices , out , mode )
126
115
@@ -456,15 +445,11 @@ def select(condlist, choicelist, default=0):
456
445
if not use_origin_backend ():
457
446
if not isinstance (condlist , list ):
458
447
pass
459
- elif not isinstance (condlist [0 ], dparray ):
460
- pass
461
448
elif not isinstance (choicelist , list ):
462
449
pass
463
- elif not isinstance (choicelist [0 ], dparray ):
464
- pass
465
450
elif len (condlist ) != len (choicelist ):
466
451
pass
467
- elif len ( condlist ) == len ( choicelist ) :
452
+ else :
468
453
val = True
469
454
size_ = condlist [0 ].size
470
455
for i in range (len (condlist )):
@@ -473,9 +458,7 @@ def select(condlist, choicelist, default=0):
473
458
if not val :
474
459
pass
475
460
else :
476
- return dpnp_select (condlist , choicelist , default )
477
- else :
478
- return dpnp_select (condlist , choicelist , default )
461
+ return dpnp_select (condlist , choicelist , default ).get_pyobj ()
479
462
480
463
return call_origin (numpy .select , condlist , choicelist , default )
481
464
0 commit comments