Skip to content

Commit f8c99d5

Browse files
committed
Cover everything in test_unique_counts
1 parent d594ff5 commit f8c99d5

File tree

1 file changed

+49
-5
lines changed

1 file changed

+49
-5
lines changed

array_api_tests/test_set_functions.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from collections import Counter
23

34
from hypothesis import assume, given
45

@@ -8,6 +9,7 @@
89
from . import hypothesis_helpers as hh
910
from . import pytest_helpers as ph
1011
from . import xps
12+
from .test_searching_functions import assert_default_index
1113

1214

1315
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
@@ -16,10 +18,52 @@ def test_unique_all(x):
1618
# TODO
1719

1820

19-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
21+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
2022
def test_unique_counts(x):
21-
xp.unique_counts(x)
22-
# TODO
23+
out = xp.unique_counts(x)
24+
assert hasattr(out, "values")
25+
assert hasattr(out, "counts")
26+
ph.assert_dtype(
27+
"unique_counts", x.dtype, out.values.dtype, repr_name="out.values.dtype"
28+
)
29+
assert_default_index(
30+
"unique_counts", out.counts.dtype, repr_name="out.counts.dtype"
31+
)
32+
assert (
33+
out.counts.shape == out.values.shape
34+
), f"{out.counts.shape=}, but should be {out.values.shape=}"
35+
scalar_type = dh.get_scalar_type(out.values.dtype)
36+
counts = Counter(scalar_type(x[idx]) for idx in ah.ndindex(x.shape))
37+
vals_idx = {}
38+
nans = 0
39+
for idx in ah.ndindex(out.values.shape):
40+
val = scalar_type(out.values[idx])
41+
count = int(out.counts[idx])
42+
if math.isnan(val):
43+
nans += 1
44+
assert count == 1, (
45+
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
46+
"but count should be 1 as NaNs are distinct"
47+
)
48+
else:
49+
expected = counts[val]
50+
assert (
51+
expected > 0
52+
), f"out.values[{idx}]={val}, but {val} not in input array"
53+
count = int(out.counts[idx])
54+
assert count == expected, (
55+
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
56+
f"but should be {expected}"
57+
)
58+
assert (
59+
val not in vals_idx.keys()
60+
), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
61+
vals_idx[val] = idx
62+
if dh.is_float_dtype(out.values.dtype):
63+
assume(x.size <= 128) # may not be representable
64+
expected = sum(v for k, v in counts.items() if math.isnan(k))
65+
print(f"{counts=}")
66+
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
2367

2468

2569
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
@@ -48,5 +92,5 @@ def test_unique_values(x):
4892
vals_idx[val] = idx
4993
if dh.is_float_dtype(out.dtype):
5094
assume(x.size <= 128) # may not be representable
51-
expected_nans = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
52-
assert nans == expected_nans, f"{nans} NaNs in out, expected {expected_nans}"
95+
expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
96+
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"

0 commit comments

Comments
 (0)