|
| 1 | +# TODO: disable if opted out |
1 | 2 | import math
|
2 | 3 | from collections import Counter
|
3 | 4 |
|
@@ -62,14 +63,60 @@ def test_unique_counts(x):
|
62 | 63 | if dh.is_float_dtype(out.values.dtype):
|
63 | 64 | assume(x.size <= 128) # may not be representable
|
64 | 65 | expected = sum(v for k, v in counts.items() if math.isnan(k))
|
65 |
| - print(f"{counts=}") |
66 | 66 | assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
|
67 | 67 |
|
68 | 68 |
|
69 |
| -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) |
| 69 | +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) |
70 | 70 | def test_unique_inverse(x):
|
71 |
| - xp.unique_inverse(x) |
72 |
| - # TODO |
| 71 | + out = xp.unique_inverse(x) |
| 72 | + assert hasattr(out, "values") |
| 73 | + assert hasattr(out, "inverse_indices") |
| 74 | + ph.assert_dtype( |
| 75 | + "unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype" |
| 76 | + ) |
| 77 | + assert_default_index( |
| 78 | + "unique_inverse", |
| 79 | + out.inverse_indices.dtype, |
| 80 | + repr_name="out.inverse_indices.dtype", |
| 81 | + ) |
| 82 | + ph.assert_shape( |
| 83 | + "unique_inverse", |
| 84 | + out.inverse_indices.shape, |
| 85 | + x.shape, |
| 86 | + repr_name="out.inverse_indices.shape", |
| 87 | + ) |
| 88 | + scalar_type = dh.get_scalar_type(out.values.dtype) |
| 89 | + distinct = set(scalar_type(x[idx]) for idx in ah.ndindex(x.shape)) |
| 90 | + vals_idx = {} |
| 91 | + nans = 0 |
| 92 | + for idx in ah.ndindex(out.values.shape): |
| 93 | + val = scalar_type(out.values[idx]) |
| 94 | + if math.isnan(val): |
| 95 | + nans += 1 |
| 96 | + else: |
| 97 | + assert ( |
| 98 | + val in distinct |
| 99 | + ), f"out.values[{idx}]={val}, but {val} not in input array" |
| 100 | + assert ( |
| 101 | + val not in vals_idx.keys() |
| 102 | + ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" |
| 103 | + vals_idx[val] = idx |
| 104 | + for idx in ah.ndindex(out.inverse_indices.shape): |
| 105 | + ridx = int(out.inverse_indices[idx]) |
| 106 | + val = out.values[ridx] |
| 107 | + expected = x[idx] |
| 108 | + msg = ( |
| 109 | + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " |
| 110 | + f"but should result in x[{idx}]={expected}" |
| 111 | + ) |
| 112 | + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): |
| 113 | + assert xp.isnan(val), msg |
| 114 | + else: |
| 115 | + assert val == expected, msg |
| 116 | + if dh.is_float_dtype(out.values.dtype): |
| 117 | + assume(x.size <= 128) # may not be representable |
| 118 | + expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) |
| 119 | + assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" |
73 | 120 |
|
74 | 121 |
|
75 | 122 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
|
|
0 commit comments