Skip to content

Commit e2b69df

Browse files
committed
Op/elwise fixes and improvements
- Fix old usage of `mock_int_dtype` - Infer `in_stype` - Allow scalar `right` for `binary_assert_against_refimpl()` - Use util in `test_add`
1 parent 799b4e6 commit e2b69df

File tree

1 file changed

+107
-110
lines changed

1 file changed

+107
-110
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 107 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,18 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
5959

6060
def unary_assert_against_refimpl(
6161
func_name: str,
62-
in_stype: ScalarType,
6362
in_: Array,
6463
res: Array,
6564
refimpl: Callable[[Scalar], Scalar],
6665
expr_template: str,
66+
in_stype: Optional[ScalarType] = None,
6767
res_stype: Optional[ScalarType] = None,
6868
ignorer: Callable[[Scalar], bool] = bool,
6969
):
7070
if in_.shape != res.shape:
7171
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
72+
if in_stype is None:
73+
in_stype = dh.get_scalar_type(in_.dtype)
7274
if res_stype is None:
7375
res_stype = in_stype
7476
for idx in sh.ndindex(in_.shape):
@@ -88,32 +90,77 @@ def unary_assert_against_refimpl(
8890

8991
def binary_assert_against_refimpl(
9092
func_name: str,
91-
in_stype: ScalarType,
9293
left: Array,
93-
right: Array,
94+
right: Union[Scalar, Array],
9495
res: Array,
9596
refimpl: Callable[[Scalar, Scalar], Scalar],
9697
expr_template: str,
98+
in_stype: Optional[ScalarType] = None,
9799
res_stype: Optional[ScalarType] = None,
98100
left_sym: str = "x1",
99101
right_sym: str = "x2",
100-
res_sym: str = "out",
102+
right_is_scalar: bool = False,
103+
res_name: str = "out",
104+
ignorer: Callable[[Scalar, Scalar], bool] = bool,
101105
):
106+
if in_stype is None:
107+
in_stype = dh.get_scalar_type(left.dtype)
102108
if res_stype is None:
103109
res_stype = in_stype
104-
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
105-
scalar_l = in_stype(left[l_idx])
106-
scalar_r = in_stype(right[r_idx])
107-
expected = refimpl(scalar_l, scalar_r)
108-
scalar_o = res_stype(res[o_idx])
109-
f_l = sh.fmt_idx(left_sym, l_idx)
110-
f_r = sh.fmt_idx(right_sym, r_idx)
111-
f_o = sh.fmt_idx(res_sym, o_idx)
112-
expr = expr_template.format(scalar_l, scalar_r, expected)
113-
assert scalar_o == expected, (
114-
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
115-
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
116-
)
110+
if right_is_scalar:
111+
if left.dtype != xp.bool:
112+
m, M = dh.dtype_ranges[left.dtype]
113+
for idx in sh.ndindex(res.shape):
114+
scalar_l = in_stype(left[idx])
115+
if any(ignorer(s) for s in [scalar_l, right]):
116+
continue
117+
expected = refimpl(scalar_l, right)
118+
if left.dtype != xp.bool:
119+
if expected <= m or expected >= M:
120+
continue
121+
scalar_o = res_stype(res[idx])
122+
f_l = sh.fmt_idx(left_sym, idx)
123+
f_o = sh.fmt_idx(res_name, idx)
124+
expr = expr_template.format(scalar_l, right, expected)
125+
if dh.is_float_dtype(left.dtype):
126+
assert isclose(scalar_o, expected), (
127+
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
128+
f"{f_l}={scalar_l}"
129+
)
130+
131+
else:
132+
assert scalar_o == expected, (
133+
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
134+
f"{f_l}={scalar_l}"
135+
)
136+
else:
137+
result_dtype = dh.result_type(left.dtype, right.dtype)
138+
if result_dtype != xp.bool:
139+
m, M = dh.dtype_ranges[result_dtype]
140+
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
141+
scalar_l = in_stype(left[l_idx])
142+
scalar_r = in_stype(right[r_idx])
143+
if any(ignorer(s) for s in [scalar_l, scalar_r]):
144+
continue
145+
expected = refimpl(scalar_l, scalar_r)
146+
if result_dtype != xp.bool:
147+
if expected <= m or expected >= M:
148+
continue
149+
scalar_o = res_stype(res[o_idx])
150+
f_l = sh.fmt_idx(left_sym, l_idx)
151+
f_r = sh.fmt_idx(right_sym, r_idx)
152+
f_o = sh.fmt_idx(res_name, o_idx)
153+
expr = expr_template.format(scalar_l, scalar_r, expected)
154+
if dh.is_float_dtype(result_dtype):
155+
assert isclose(scalar_o, expected), (
156+
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
157+
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
158+
)
159+
else:
160+
assert scalar_o == expected, (
161+
f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
162+
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
163+
)
117164

118165

119166
# When appropiate, this module tests operators alongside their respective
@@ -325,7 +372,6 @@ def test_abs(ctx, data):
325372
ph.assert_shape(ctx.func_name, out.shape, x.shape)
326373
unary_assert_against_refimpl(
327374
ctx.func_name,
328-
dh.get_scalar_type(x.dtype),
329375
x,
330376
out,
331377
abs,
@@ -379,37 +425,34 @@ def test_add(ctx, data):
379425

380426
assert_binary_param_dtype(ctx, left, right, res)
381427
assert_binary_param_shape(ctx, left, right, res)
382-
m, M = dh.dtype_ranges[res.dtype]
383-
scalar_type = dh.get_scalar_type(res.dtype)
384428
if ctx.right_is_scalar:
385-
for idx in sh.ndindex(res.shape):
386-
scalar_l = scalar_type(left[idx])
387-
expected = scalar_l + right
388-
if not math.isfinite(expected) or expected <= m or expected >= M:
389-
continue
390-
scalar_o = scalar_type(res[idx])
391-
f_l = sh.fmt_idx(ctx.left_sym, idx)
392-
f_o = sh.fmt_idx(ctx.res_name, idx)
393-
assert isclose(scalar_o, expected), (
394-
f"{f_o}={scalar_o}, but should be roughly ({f_l} + {right})={expected} "
395-
f"[{ctx.func_name}()]\n{f_l}={scalar_l}"
396-
)
429+
binary_assert_against_refimpl(
430+
func_name=ctx.func_name,
431+
left_sym=ctx.left_sym,
432+
left=left,
433+
right_sym=ctx.right_sym,
434+
right=right,
435+
right_is_scalar=True,
436+
res_name=ctx.res_name,
437+
res=res,
438+
refimpl=operator.add,
439+
expr_template="({} + {})={}",
440+
ignorer=lambda s: not math.isfinite(s),
441+
)
397442
else:
398443
ph.assert_array(ctx.func_name, res, ctx.func(right, left)) # cumulative
399-
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
400-
scalar_l = scalar_type(left[l_idx])
401-
scalar_r = scalar_type(right[r_idx])
402-
expected = scalar_l + scalar_r
403-
if not math.isfinite(expected) or expected <= m or expected >= M:
404-
continue
405-
scalar_o = scalar_type(res[o_idx])
406-
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
407-
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
408-
f_o = sh.fmt_idx(ctx.res_name, o_idx)
409-
assert isclose(scalar_o, expected), (
410-
f"{f_o}={scalar_o}, but should be roughly ({f_l} + {f_r})={expected} "
411-
f"[{ctx.func_name}()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
412-
)
444+
binary_assert_against_refimpl(
445+
func_name=ctx.func_name,
446+
left_sym=ctx.left_sym,
447+
left=left,
448+
right_sym=ctx.right_sym,
449+
right=right,
450+
res_name=ctx.res_name,
451+
res=res,
452+
refimpl=operator.add,
453+
expr_template="({} + {})={}",
454+
ignorer=lambda s: not math.isfinite(s),
455+
)
413456

414457

415458
@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
@@ -531,11 +574,7 @@ def test_bitwise_and(ctx, data):
531574
# for mypy
532575
assert isinstance(scalar_l, int)
533576
assert isinstance(right, int)
534-
expected = ah.mock_int_dtype(
535-
scalar_l & right,
536-
dh.dtype_nbits[res.dtype],
537-
dh.dtype_signed[res.dtype],
538-
)
577+
expected = mock_int_dtype(scalar_l & right, res.dtype)
539578
scalar_o = scalar_type(res[idx])
540579
f_l = sh.fmt_idx(ctx.left_sym, idx)
541580
f_o = sh.fmt_idx(ctx.res_name, idx)
@@ -553,11 +592,7 @@ def test_bitwise_and(ctx, data):
553592
# for mypy
554593
assert isinstance(scalar_l, int)
555594
assert isinstance(scalar_r, int)
556-
expected = ah.mock_int_dtype(
557-
scalar_l & scalar_r,
558-
dh.dtype_nbits[res.dtype],
559-
dh.dtype_signed[res.dtype],
560-
)
595+
expected = mock_int_dtype(scalar_l & scalar_r, res.dtype)
561596
scalar_o = scalar_type(res[o_idx])
562597
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
563598
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
@@ -587,11 +622,10 @@ def test_bitwise_left_shift(ctx, data):
587622
if ctx.right_is_scalar:
588623
for idx in sh.ndindex(res.shape):
589624
scalar_l = int(left[idx])
590-
expected = ah.mock_int_dtype(
625+
expected = mock_int_dtype(
591626
# We avoid shifting very large ints
592627
scalar_l << right if right < dh.dtype_nbits[res.dtype] else 0,
593-
dh.dtype_nbits[res.dtype],
594-
dh.dtype_signed[res.dtype],
628+
res.dtype,
595629
)
596630
scalar_o = int(res[idx])
597631
f_l = sh.fmt_idx(ctx.left_sym, idx)
@@ -604,11 +638,10 @@ def test_bitwise_left_shift(ctx, data):
604638
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
605639
scalar_l = int(left[l_idx])
606640
scalar_r = int(right[r_idx])
607-
expected = ah.mock_int_dtype(
641+
expected = mock_int_dtype(
608642
# We avoid shifting very large ints
609643
scalar_l << scalar_r if scalar_r < dh.dtype_nbits[res.dtype] else 0,
610-
dh.dtype_nbits[res.dtype],
611-
dh.dtype_signed[res.dtype],
644+
res.dtype,
612645
)
613646
scalar_o = int(res[o_idx])
614647
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
@@ -636,9 +669,7 @@ def test_bitwise_invert(ctx, data):
636669
refimpl = lambda s: not s
637670
else:
638671
refimpl = lambda s: mock_int_dtype(~s, x.dtype)
639-
unary_assert_against_refimpl(
640-
ctx.func_name, dh.get_scalar_type(x.dtype), x, out, refimpl, "~{}={}"
641-
)
672+
unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, "~{}={}")
642673

643674

644675
@pytest.mark.parametrize(
@@ -662,11 +693,7 @@ def test_bitwise_or(ctx, data):
662693
else:
663694
scalar_l = int(left[idx])
664695
scalar_o = int(res[idx])
665-
expected = ah.mock_int_dtype(
666-
scalar_l | right,
667-
dh.dtype_nbits[res.dtype],
668-
dh.dtype_signed[res.dtype],
669-
)
696+
expected = mock_int_dtype(scalar_l | right, res.dtype)
670697
f_l = sh.fmt_idx(ctx.left_sym, idx)
671698
f_o = sh.fmt_idx(ctx.res_name, idx)
672699
assert scalar_o == expected, (
@@ -684,11 +711,7 @@ def test_bitwise_or(ctx, data):
684711
scalar_l = int(left[l_idx])
685712
scalar_r = int(right[r_idx])
686713
scalar_o = int(res[o_idx])
687-
expected = ah.mock_int_dtype(
688-
scalar_l | scalar_r,
689-
dh.dtype_nbits[res.dtype],
690-
dh.dtype_signed[res.dtype],
691-
)
714+
expected = mock_int_dtype(scalar_l | scalar_r, res.dtype)
692715
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
693716
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
694717
f_o = sh.fmt_idx(ctx.res_name, o_idx)
@@ -717,11 +740,7 @@ def test_bitwise_right_shift(ctx, data):
717740
if ctx.right_is_scalar:
718741
for idx in sh.ndindex(res.shape):
719742
scalar_l = int(left[idx])
720-
expected = ah.mock_int_dtype(
721-
scalar_l >> right,
722-
dh.dtype_nbits[res.dtype],
723-
dh.dtype_signed[res.dtype],
724-
)
743+
expected = mock_int_dtype(scalar_l >> right, res.dtype)
725744
scalar_o = int(res[idx])
726745
f_l = sh.fmt_idx(ctx.left_sym, idx)
727746
f_o = sh.fmt_idx(ctx.res_name, idx)
@@ -733,11 +752,7 @@ def test_bitwise_right_shift(ctx, data):
733752
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
734753
scalar_l = int(left[l_idx])
735754
scalar_r = int(right[r_idx])
736-
expected = ah.mock_int_dtype(
737-
scalar_l >> scalar_r,
738-
dh.dtype_nbits[res.dtype],
739-
dh.dtype_signed[res.dtype],
740-
)
755+
expected = mock_int_dtype(scalar_l >> scalar_r, res.dtype)
741756
scalar_o = int(res[o_idx])
742757
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
743758
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
@@ -769,11 +784,7 @@ def test_bitwise_xor(ctx, data):
769784
else:
770785
scalar_l = int(left[idx])
771786
scalar_o = int(res[idx])
772-
expected = ah.mock_int_dtype(
773-
scalar_l ^ right,
774-
dh.dtype_nbits[res.dtype],
775-
dh.dtype_signed[res.dtype],
776-
)
787+
expected = mock_int_dtype(scalar_l ^ right, res.dtype)
777788
f_l = sh.fmt_idx(ctx.left_sym, idx)
778789
f_o = sh.fmt_idx(ctx.res_name, idx)
779790
assert scalar_o == expected, (
@@ -791,11 +802,7 @@ def test_bitwise_xor(ctx, data):
791802
scalar_l = int(left[l_idx])
792803
scalar_r = int(right[r_idx])
793804
scalar_o = int(res[o_idx])
794-
expected = ah.mock_int_dtype(
795-
scalar_l ^ scalar_r,
796-
dh.dtype_nbits[res.dtype],
797-
dh.dtype_signed[res.dtype],
798-
)
805+
expected = mock_int_dtype(scalar_l ^ scalar_r, res.dtype)
799806
f_l = sh.fmt_idx(ctx.left_sym, l_idx)
800807
f_r = sh.fmt_idx(ctx.right_sym, r_idx)
801808
f_o = sh.fmt_idx(ctx.res_name, o_idx)
@@ -1309,13 +1316,7 @@ def test_logical_and(x1, x2):
13091316
ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype)
13101317
ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape)
13111318
binary_assert_against_refimpl(
1312-
"logical_and",
1313-
bool,
1314-
x1,
1315-
x2,
1316-
out,
1317-
lambda l, r: l and r,
1318-
"({} and {})={}",
1319+
"logical_and", x1, x2, out, lambda l, r: l and r, "({} and {})={}"
13191320
)
13201321

13211322

@@ -1324,9 +1325,7 @@ def test_logical_not(x):
13241325
out = ah.logical_not(x)
13251326
ph.assert_dtype("logical_not", x.dtype, out.dtype)
13261327
ph.assert_shape("logical_not", out.shape, x.shape)
1327-
unary_assert_against_refimpl(
1328-
"logical_not", bool, x, out, lambda i: not i, "(not {})={}"
1329-
)
1328+
unary_assert_against_refimpl("logical_not", x, out, lambda i: not i, "(not {})={}")
13301329

13311330

13321331
@given(*hh.two_mutual_arrays([xp.bool]))
@@ -1335,7 +1334,7 @@ def test_logical_or(x1, x2):
13351334
ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype)
13361335
ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape)
13371336
binary_assert_against_refimpl(
1338-
"logical_or", bool, x1, x2, out, lambda l, r: l or r, "({} or {})={}"
1337+
"logical_or", x1, x2, out, lambda l, r: l or r, "({} or {})={}"
13391338
)
13401339

13411340

@@ -1345,7 +1344,7 @@ def test_logical_xor(x1, x2):
13451344
ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype)
13461345
ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape)
13471346
binary_assert_against_refimpl(
1348-
"logical_xor", bool, x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}"
1347+
"logical_xor", x1, x2, out, lambda l, r: l ^ r, "({} ^ {})={}"
13491348
)
13501349

13511350

@@ -1377,9 +1376,7 @@ def test_negative(ctx, data):
13771376

13781377
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
13791378
ph.assert_shape(ctx.func_name, out.shape, x.shape)
1380-
unary_assert_against_refimpl(
1381-
ctx.func_name, dh.get_scalar_type(x.dtype), x, out, operator.neg, "-({})={}"
1382-
)
1379+
unary_assert_against_refimpl(ctx.func_name, x, out, operator.neg, "-({})={}")
13831380

13841381

13851382
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes()))

0 commit comments

Comments
 (0)