Skip to content

Commit 4d849f1

Browse files
committed
Finish elwise TODOs
1 parent a4a7e04 commit 4d849f1

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,30 @@ def unary_assert_against_refimpl(
7676
in_stype = dh.get_scalar_type(in_.dtype)
7777
if res_stype is None:
7878
res_stype = in_stype
79+
if res.dtype != xp.bool:
80+
m, M = dh.dtype_ranges[res.dtype]
7981
for idx in sh.ndindex(in_.shape):
8082
scalar_i = in_stype(in_[idx])
8183
if not filter_(scalar_i):
8284
continue
8385
expected = refimpl(scalar_i)
86+
if res.dtype != xp.bool:
87+
if expected <= m or expected >= M:
88+
continue
8489
scalar_o = res_stype(res[idx])
8590
f_i = sh.fmt_idx("x", idx)
8691
f_o = sh.fmt_idx("out", idx)
8792
expr = expr_template.format(f_i, expected)
88-
assert scalar_o == expected, (
89-
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
90-
f"{f_i}={scalar_i}"
91-
)
93+
if dh.is_float_dtype(res.dtype):
94+
assert isclose(scalar_o, expected), (
95+
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
96+
f"{f_i}={scalar_i}"
97+
)
98+
else:
99+
assert scalar_o == expected, (
100+
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
101+
f"{f_i}={scalar_i}"
102+
)
92103

93104

94105
def binary_assert_against_refimpl(
@@ -1257,29 +1268,35 @@ def test_sin(x):
12571268
out = xp.sin(x)
12581269
ph.assert_dtype("sin", x.dtype, out.dtype)
12591270
ph.assert_shape("sin", out.shape, x.shape)
1260-
# TODO
1271+
unary_assert_against_refimpl("sin", x, out, math.sin, "sin({})={}")
12611272

12621273

12631274
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
12641275
def test_sinh(x):
12651276
out = xp.sinh(x)
12661277
ph.assert_dtype("sinh", x.dtype, out.dtype)
12671278
ph.assert_shape("sinh", out.shape, x.shape)
1268-
# TODO
1279+
unary_assert_against_refimpl("sinh", x, out, math.sinh, "sinh({})={}")
12691280

12701281

12711282
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
12721283
def test_square(x):
12731284
out = xp.square(x)
12741285
ph.assert_dtype("square", x.dtype, out.dtype)
12751286
ph.assert_shape("square", out.shape, x.shape)
1287+
unary_assert_against_refimpl("square", x, out, lambda s: s ** 2, "{}²={}")
12761288

12771289

1278-
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
1290+
@given(
1291+
xps.arrays(
1292+
dtype=xps.floating_dtypes(), shape=hh.shapes(), elements={"min_value": 0}
1293+
)
1294+
)
12791295
def test_sqrt(x):
12801296
out = xp.sqrt(x)
12811297
ph.assert_dtype("sqrt", x.dtype, out.dtype)
12821298
ph.assert_shape("sqrt", out.shape, x.shape)
1299+
unary_assert_against_refimpl("sqrt", x, out, math.sqrt, "sqrt({})={}")
12831300

12841301

12851302
@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes()))
@@ -1305,15 +1322,15 @@ def test_tan(x):
13051322
out = xp.tan(x)
13061323
ph.assert_dtype("tan", x.dtype, out.dtype)
13071324
ph.assert_shape("tan", out.shape, x.shape)
1308-
# TODO
1325+
unary_assert_against_refimpl("tan", x, out, math.tan, "tan({})={}")
13091326

13101327

13111328
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
13121329
def test_tanh(x):
13131330
out = xp.tanh(x)
13141331
ph.assert_dtype("tanh", x.dtype, out.dtype)
13151332
ph.assert_shape("tanh", out.shape, x.shape)
1316-
# TODO
1333+
unary_assert_against_refimpl("tanh", x, out, math.tanh, "tanh({})={}")
13171334

13181335

13191336
@given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes()))

0 commit comments

Comments
 (0)