|
1 | 1 | # TODO: disable if opted out
|
2 | 2 | import math
|
3 |
| -from collections import Counter |
| 3 | +from collections import Counter, defaultdict |
4 | 4 |
|
5 | 5 | from hypothesis import assume, given
|
6 | 6 |
|
|
13 | 13 | from .test_searching_functions import assert_default_index
|
14 | 14 |
|
15 | 15 |
|
16 |
| -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes())) |
| 16 | +@given( |
| 17 | + xps.arrays( |
| 18 | + dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1, min_dims=1, max_dims=1) |
| 19 | + ) |
| 20 | +) # TODO |
17 | 21 | def test_unique_all(x):
|
18 |
| - xp.unique_all(x) |
19 |
| - # TODO |
| 22 | + out = xp.unique_all(x) |
| 23 | + |
| 24 | + assert hasattr(out, "values") |
| 25 | + assert hasattr(out, "indices") |
| 26 | + assert hasattr(out, "inverse_indices") |
| 27 | + assert hasattr(out, "counts") |
| 28 | + |
| 29 | + ph.assert_dtype( |
| 30 | + "unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype" |
| 31 | + ) |
| 32 | + assert_default_index("unique_all", out.indices.dtype, repr_name="out.indices.dtype") |
| 33 | + assert_default_index( |
| 34 | + "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype" |
| 35 | + ) |
| 36 | + assert_default_index("unique_all", out.counts.dtype, repr_name="out.counts.dtype") |
| 37 | + |
| 38 | + assert ( |
| 39 | + out.indices.shape == out.values.shape |
| 40 | + ), f"{out.indices.shape=}, but should be {out.values.shape=}" |
| 41 | + ph.assert_shape( |
| 42 | + "unique_all", |
| 43 | + out.inverse_indices.shape, |
| 44 | + x.shape, |
| 45 | + repr_name="out.inverse_indices.shape", |
| 46 | + ) |
| 47 | + assert ( |
| 48 | + out.counts.shape == out.values.shape |
| 49 | + ), f"{out.counts.shape=}, but should be {out.values.shape=}" |
| 50 | + |
| 51 | + scalar_type = dh.get_scalar_type(out.values.dtype) |
| 52 | + counts = defaultdict(int) |
| 53 | + firsts = {} |
| 54 | + for i, idx in enumerate(ah.ndindex(x.shape)): |
| 55 | + val = scalar_type(x[idx]) |
| 56 | + if counts[val] == 0: |
| 57 | + firsts[val] = i |
| 58 | + counts[val] += 1 |
| 59 | + |
| 60 | + for idx in ah.ndindex(out.indices.shape): |
| 61 | + val = scalar_type(out.values[idx]) |
| 62 | + if math.isnan(val): |
| 63 | + break |
| 64 | + i = int(out.indices[idx]) |
| 65 | + expected = firsts[val] |
| 66 | + assert i == expected, ( |
| 67 | + f"out.values[{idx}]={val} and out.indices[{idx}]={i}, " |
| 68 | + f"but first occurence of {val} is at {expected}" |
| 69 | + ) |
| 70 | + |
| 71 | + for idx in ah.ndindex(out.inverse_indices.shape): |
| 72 | + ridx = int(out.inverse_indices[idx]) |
| 73 | + val = out.values[ridx] |
| 74 | + expected = x[idx] |
| 75 | + msg = ( |
| 76 | + f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, " |
| 77 | + f"but should result in x[{idx}]={expected}" |
| 78 | + ) |
| 79 | + if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected): |
| 80 | + assert xp.isnan(val), msg |
| 81 | + else: |
| 82 | + assert val == expected, msg |
| 83 | + |
| 84 | + vals_idx = {} |
| 85 | + nans = 0 |
| 86 | + for idx in ah.ndindex(out.values.shape): |
| 87 | + val = scalar_type(out.values[idx]) |
| 88 | + count = int(out.counts[idx]) |
| 89 | + if math.isnan(val): |
| 90 | + nans += 1 |
| 91 | + assert count == 1, ( |
| 92 | + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " |
| 93 | + "but count should be 1 as NaNs are distinct" |
| 94 | + ) |
| 95 | + else: |
| 96 | + expected = counts[val] |
| 97 | + assert ( |
| 98 | + expected > 0 |
| 99 | + ), f"out.values[{idx}]={val}, but {val} not in input array" |
| 100 | + count = int(out.counts[idx]) |
| 101 | + assert count == expected, ( |
| 102 | + f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " |
| 103 | + f"but should be {expected}" |
| 104 | + ) |
| 105 | + assert ( |
| 106 | + val not in vals_idx.keys() |
| 107 | + ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" |
| 108 | + vals_idx[val] = idx |
| 109 | + |
| 110 | + if dh.is_float_dtype(out.values.dtype): |
| 111 | + assume(x.size <= 128) # may not be representable |
| 112 | + expected = sum(v for k, v in counts.items() if math.isnan(k)) |
| 113 | + assert nans == expected, f"{nans} NaNs in out, but should be {expected}" |
20 | 114 |
|
21 | 115 |
|
22 | 116 | @given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
|
|
0 commit comments