Skip to content

Commit 799b4e6

Browse files
committed
Refactor parametrized unary tests
Also moves `ah.int_to_dtype()` and renames it `mock_int_dtype()`
1 parent af6d150 commit 799b4e6

File tree

4 files changed

+69
-86
lines changed

4 files changed

+69
-86
lines changed

array_api_tests/array_helpers.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,3 @@ def same_sign(x, y):
306306
def assert_same_sign(x, y):
307307
assert all(same_sign(x, y)), "The input arrays do not have the same sign"
308308

309-
def int_to_dtype(x, n, signed):
310-
"""
311-
Convert the Python integer x into an n bit signed or unsigned number.
312-
"""
313-
mask = (1 << n) - 1
314-
x &= mask
315-
if signed:
316-
highest_bit = 1 << (n-1)
317-
if x & highest_bit:
318-
x = -((~x & mask) + 1)
319-
return x
Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
from hypothesis import given, assume
2-
from hypothesis.strategies import integers
3-
4-
from ..array_helpers import exactly_equal, notequal, int_to_dtype
5-
from ..hypothesis_helpers import integer_dtypes
6-
from ..dtype_helpers import dtype_nbits, dtype_signed
71
from .. import _array_module as xp
2+
from ..array_helpers import exactly_equal, notequal
83

94
# TODO: These meta-tests currently only work with NumPy
105

@@ -22,12 +17,3 @@ def test_notequal():
2217
res = xp.asarray([False, True, False, False, False, True, False, True])
2318
assert xp.all(xp.equal(notequal(a, b), res))
2419

25-
@given(integers(), integer_dtypes)
26-
def test_int_to_dtype(x, dtype):
27-
n = dtype_nbits[dtype]
28-
signed = dtype_signed[dtype]
29-
try:
30-
d = xp.asarray(x, dtype=dtype)
31-
except OverflowError:
32-
assume(False)
33-
assert int_to_dtype(x, n, signed) == d

array_api_tests/meta/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import pytest
2+
from hypothesis import given, reject
3+
from hypothesis import strategies as st
24

5+
from .. import _array_module as xp
6+
from .. import xps
37
from .. import shape_helpers as sh
48
from ..test_creation_functions import frange
59
from ..test_manipulation_functions import roll_ndindex
10+
from ..test_operators_and_elementwise_functions import mock_int_dtype
611
from ..test_signatures import extension_module
712

813

@@ -101,3 +106,12 @@ def test_roll_ndindex(shape, shifts, axes, expected):
101106
)
102107
def test_fmt_idx(idx, expected):
103108
assert sh.fmt_idx("x", idx) == expected
109+
110+
111+
@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes())
112+
def test_int_to_dtype(x, dtype):
113+
try:
114+
d = xp.asarray(x, dtype=dtype)
115+
except OverflowError:
116+
reject()
117+
assert mock_int_dtype(x, dtype) == d

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 54 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
import math
13+
import operator
1314
from enum import Enum, auto
1415
from typing import Callable, List, NamedTuple, Optional, Union
1516

@@ -44,6 +45,18 @@ def isclose(n1: Union[int, float], n2: Union[int, float]) -> bool:
4445
return math.isclose(n1, n2, rel_tol=0.25, abs_tol=1)
4546

4647

48+
def mock_int_dtype(n: int, dtype: DataType) -> int:
49+
"""Returns equivalent of `n` that mocks `dtype` behaviour"""
50+
nbits = dh.dtype_nbits[dtype]
51+
mask = (1 << nbits) - 1
52+
n &= mask
53+
if dh.dtype_signed[dtype]:
54+
highest_bit = 1 << (nbits - 1)
55+
if n & highest_bit:
56+
n = -((~n & mask) + 1)
57+
return n
58+
59+
4760
def unary_assert_against_refimpl(
4861
func_name: str,
4962
in_stype: ScalarType,
@@ -52,13 +65,16 @@ def unary_assert_against_refimpl(
5265
refimpl: Callable[[Scalar], Scalar],
5366
expr_template: str,
5467
res_stype: Optional[ScalarType] = None,
68+
ignorer: Callable[[Scalar], bool] = bool,
5569
):
5670
if in_.shape != res.shape:
5771
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
5872
if res_stype is None:
5973
res_stype = in_stype
6074
for idx in sh.ndindex(in_.shape):
6175
scalar_i = in_stype(in_[idx])
76+
if ignorer(scalar_i):
77+
continue
6278
expected = refimpl(scalar_i)
6379
scalar_o = res_stype(res[idx])
6480
f_i = sh.fmt_idx("x", idx)
@@ -299,25 +315,22 @@ def assert_binary_param_shape(
299315
@given(data=st.data())
300316
def test_abs(ctx, data):
301317
x = data.draw(ctx.strat, label="x")
318+
# abs of the smallest negative integer is out-of-scope
302319
if x.dtype in dh.int_dtypes:
303-
# abs of the smallest representable negative integer is not defined
304-
mask = xp.not_equal(
305-
x, ah.full(x.shape, dh.dtype_ranges[x.dtype].min, dtype=x.dtype)
306-
)
307-
x = x[mask]
320+
assume(xp.all(x > dh.dtype_ranges[x.dtype].min))
321+
308322
out = ctx.func(x)
323+
309324
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
310325
ph.assert_shape(ctx.func_name, out.shape, x.shape)
311-
assert ah.all(
312-
ah.logical_not(ah.negative_mathematical_sign(out))
313-
), f"out elements not all positively signed [{ctx.func_name}()]\n{out=}"
314-
less_zero = ah.negative_mathematical_sign(x)
315-
negx = ah.negative(x)
316-
# abs(x) = -x for x < 0
317-
ah.assert_exactly_equal(out[less_zero], negx[less_zero])
318-
# abs(x) = x for x >= 0
319-
ah.assert_exactly_equal(
320-
out[ah.logical_not(less_zero)], x[ah.logical_not(less_zero)]
326+
unary_assert_against_refimpl(
327+
ctx.func_name,
328+
dh.get_scalar_type(x.dtype),
329+
x,
330+
out,
331+
abs,
332+
"abs({})={}",
333+
ignorer=lambda s: math.isnan(s) or s is -0.0 or s == float("-infinity"),
321334
)
322335

323336

@@ -518,7 +531,7 @@ def test_bitwise_and(ctx, data):
518531
# for mypy
519532
assert isinstance(scalar_l, int)
520533
assert isinstance(right, int)
521-
expected = ah.int_to_dtype(
534+
expected = ah.mock_int_dtype(
522535
scalar_l & right,
523536
dh.dtype_nbits[res.dtype],
524537
dh.dtype_signed[res.dtype],
@@ -540,7 +553,7 @@ def test_bitwise_and(ctx, data):
540553
# for mypy
541554
assert isinstance(scalar_l, int)
542555
assert isinstance(scalar_r, int)
543-
expected = ah.int_to_dtype(
556+
expected = ah.mock_int_dtype(
544557
scalar_l & scalar_r,
545558
dh.dtype_nbits[res.dtype],
546559
dh.dtype_signed[res.dtype],
@@ -574,7 +587,7 @@ def test_bitwise_left_shift(ctx, data):
574587
if ctx.right_is_scalar:
575588
for idx in sh.ndindex(res.shape):
576589
scalar_l = int(left[idx])
577-
expected = ah.int_to_dtype(
590+
expected = ah.mock_int_dtype(
578591
# We avoid shifting very large ints
579592
scalar_l << right if right < dh.dtype_nbits[res.dtype] else 0,
580593
dh.dtype_nbits[res.dtype],
@@ -591,7 +604,7 @@ def test_bitwise_left_shift(ctx, data):
591604
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
592605
scalar_l = int(left[l_idx])
593606
scalar_r = int(right[r_idx])
594-
expected = ah.int_to_dtype(
607+
expected = ah.mock_int_dtype(
595608
# We avoid shifting very large ints
596609
scalar_l << scalar_r if scalar_r < dh.dtype_nbits[res.dtype] else 0,
597610
dh.dtype_nbits[res.dtype],
@@ -608,8 +621,7 @@ def test_bitwise_left_shift(ctx, data):
608621

609622

610623
@pytest.mark.parametrize(
611-
"ctx",
612-
make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()),
624+
"ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes())
613625
)
614626
@given(data=st.data())
615627
def test_bitwise_invert(ctx, data):
@@ -619,23 +631,14 @@ def test_bitwise_invert(ctx, data):
619631

620632
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
621633
ph.assert_shape(ctx.func_name, out.shape, x.shape)
622-
for idx in sh.ndindex(out.shape):
623-
if out.dtype == xp.bool:
624-
scalar_x = bool(x[idx])
625-
scalar_o = bool(out[idx])
626-
expected = not scalar_x
627-
else:
628-
scalar_x = int(x[idx])
629-
scalar_o = int(out[idx])
630-
expected = ah.int_to_dtype(
631-
~scalar_x, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]
632-
)
633-
f_x = sh.fmt_idx("x", idx)
634-
f_o = sh.fmt_idx("out", idx)
635-
assert scalar_o == expected, (
636-
f"{f_o}={scalar_o}, but should be ~{f_x}={scalar_x} "
637-
f"[{ctx.func_name}()]\n{f_x}={scalar_x}"
638-
)
634+
if x.dtype == xp.bool:
635+
# invert op for booleans is weird, so use not
636+
refimpl = lambda s: not s
637+
else:
638+
refimpl = lambda s: mock_int_dtype(~s, x.dtype)
639+
unary_assert_against_refimpl(
640+
ctx.func_name, dh.get_scalar_type(x.dtype), x, out, refimpl, "~{}={}"
641+
)
639642

640643

641644
@pytest.mark.parametrize(
@@ -659,7 +662,7 @@ def test_bitwise_or(ctx, data):
659662
else:
660663
scalar_l = int(left[idx])
661664
scalar_o = int(res[idx])
662-
expected = ah.int_to_dtype(
665+
expected = ah.mock_int_dtype(
663666
scalar_l | right,
664667
dh.dtype_nbits[res.dtype],
665668
dh.dtype_signed[res.dtype],
@@ -681,7 +684,7 @@ def test_bitwise_or(ctx, data):
681684
scalar_l = int(left[l_idx])
682685
scalar_r = int(right[r_idx])
683686
scalar_o = int(res[o_idx])
684-
expected = ah.int_to_dtype(
687+
expected = ah.mock_int_dtype(
685688
scalar_l | scalar_r,
686689
dh.dtype_nbits[res.dtype],
687690
dh.dtype_signed[res.dtype],
@@ -714,7 +717,7 @@ def test_bitwise_right_shift(ctx, data):
714717
if ctx.right_is_scalar:
715718
for idx in sh.ndindex(res.shape):
716719
scalar_l = int(left[idx])
717-
expected = ah.int_to_dtype(
720+
expected = ah.mock_int_dtype(
718721
scalar_l >> right,
719722
dh.dtype_nbits[res.dtype],
720723
dh.dtype_signed[res.dtype],
@@ -730,7 +733,7 @@ def test_bitwise_right_shift(ctx, data):
730733
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
731734
scalar_l = int(left[l_idx])
732735
scalar_r = int(right[r_idx])
733-
expected = ah.int_to_dtype(
736+
expected = ah.mock_int_dtype(
734737
scalar_l >> scalar_r,
735738
dh.dtype_nbits[res.dtype],
736739
dh.dtype_signed[res.dtype],
@@ -766,7 +769,7 @@ def test_bitwise_xor(ctx, data):
766769
else:
767770
scalar_l = int(left[idx])
768771
scalar_o = int(res[idx])
769-
expected = ah.int_to_dtype(
772+
expected = ah.mock_int_dtype(
770773
scalar_l ^ right,
771774
dh.dtype_nbits[res.dtype],
772775
dh.dtype_signed[res.dtype],
@@ -788,7 +791,7 @@ def test_bitwise_xor(ctx, data):
788791
scalar_l = int(left[l_idx])
789792
scalar_r = int(right[r_idx])
790793
scalar_o = int(res[o_idx])
791-
expected = ah.int_to_dtype(
794+
expected = ah.mock_int_dtype(
792795
scalar_l ^ scalar_r,
793796
dh.dtype_nbits[res.dtype],
794797
dh.dtype_signed[res.dtype],
@@ -1366,25 +1369,17 @@ def test_multiply(ctx, data):
13661369
@given(data=st.data())
13671370
def test_negative(ctx, data):
13681371
x = data.draw(ctx.strat, label="x")
1372+
# negative of the smallest negative integer is out-of-scope
1373+
if x.dtype in dh.int_dtypes:
1374+
assume(xp.all(x > dh.dtype_ranges[x.dtype].min))
13691375

13701376
out = ctx.func(x)
13711377

13721378
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
13731379
ph.assert_shape(ctx.func_name, out.shape, x.shape)
1374-
1375-
# Negation is an involution
1376-
ah.assert_exactly_equal(x, ctx.func(out))
1377-
1378-
mask = ah.isfinite(x)
1379-
if dh.is_int_dtype(x.dtype):
1380-
minval = dh.dtype_ranges[x.dtype][0]
1381-
if minval < 0:
1382-
# negative of the smallest representable negative integer is not defined
1383-
mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype))
1384-
1385-
# Additive inverse
1386-
y = xp.add(x[mask], out[mask])
1387-
ah.assert_exactly_equal(y, ah.zero(x[mask].shape, x.dtype))
1380+
unary_assert_against_refimpl(
1381+
ctx.func_name, dh.get_scalar_type(x.dtype), x, out, operator.neg, "-({})={}"
1382+
)
13881383

13891384

13901385
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes()))
@@ -1438,8 +1433,7 @@ def test_positive(ctx, data):
14381433

14391434
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
14401435
ph.assert_shape(ctx.func_name, out.shape, x.shape)
1441-
# Positive does nothing
1442-
ah.assert_exactly_equal(out, x)
1436+
ph.assert_array(ctx.func_name, out, x)
14431437

14441438

14451439
@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes()))

0 commit comments

Comments
 (0)