Skip to content

Commit a435e63

Browse files
committed
Test 0d arrays conversion to scalars
1 parent 3fc57be commit a435e63

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

array_api_tests/test_array2scalar.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
from hypothesis import given
3+
from hypothesis import strategies as st
4+
5+
from . import _array_module as xp
6+
from . import dtype_helpers as dh
7+
from . import xps
8+
from .typing import DataType, Param
9+
10+
method_stype = {
11+
"__bool__": bool,
12+
"__int__": int,
13+
"__index__": int,
14+
"__float__": float,
15+
}
16+
17+
18+
def make_param(method_name: str, dtype: DataType) -> Param:
19+
stype = method_stype[method_name]
20+
return pytest.param(
21+
method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})"
22+
)
23+
24+
25+
@pytest.mark.parametrize(
26+
"method_name, dtype, stype",
27+
[make_param("__bool__", xp.bool)]
28+
+ [make_param("__int__", d) for d in dh.all_int_dtypes]
29+
+ [make_param("__index__", d) for d in dh.all_int_dtypes]
30+
+ [make_param("__float__", d) for d in dh.float_dtypes],
31+
)
32+
@given(data=st.data())
33+
def test_0d_array_can_convert_to_scalar(method_name, dtype, stype, data):
34+
x = data.draw(xps.arrays(dtype, shape=()), label="x")
35+
method = getattr(x, method_name)
36+
out = method()
37+
assert isinstance(
38+
out, stype
39+
), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar"

0 commit comments

Comments
 (0)