Skip to content

Commit 54f5f64

Browse files
committed
Fix some compatibility issues with NumPy 1.21
The equal_nan keyword to unique() and np._CopyMode aren't in this version.
1 parent bb5da85 commit 54f5f64

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

array_api_compat/common/_aliases.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from typing import NamedTuple
1313
from types import ModuleType
14+
import inspect
1415

1516
from ._helpers import _check_device, _is_numpy_array, get_namespace
1617

@@ -161,13 +162,23 @@ class UniqueInverseResult(NamedTuple):
161162
inverse_indices: ndarray
162163

163164

165+
def _unique_kwargs(xp):
166+
# Older versions of NumPy and CuPy do not have equal_nan. Rather than
167+
# trying to parse version numbers, just check if equal_nan is in the
168+
# signature.
169+
s = inspect.signature(xp.unique)
170+
if 'equal_nan' in s.parameters:
171+
return {'equal_nan': False}
172+
return {}
173+
164174
def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
175+
kwargs = _unique_kwargs(xp)
165176
values, indices, inverse_indices, counts = xp.unique(
166177
x,
167178
return_counts=True,
168179
return_index=True,
169180
return_inverse=True,
170-
equal_nan=False,
181+
**kwargs,
171182
)
172183
# np.unique() flattens inverse indices, but they need to share x's shape
173184
# See https://github.com/numpy/numpy/issues/20638
@@ -181,24 +192,26 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult:
181192

182193

183194
def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult:
195+
kwargs = _unique_kwargs(xp)
184196
res = xp.unique(
185197
x,
186198
return_counts=True,
187199
return_index=False,
188200
return_inverse=False,
189-
equal_nan=False,
201+
**kwargs
190202
)
191203

192204
return UniqueCountsResult(*res)
193205

194206

195207
def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
208+
kwargs = _unique_kwargs(xp)
196209
values, inverse_indices = xp.unique(
197210
x,
198211
return_counts=False,
199212
return_index=False,
200213
return_inverse=True,
201-
equal_nan=False,
214+
**kwargs,
202215
)
203216
# xp.unique() flattens inverse indices, but they need to share x's shape
204217
# See https://github.com/numpy/numpy/issues/20638
@@ -207,12 +220,13 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult:
207220

208221

209222
def unique_values(x: ndarray, /, xp) -> ndarray:
223+
kwargs = _unique_kwargs(xp)
210224
return xp.unique(
211225
x,
212226
return_counts=False,
213227
return_index=False,
214228
return_inverse=False,
215-
equal_nan=False,
229+
**kwargs,
216230
)
217231

218232
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
@@ -295,8 +309,13 @@ def _asarray(
295309
_check_device(xp, device)
296310
if _is_numpy_array(obj):
297311
import numpy as np
298-
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
299-
COPY_TRUE = (True, np._CopyMode.ALWAYS)
312+
if hasattr(np, '_CopyMode'):
313+
# Not present in older NumPys
314+
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
315+
COPY_TRUE = (True, np._CopyMode.ALWAYS)
316+
else:
317+
COPY_FALSE = (False,)
318+
COPY_TRUE = (True,)
300319
else:
301320
COPY_FALSE = (False,)
302321
COPY_TRUE = (True,)

0 commit comments

Comments
 (0)