Skip to content

Commit ab8674f

Browse files
committed
Cover everything in test_unique_inverse
1 parent f8c99d5 commit ab8674f

File tree

1 file changed

+51
-4
lines changed

1 file changed

+51
-4
lines changed

array_api_tests/test_set_functions.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# TODO: disable if opted out
12
import math
23
from collections import Counter
34

@@ -62,14 +63,60 @@ def test_unique_counts(x):
6263
if dh.is_float_dtype(out.values.dtype):
6364
assume(x.size <= 128) # may not be representable
6465
expected = sum(v for k, v in counts.items() if math.isnan(k))
65-
print(f"{counts=}")
6666
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
6767

6868

69-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
69+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
7070
def test_unique_inverse(x):
71-
xp.unique_inverse(x)
72-
# TODO
71+
out = xp.unique_inverse(x)
72+
assert hasattr(out, "values")
73+
assert hasattr(out, "inverse_indices")
74+
ph.assert_dtype(
75+
"unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype"
76+
)
77+
assert_default_index(
78+
"unique_inverse",
79+
out.inverse_indices.dtype,
80+
repr_name="out.inverse_indices.dtype",
81+
)
82+
ph.assert_shape(
83+
"unique_inverse",
84+
out.inverse_indices.shape,
85+
x.shape,
86+
repr_name="out.inverse_indices.shape",
87+
)
88+
scalar_type = dh.get_scalar_type(out.values.dtype)
89+
distinct = set(scalar_type(x[idx]) for idx in ah.ndindex(x.shape))
90+
vals_idx = {}
91+
nans = 0
92+
for idx in ah.ndindex(out.values.shape):
93+
val = scalar_type(out.values[idx])
94+
if math.isnan(val):
95+
nans += 1
96+
else:
97+
assert (
98+
val in distinct
99+
), f"out.values[{idx}]={val}, but {val} not in input array"
100+
assert (
101+
val not in vals_idx.keys()
102+
), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
103+
vals_idx[val] = idx
104+
for idx in ah.ndindex(out.inverse_indices.shape):
105+
ridx = int(out.inverse_indices[idx])
106+
val = out.values[ridx]
107+
expected = x[idx]
108+
msg = (
109+
f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, "
110+
f"but should result in x[{idx}]={expected}"
111+
)
112+
if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected):
113+
assert xp.isnan(val), msg
114+
else:
115+
assert val == expected, msg
116+
if dh.is_float_dtype(out.values.dtype):
117+
assume(x.size <= 128) # may not be representable
118+
expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
119+
assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}"
73120

74121

75122
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))

0 commit comments

Comments
 (0)