Skip to content

Commit 8267a7c

Browse files
committed
Cover everything in test_all
1 parent 131dd31 commit 8267a7c

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

array_api_tests/test_set_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# TODO: disable if opted out
1+
# TODO: disable if opted out, refactor things
22
import math
33
from collections import Counter, defaultdict
44

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,48 @@
11
from hypothesis import given
2+
from hypothesis import strategies as st
23

34
from . import _array_module as xp
5+
from . import array_helpers as ah
6+
from . import dtype_helpers as dh
47
from . import hypothesis_helpers as hh
8+
from . import pytest_helpers as ph
59
from . import xps
10+
from .test_statistical_functions import (
11+
assert_equals,
12+
assert_keepdimable_shape,
13+
axes,
14+
axes_ndindex,
15+
normalise_axis,
16+
)
617

718

8-
# TODO: generate kwargs
9-
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
10-
def test_any(x):
11-
xp.any(x)
12-
# TODO
19+
@given(
20+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)),
21+
data=st.data(),
22+
)
23+
def test_all(x, data):
24+
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
25+
26+
out = xp.all(x, **kw)
27+
28+
ph.assert_dtype("all", x.dtype, out.dtype, xp.bool)
29+
_axes = normalise_axis(kw.get("axis", None), x.ndim)
30+
assert_keepdimable_shape(
31+
"all", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
32+
)
33+
scalar_type = dh.get_scalar_type(x.dtype)
34+
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
35+
result = bool(out[out_idx])
36+
elements = []
37+
for idx in indices:
38+
s = scalar_type(x[idx])
39+
elements.append(s)
40+
expected = all(elements)
41+
assert_equals("all", scalar_type, out_idx, result, expected)
1342

1443

1544
# TODO: generate kwargs
1645
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()))
17-
def test_all(x):
18-
xp.all(x)
46+
def test_any(x):
47+
xp.any(x)
1948
# TODO

0 commit comments

Comments
 (0)