Skip to content

Commit 9edcfcc

Browse files
committed
Favour use of operator for refimpl
1 parent e72184e commit 9edcfcc

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def test_bitwise_and(ctx, data):
561561
binary_param_assert_dtype(ctx, left, right, res)
562562
binary_param_assert_shape(ctx, left, right, res)
563563
if left.dtype == xp.bool:
564-
refimpl = lambda l, r: l and r
564+
refimpl = operator.and_
565565
else:
566566
refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype)
567567
binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl)
@@ -583,15 +583,9 @@ def test_bitwise_left_shift(ctx, data):
583583

584584
binary_param_assert_dtype(ctx, left, right, res)
585585
binary_param_assert_shape(ctx, left, right, res)
586+
nbits = res.dtype
586587
binary_param_assert_against_refimpl(
587-
ctx,
588-
left,
589-
right,
590-
res,
591-
"<<",
592-
lambda l, r: (
593-
mock_int_dtype(l << r, res.dtype) if r < dh.dtype_nbits[res.dtype] else 0
594-
),
588+
ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0
595589
)
596590

597591

@@ -607,7 +601,7 @@ def test_bitwise_invert(ctx, data):
607601
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
608602
ph.assert_shape(ctx.func_name, out.shape, x.shape)
609603
if x.dtype == xp.bool:
610-
refimpl = lambda s: not s
604+
refimpl = operator.not_
611605
else:
612606
refimpl = lambda s: mock_int_dtype(~s, x.dtype)
613607
unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}")
@@ -626,7 +620,7 @@ def test_bitwise_or(ctx, data):
626620
binary_param_assert_dtype(ctx, left, right, res)
627621
binary_param_assert_shape(ctx, left, right, res)
628622
if left.dtype == xp.bool:
629-
refimpl = lambda l, r: l or r
623+
refimpl = operator.or_
630624
else:
631625
refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype)
632626
binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl)
@@ -649,12 +643,7 @@ def test_bitwise_right_shift(ctx, data):
649643
binary_param_assert_dtype(ctx, left, right, res)
650644
binary_param_assert_shape(ctx, left, right, res)
651645
binary_param_assert_against_refimpl(
652-
ctx,
653-
left,
654-
right,
655-
res,
656-
">>",
657-
lambda l, r: mock_int_dtype(l >> r, res.dtype),
646+
ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype)
658647
)
659648

660649

@@ -943,14 +932,16 @@ def test_log10(x):
943932
)
944933

945934

935+
def logaddexp(l: float, r: float) -> float:
936+
return math.log(math.exp(l) + math.exp(r))
937+
938+
946939
@given(*hh.two_mutual_arrays(dh.float_dtypes))
947940
def test_logaddexp(x1, x2):
948941
out = xp.logaddexp(x1, x2)
949942
ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype)
950943
ph.assert_result_shape("logaddexp", [x1.shape, x2.shape], out.shape)
951-
binary_assert_against_refimpl(
952-
"logaddexp", x1, x2, out, lambda l, r: math.log(math.exp(l) + math.exp(r))
953-
)
944+
binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp)
954945

955946

956947
@given(*hh.two_mutual_arrays([xp.bool]))
@@ -959,7 +950,7 @@ def test_logical_and(x1, x2):
959950
ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype)
960951
ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape)
961952
binary_assert_against_refimpl(
962-
"logical_and", x1, x2, out, lambda l, r: l and r, expr_template="({} and {})={}"
953+
"logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}"
963954
)
964955

965956

@@ -969,7 +960,7 @@ def test_logical_not(x):
969960
ph.assert_dtype("logical_not", x.dtype, out.dtype)
970961
ph.assert_shape("logical_not", out.shape, x.shape)
971962
unary_assert_against_refimpl(
972-
"logical_not", x, out, lambda i: not i, expr_template="(not {})={}"
963+
"logical_not", x, out, operator.not_, expr_template="(not {})={}"
973964
)
974965

975966

@@ -979,7 +970,7 @@ def test_logical_or(x1, x2):
979970
ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype)
980971
ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape)
981972
binary_assert_against_refimpl(
982-
"logical_or", x1, x2, out, lambda l, r: l or r, expr_template="({} or {})={}"
973+
"logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}"
983974
)
984975

985976

0 commit comments

Comments
 (0)