Skip to content

Commit 131dd31

Browse files
committed
Cover everything in test_unique_all (if messily)
1 parent ab8674f commit 131dd31

File tree

1 file changed

+98
-4
lines changed

1 file changed

+98
-4
lines changed

array_api_tests/test_set_functions.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TODO: disable if opted out
22
import math
3-
from collections import Counter
3+
from collections import Counter, defaultdict
44

55
from hypothesis import assume, given
66

@@ -13,10 +13,104 @@
1313
from .test_searching_functions import assert_default_index
1414

1515

16-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
16+
@given(
17+
xps.arrays(
18+
dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1, min_dims=1, max_dims=1)
19+
)
20+
) # TODO
1721
def test_unique_all(x):
18-
xp.unique_all(x)
19-
# TODO
22+
out = xp.unique_all(x)
23+
24+
assert hasattr(out, "values")
25+
assert hasattr(out, "indices")
26+
assert hasattr(out, "inverse_indices")
27+
assert hasattr(out, "counts")
28+
29+
ph.assert_dtype(
30+
"unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype"
31+
)
32+
assert_default_index("unique_all", out.indices.dtype, repr_name="out.indices.dtype")
33+
assert_default_index(
34+
"unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype"
35+
)
36+
assert_default_index("unique_all", out.counts.dtype, repr_name="out.counts.dtype")
37+
38+
assert (
39+
out.indices.shape == out.values.shape
40+
), f"{out.indices.shape=}, but should be {out.values.shape=}"
41+
ph.assert_shape(
42+
"unique_all",
43+
out.inverse_indices.shape,
44+
x.shape,
45+
repr_name="out.inverse_indices.shape",
46+
)
47+
assert (
48+
out.counts.shape == out.values.shape
49+
), f"{out.counts.shape=}, but should be {out.values.shape=}"
50+
51+
scalar_type = dh.get_scalar_type(out.values.dtype)
52+
counts = defaultdict(int)
53+
firsts = {}
54+
for i, idx in enumerate(ah.ndindex(x.shape)):
55+
val = scalar_type(x[idx])
56+
if counts[val] == 0:
57+
firsts[val] = i
58+
counts[val] += 1
59+
60+
for idx in ah.ndindex(out.indices.shape):
61+
val = scalar_type(out.values[idx])
62+
if math.isnan(val):
63+
break
64+
i = int(out.indices[idx])
65+
expected = firsts[val]
66+
assert i == expected, (
67+
f"out.values[{idx}]={val} and out.indices[{idx}]={i}, "
68+
f"but first occurence of {val} is at {expected}"
69+
)
70+
71+
for idx in ah.ndindex(out.inverse_indices.shape):
72+
ridx = int(out.inverse_indices[idx])
73+
val = out.values[ridx]
74+
expected = x[idx]
75+
msg = (
76+
f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, "
77+
f"but should result in x[{idx}]={expected}"
78+
)
79+
if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected):
80+
assert xp.isnan(val), msg
81+
else:
82+
assert val == expected, msg
83+
84+
vals_idx = {}
85+
nans = 0
86+
for idx in ah.ndindex(out.values.shape):
87+
val = scalar_type(out.values[idx])
88+
count = int(out.counts[idx])
89+
if math.isnan(val):
90+
nans += 1
91+
assert count == 1, (
92+
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
93+
"but count should be 1 as NaNs are distinct"
94+
)
95+
else:
96+
expected = counts[val]
97+
assert (
98+
expected > 0
99+
), f"out.values[{idx}]={val}, but {val} not in input array"
100+
count = int(out.counts[idx])
101+
assert count == expected, (
102+
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
103+
f"but should be {expected}"
104+
)
105+
assert (
106+
val not in vals_idx.keys()
107+
), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
108+
vals_idx[val] = idx
109+
110+
if dh.is_float_dtype(out.values.dtype):
111+
assume(x.size <= 128) # may not be representable
112+
expected = sum(v for k, v in counts.items() if math.isnan(k))
113+
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
20114

21115

22116
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))

0 commit comments

Comments
 (0)