11
11
12
12
from typing import NamedTuple
13
13
from types import ModuleType
14
+ import inspect
14
15
15
16
from ._helpers import _check_device , _is_numpy_array , get_namespace
16
17
@@ -161,13 +162,23 @@ class UniqueInverseResult(NamedTuple):
161
162
inverse_indices : ndarray
162
163
163
164
165
+ def _unique_kwargs (xp ):
166
+ # Older versions of NumPy and CuPy do not have equal_nan. Rather than
167
+ # trying to parse version numbers, just check if equal_nan is in the
168
+ # signature.
169
+ s = inspect .signature (xp .unique )
170
+ if 'equal_nan' in s .parameters :
171
+ return {'equal_nan' : False }
172
+ return {}
173
+
164
174
def unique_all (x : ndarray , / , xp ) -> UniqueAllResult :
175
+ kwargs = _unique_kwargs (xp )
165
176
values , indices , inverse_indices , counts = xp .unique (
166
177
x ,
167
178
return_counts = True ,
168
179
return_index = True ,
169
180
return_inverse = True ,
170
- equal_nan = False ,
181
+ ** kwargs ,
171
182
)
172
183
# np.unique() flattens inverse indices, but they need to share x's shape
173
184
# See https://github.com/numpy/numpy/issues/20638
@@ -181,24 +192,26 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
181
192
182
193
183
194
def unique_counts (x : ndarray , / , xp ) -> UniqueCountsResult :
195
+ kwargs = _unique_kwargs (xp )
184
196
res = xp .unique (
185
197
x ,
186
198
return_counts = True ,
187
199
return_index = False ,
188
200
return_inverse = False ,
189
- equal_nan = False ,
201
+ ** kwargs
190
202
)
191
203
192
204
return UniqueCountsResult (* res )
193
205
194
206
195
207
def unique_inverse (x : ndarray , / , xp ) -> UniqueInverseResult :
208
+ kwargs = _unique_kwargs (xp )
196
209
values , inverse_indices = xp .unique (
197
210
x ,
198
211
return_counts = False ,
199
212
return_index = False ,
200
213
return_inverse = True ,
201
- equal_nan = False ,
214
+ ** kwargs ,
202
215
)
203
216
# xp.unique() flattens inverse indices, but they need to share x's shape
204
217
# See https://github.com/numpy/numpy/issues/20638
@@ -207,12 +220,13 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
207
220
208
221
209
222
def unique_values (x : ndarray , / , xp ) -> ndarray :
223
+ kwargs = _unique_kwargs (xp )
210
224
return xp .unique (
211
225
x ,
212
226
return_counts = False ,
213
227
return_index = False ,
214
228
return_inverse = False ,
215
- equal_nan = False ,
229
+ ** kwargs ,
216
230
)
217
231
218
232
def astype (x : ndarray , dtype : Dtype , / , * , copy : bool = True ) -> ndarray :
@@ -295,8 +309,13 @@ def _asarray(
295
309
_check_device (xp , device )
296
310
if _is_numpy_array (obj ):
297
311
import numpy as np
298
- COPY_FALSE = (False , np ._CopyMode .IF_NEEDED )
299
- COPY_TRUE = (True , np ._CopyMode .ALWAYS )
312
+ if hasattr (np , '_CopyMode' ):
313
+ # Not present in older NumPys
314
+ COPY_FALSE = (False , np ._CopyMode .IF_NEEDED )
315
+ COPY_TRUE = (True , np ._CopyMode .ALWAYS )
316
+ else :
317
+ COPY_FALSE = (False ,)
318
+ COPY_TRUE = (True ,)
300
319
else :
301
320
COPY_FALSE = (False ,)
302
321
COPY_TRUE = (True ,)
0 commit comments