Skip to content

Commit 4fc6133

Browse files
committed
Fix the array API unique_*() functions to not compare nans as equal
The spec requires this, but it is only now possible to implement with the new equal_nan flag in np.unique(). Original NumPy Commit: 12f83eb7337b840e3cd9026779b99b1af8033bf3
1 parent 606fdb3 commit 4fc6133

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

array_api_strict/_set_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def unique_all(x: Array, /) -> UniqueAllResult:
4646
return_counts=True,
4747
return_index=True,
4848
return_inverse=True,
49+
equal_nan=False,
4950
)
5051
# np.unique() flattens inverse indices, but they need to share x's shape
5152
# See https://github.com/numpy/numpy/issues/20638
@@ -64,6 +65,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult:
6465
return_counts=True,
6566
return_index=False,
6667
return_inverse=False,
68+
equal_nan=False,
6769
)
6870

6971
return UniqueCountsResult(*[Array._new(i) for i in res])
@@ -80,6 +82,7 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
8082
return_counts=False,
8183
return_index=False,
8284
return_inverse=True,
85+
equal_nan=False,
8386
)
8487
# np.unique() flattens inverse indices, but they need to share x's shape
8588
# See https://github.com/numpy/numpy/issues/20638
@@ -98,5 +101,6 @@ def unique_values(x: Array, /) -> Array:
98101
return_counts=False,
99102
return_index=False,
100103
return_inverse=False,
104+
equal_nan=False,
101105
)
102106
return Array._new(res)

0 commit comments

Comments
 (0)