Skip to content

Commit 3fc57be

Browse files
committed
Cover everything in test_any
1 parent 8267a7c commit 3fc57be

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

array_api_tests/test_utility_functions.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,26 @@ def test_all(x, data):
4141
assert_equals("all", scalar_type, out_idx, result, expected)
4242

4343

44-
# TODO: generate kwargs
45-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
46-
def test_any(x):
47-
xp.any(x)
48-
# TODO
44+
@given(
45+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
46+
data=st.data(),
47+
)
48+
def test_any(x, data):
49+
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
50+
51+
out = xp.any(x, **kw)
52+
53+
ph.assert_dtype("any", x.dtype, out.dtype, xp.bool)
54+
_axes = normalise_axis(kw.get("axis", None), x.ndim)
55+
assert_keepdimable_shape(
56+
"any", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
57+
)
58+
scalar_type = dh.get_scalar_type(x.dtype)
59+
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
60+
result = bool(out[out_idx])
61+
elements = []
62+
for idx in indices:
63+
s = scalar_type(x[idx])
64+
elements.append(s)
65+
expected = any(elements)
66+
assert_equals("any", scalar_type, out_idx, result, expected)

0 commit comments

Comments
 (0)