Skip to content

Commit c1f8a74

Browse files
Add check of computed against expected indices
1 parent aaa1ad7 commit c1f8a74

File tree

1 file changed

+113
-12
lines changed

1 file changed

+113
-12
lines changed

dpctl/tests/test_usm_ndarray_top_k.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,38 @@
2020
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2121

2222

23+
def _expected_largest_inds(inp, n, shift, k):
24+
"Computed expected top_k indices for mode='largest'"
25+
assert k < n
26+
ones_start_id = shift % (2 * n)
27+
28+
alloc_dev = inp.device
29+
30+
if ones_start_id < n:
31+
expected_inds = dpt.arange(
32+
ones_start_id, ones_start_id + k, dtype="i8", device=alloc_dev
33+
)
34+
else:
35+
# wrap-around
36+
ones_end_id = (ones_start_id + n) % (2 * n)
37+
if ones_end_id >= k:
38+
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
39+
else:
40+
expected_inds = dpt.concat(
41+
(
42+
dpt.arange(ones_end_id, dtype="i8", device=alloc_dev),
43+
dpt.arange(
44+
ones_start_id,
45+
ones_start_id + k - ones_end_id,
46+
dtype="i8",
47+
device=alloc_dev,
48+
),
49+
)
50+
)
51+
52+
return expected_inds
53+
54+
2355
@pytest.mark.parametrize(
2456
"dtype",
2557
[
@@ -38,25 +70,59 @@
3870
"c16",
3971
],
4072
)
41-
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
42-
def test_topk_1d_largest(dtype, n):
73+
@pytest.mark.parametrize("n", [33, 43, 255, 511, 1021, 8193])
74+
def test_top_k_1d_largest(dtype, n):
4375
q = get_queue_or_skip()
4476
skip_if_dtype_not_supported(dtype, q)
4577

78+
shift, k = 734, 5
4679
o = dpt.ones(n, dtype=dtype)
4780
z = dpt.zeros(n, dtype=dtype)
4881
oz = dpt.concat((o, z))
49-
inp = dpt.roll(oz, 734)
50-
k = 5
82+
inp = dpt.roll(oz, shift)
83+
84+
expected_inds = _expected_largest_inds(oz, n, shift, k)
5185

5286
s = dpt.top_k(inp, k, mode="largest")
5387
assert s.values.shape == (k,)
5488
assert s.values.dtype == inp.dtype
5589
assert s.indices.shape == (k,)
90+
assert dpt.all(s.indices == expected_inds)
5691
assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values
5792
assert dpt.all(s.values == inp[s.indices]), s.indices
5893

5994

95+
def _expected_smallest_inds(inp, n, shift, k):
96+
"Computed expected top_k indices for mode='smallest'"
97+
assert k < n
98+
zeros_start_id = (n + shift) % (2 * n)
99+
zeros_end_id = (shift) % (2 * n)
100+
101+
alloc_dev = inp.device
102+
103+
if zeros_start_id < zeros_end_id:
104+
expected_inds = dpt.arange(
105+
zeros_start_id, zeros_start_id + k, dtype="i8", device=alloc_dev
106+
)
107+
else:
108+
if zeros_end_id >= k:
109+
expected_inds = dpt.arange(k, dtype="i8", device=alloc_dev)
110+
else:
111+
expected_inds = dpt.concat(
112+
(
113+
dpt.arange(zeros_end_id, dtype="i8", device=alloc_dev),
114+
dpt.arange(
115+
zeros_start_id,
116+
zeros_start_id + k - zeros_end_id,
117+
dtype="i8",
118+
device=alloc_dev,
119+
),
120+
)
121+
)
122+
123+
return expected_inds
124+
125+
60126
@pytest.mark.parametrize(
61127
"dtype",
62128
[
@@ -75,37 +141,70 @@ def test_topk_1d_largest(dtype, n):
75141
"c16",
76142
],
77143
)
78-
@pytest.mark.parametrize("n", [33, 255, 257, 513, 1021, 8193])
79-
def test_topk_1d_smallest(dtype, n):
144+
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
145+
def test_top_k_1d_smallest(dtype, n):
80146
q = get_queue_or_skip()
81147
skip_if_dtype_not_supported(dtype, q)
82148

149+
shift, k = 734, 5
83150
o = dpt.ones(n, dtype=dtype)
84151
z = dpt.zeros(n, dtype=dtype)
85152
oz = dpt.concat((o, z))
86-
inp = dpt.roll(oz, 734)
87-
k = 5
153+
inp = dpt.roll(oz, shift)
154+
155+
expected_inds = _expected_smallest_inds(oz, n, shift, k)
88156

89157
s = dpt.top_k(inp, k, mode="smallest")
90158
assert s.values.shape == (k,)
91159
assert s.values.dtype == inp.dtype
92160
assert s.indices.shape == (k,)
161+
assert dpt.all(s.indices == expected_inds)
93162
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values
94163
assert dpt.all(s.values == inp[s.indices]), s.indices
95164

96165

97166
# triage failing top k radix implementation on CPU
98167
# replicates from Python behavior of radix sort topk implementation
99-
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
100-
def test_topk_largest_1d_radix_i1(n):
168+
@pytest.mark.parametrize(
169+
"n",
170+
[
171+
33,
172+
34,
173+
35,
174+
36,
175+
37,
176+
38,
177+
39,
178+
40,
179+
41,
180+
42,
181+
43,
182+
44,
183+
45,
184+
46,
185+
47,
186+
48,
187+
49,
188+
50,
189+
61,
190+
137,
191+
255,
192+
511,
193+
1021,
194+
8193,
195+
],
196+
)
197+
def test_top_k_largest_1d_radix_i1(n):
101198
get_queue_or_skip()
102199
dt = "i1"
103200

201+
shift, k = 734, 5
104202
o = dpt.ones(n, dtype=dt)
105203
z = dpt.zeros(n, dtype=dt)
106204
oz = dpt.concat((o, z))
107-
inp = dpt.roll(oz, 734)
108-
k = 5
205+
inp = dpt.roll(oz, shift)
206+
207+
expected_inds = _expected_largest_inds(oz, n, shift, k)
109208

110209
sorted_v = dpt.sort(inp, descending=True, kind="radixsort")
111210
argsorted = dpt.argsort(inp, descending=True, kind="radixsort")
@@ -116,4 +215,6 @@ def test_topk_largest_1d_radix_i1(n):
116215
topk_inds = dpt.copy(argsorted[:k])
117216

118217
assert dpt.all(topk_vals == dpt.ones(k, dtype=dt))
218+
assert dpt.all(topk_inds == expected_inds)
219+
119220
assert dpt.all(topk_vals == inp[topk_inds]), topk_inds

0 commit comments

Comments
 (0)