Skip to content

Commit f8ef141

Browse files
committed
Changes to tests for dpctl.tensor.any/all
- Tests refactored into more generic tests parametrized by function and identity - Randrange used to make tests more robust - Tests now cover branch in kernel for wide vs. skinny arrays
1 parent eed7e45 commit f8ef141

File tree

1 file changed

+51
-88
lines changed

1 file changed

+51
-88
lines changed
Lines changed: 51 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from random import randrange
2+
13
import numpy as np
24
import pytest
35
from numpy import AxisError
@@ -24,141 +26,102 @@
2426
]
2527

2628

29+
@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)])
2730
@pytest.mark.parametrize("dtype", _all_dtypes)
28-
def test_all_dtypes_contig(dtype):
29-
q = get_queue_or_skip()
30-
skip_if_dtype_not_supported(dtype, q)
31-
32-
x = dpt.ones(10, dtype=dtype, sycl_queue=q)
33-
res = dpt.all(x)
34-
35-
assert_equal(dpt.asnumpy(res), True)
36-
37-
x[x.size // 2] = 0
38-
res = dpt.all(x)
39-
assert_equal(dpt.asnumpy(res), False)
40-
41-
42-
@pytest.mark.parametrize("dtype", _all_dtypes)
43-
def test_all_dtypes_strided(dtype):
31+
def test_boolean_reduction_dtypes_contig(func, identity, dtype):
4432
q = get_queue_or_skip()
4533
skip_if_dtype_not_supported(dtype, q)
4634

47-
x = dpt.ones(20, dtype=dtype, sycl_queue=q)[::-2]
48-
res = dpt.all(x)
49-
assert_equal(dpt.asnumpy(res), True)
35+
x = dpt.full(10, identity, dtype=dtype, sycl_queue=q)
36+
res = func(x)
5037

51-
x[x.size // 2] = 0
52-
res = dpt.all(x)
53-
assert_equal(dpt.asnumpy(res), False)
38+
assert_equal(dpt.asnumpy(res), identity)
5439

40+
x[randrange(x.size)] = not identity
41+
res = func(x)
42+
assert_equal(dpt.asnumpy(res), not identity)
5543

56-
@pytest.mark.parametrize("dtype", _all_dtypes)
57-
def test_any_dtypes_contig(dtype):
58-
q = get_queue_or_skip()
59-
skip_if_dtype_not_supported(dtype, q)
44+
# test branch in kernel for large arrays
45+
wg_size = 4 * 32
46+
x = dpt.full((wg_size + 1), identity, dtype=dtype, sycl_queue=q)
47+
res = func(x)
48+
assert_equal(dpt.asnumpy(res), identity)
6049

61-
x = dpt.zeros(10, dtype=dtype, sycl_queue=q)
62-
res = dpt.any(x)
63-
64-
assert_equal(dpt.asnumpy(res), False)
65-
66-
x[x.size // 2] = 1
67-
res = dpt.any(x)
68-
assert_equal(dpt.asnumpy(res), True)
50+
x[randrange(x.size)] = not identity
51+
res = func(x)
52+
assert_equal(dpt.asnumpy(res), not identity)
6953

7054

55+
@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)])
7156
@pytest.mark.parametrize("dtype", _all_dtypes)
72-
def test_any_dtypes_strided(dtype):
57+
def test_boolean_reduction_dtypes_strided(func, identity, dtype):
7358
q = get_queue_or_skip()
7459
skip_if_dtype_not_supported(dtype, q)
7560

76-
x = dpt.zeros(20, dtype=dtype, sycl_queue=q)[::-2]
77-
res = dpt.any(x)
78-
assert_equal(dpt.asnumpy(res), False)
61+
x = dpt.full(20, identity, dtype=dtype, sycl_queue=q)[::-2]
62+
res = func(x)
63+
assert_equal(dpt.asnumpy(res), identity)
7964

80-
x[x.size // 2] = 1
81-
res = dpt.any(x)
82-
assert_equal(dpt.asnumpy(res), True)
65+
x[randrange(x.size)] = not identity
66+
res = func(x)
67+
assert_equal(dpt.asnumpy(res), not identity)
8368

8469

85-
def test_all_axis():
70+
@pytest.mark.parametrize("func,identity", [(dpt.all, True), (dpt.any, False)])
71+
def test_boolean_reduction_axis(func, identity):
8672
get_queue_or_skip()
8773

88-
x = dpt.ones((2, 3, 4, 5, 6), dtype="i4")
89-
res = dpt.all(x, axis=(1, 2, -1))
90-
91-
assert res.shape == (2, 5)
92-
assert_array_equal(dpt.asnumpy(res), np.full(res.shape, True))
93-
94-
# make first row of output false
95-
x[0, 0, 0, ...] = 0
96-
res = dpt.all(x, axis=(1, 2, -1))
97-
assert_array_equal(dpt.asnumpy(res[0]), np.full(res.shape[1], False))
98-
99-
100-
def test_any_axis():
101-
get_queue_or_skip()
102-
103-
x = dpt.zeros((2, 3, 4, 5, 6), dtype="i4")
104-
res = dpt.any(x, axis=(1, 2, -1))
74+
x = dpt.full((2, 3, 4, 5, 6), identity, dtype="i4")
75+
res = func(x, axis=(1, 2, -1))
10576

10677
assert res.shape == (2, 5)
107-
assert_array_equal(dpt.asnumpy(res), np.full(res.shape, False))
78+
assert_array_equal(dpt.asnumpy(res), np.full(res.shape, identity))
10879

109-
# make first row of output true
110-
x[0, 0, 0, ...] = 1
111-
res = dpt.any(x, axis=(1, 2, -1))
112-
assert_array_equal(dpt.asnumpy(res[0]), np.full(res.shape[1], True))
80+
# make first row of output negation of identity
81+
x[0, 0, 0, ...] = not identity
82+
res = func(x, axis=(1, 2, -1))
83+
assert_array_equal(dpt.asnumpy(res[0]), np.full(res.shape[1], not identity))
11384

11485

115-
def test_all_any_keepdims():
86+
@pytest.mark.parametrize("func", [dpt.all, dpt.any])
87+
def test_all_any_keepdims(func):
11688
get_queue_or_skip()
11789

11890
x = dpt.ones((2, 3, 4, 5, 6), dtype="i4")
11991

120-
res = dpt.all(x, axis=(1, 2, -1), keepdims=True)
121-
assert res.shape == (2, 1, 1, 5, 1)
122-
assert_array_equal(dpt.asnumpy(res), np.full(res.shape, True))
123-
124-
res = dpt.any(x, axis=(1, 2, -1), keepdims=True)
92+
res = func(x, axis=(1, 2, -1), keepdims=True)
12593
assert res.shape == (2, 1, 1, 5, 1)
12694
assert_array_equal(dpt.asnumpy(res), np.full(res.shape, True))
12795

12896

12997
# nan, inf, and -inf should evaluate to true
130-
def test_all_any_nan_inf():
98+
@pytest.mark.parametrize("func", [dpt.all, dpt.any])
99+
def test_boolean_reductions_nan_inf(func):
131100
q = get_queue_or_skip()
132101

133-
x = dpt.asarray([dpt.nan, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q)
134-
res = dpt.all(x)
102+
x = dpt.asarray([dpt.nan, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q)[
103+
:, dpt.newaxis
104+
]
105+
res = func(x, axis=1)
135106
assert_equal(dpt.asnumpy(res), True)
136107

137-
x = x[:, dpt.newaxis]
138-
res = dpt.any(x, axis=1)
139-
assert_array_equal(dpt.asnumpy(res), np.full(res.shape, True))
140-
141108

142-
def test_all_any_scalar():
109+
@pytest.mark.parametrize("func", [dpt.all, dpt.any])
110+
def test_boolean_reduction_scalars(func):
143111
get_queue_or_skip()
144112

145113
x = dpt.ones((), dtype="i4")
146-
dpt.all(x)
147-
dpt.any(x)
114+
func(x)
148115

149116

150-
def test_arg_validation_all_any():
117+
@pytest.mark.parametrize("func", [dpt.all, dpt.any])
118+
def test_arg_validation_boolean_reductions(func):
151119
get_queue_or_skip()
152120

153121
x = dpt.ones((4, 5), dtype="i4")
154122
d = dict()
155123

156124
with pytest.raises(TypeError):
157-
dpt.all(d)
158-
with pytest.raises(AxisError):
159-
dpt.all(x, axis=-3)
160-
161-
with pytest.raises(TypeError):
162-
dpt.any(d)
125+
func(d)
163126
with pytest.raises(AxisError):
164-
dpt.any(x, axis=-3)
127+
func(x, axis=-3)

0 commit comments

Comments
 (0)