10
10
"""
11
11
12
12
import math
13
+ import operator
13
14
from enum import Enum , auto
14
15
from typing import Callable , List , NamedTuple , Optional , Union
15
16
@@ -44,6 +45,18 @@ def isclose(n1: Union[int, float], n2: Union[int, float]) -> bool:
44
45
return math .isclose (n1 , n2 , rel_tol = 0.25 , abs_tol = 1 )
45
46
46
47
48
+ def mock_int_dtype (n : int , dtype : DataType ) -> int :
49
+ """Returns equivalent of `n` that mocks `dtype` behaviour"""
50
+ nbits = dh .dtype_nbits [dtype ]
51
+ mask = (1 << nbits ) - 1
52
+ n &= mask
53
+ if dh .dtype_signed [dtype ]:
54
+ highest_bit = 1 << (nbits - 1 )
55
+ if n & highest_bit :
56
+ n = - ((~ n & mask ) + 1 )
57
+ return n
58
+
59
+
47
60
def unary_assert_against_refimpl (
48
61
func_name : str ,
49
62
in_stype : ScalarType ,
@@ -52,13 +65,16 @@ def unary_assert_against_refimpl(
52
65
refimpl : Callable [[Scalar ], Scalar ],
53
66
expr_template : str ,
54
67
res_stype : Optional [ScalarType ] = None ,
68
+ ignorer : Callable [[Scalar ], bool ] = bool ,
55
69
):
56
70
if in_ .shape != res .shape :
57
71
raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
58
72
if res_stype is None :
59
73
res_stype = in_stype
60
74
for idx in sh .ndindex (in_ .shape ):
61
75
scalar_i = in_stype (in_ [idx ])
76
+ if ignorer (scalar_i ):
77
+ continue
62
78
expected = refimpl (scalar_i )
63
79
scalar_o = res_stype (res [idx ])
64
80
f_i = sh .fmt_idx ("x" , idx )
@@ -299,25 +315,22 @@ def assert_binary_param_shape(
299
315
@given (data = st .data ())
300
316
def test_abs (ctx , data ):
301
317
x = data .draw (ctx .strat , label = "x" )
318
+ # abs of the smallest negative integer is out-of-scope
302
319
if x .dtype in dh .int_dtypes :
303
- # abs of the smallest representable negative integer is not defined
304
- mask = xp .not_equal (
305
- x , ah .full (x .shape , dh .dtype_ranges [x .dtype ].min , dtype = x .dtype )
306
- )
307
- x = x [mask ]
320
+ assume (xp .all (x > dh .dtype_ranges [x .dtype ].min ))
321
+
308
322
out = ctx .func (x )
323
+
309
324
ph .assert_dtype (ctx .func_name , x .dtype , out .dtype )
310
325
ph .assert_shape (ctx .func_name , out .shape , x .shape )
311
- assert ah .all (
312
- ah .logical_not (ah .negative_mathematical_sign (out ))
313
- ), f"out elements not all positively signed [{ ctx .func_name } ()]\n { out = } "
314
- less_zero = ah .negative_mathematical_sign (x )
315
- negx = ah .negative (x )
316
- # abs(x) = -x for x < 0
317
- ah .assert_exactly_equal (out [less_zero ], negx [less_zero ])
318
- # abs(x) = x for x >= 0
319
- ah .assert_exactly_equal (
320
- out [ah .logical_not (less_zero )], x [ah .logical_not (less_zero )]
326
+ unary_assert_against_refimpl (
327
+ ctx .func_name ,
328
+ dh .get_scalar_type (x .dtype ),
329
+ x ,
330
+ out ,
331
+ abs ,
332
+ "abs({})={}" ,
333
+ ignorer = lambda s : math .isnan (s ) or s is - 0.0 or s == float ("-infinity" ),
321
334
)
322
335
323
336
@@ -518,7 +531,7 @@ def test_bitwise_and(ctx, data):
518
531
# for mypy
519
532
assert isinstance (scalar_l , int )
520
533
assert isinstance (right , int )
521
- expected = ah .int_to_dtype (
534
+ expected = ah .mock_int_dtype (
522
535
scalar_l & right ,
523
536
dh .dtype_nbits [res .dtype ],
524
537
dh .dtype_signed [res .dtype ],
@@ -540,7 +553,7 @@ def test_bitwise_and(ctx, data):
540
553
# for mypy
541
554
assert isinstance (scalar_l , int )
542
555
assert isinstance (scalar_r , int )
543
- expected = ah .int_to_dtype (
556
+ expected = ah .mock_int_dtype (
544
557
scalar_l & scalar_r ,
545
558
dh .dtype_nbits [res .dtype ],
546
559
dh .dtype_signed [res .dtype ],
@@ -574,7 +587,7 @@ def test_bitwise_left_shift(ctx, data):
574
587
if ctx .right_is_scalar :
575
588
for idx in sh .ndindex (res .shape ):
576
589
scalar_l = int (left [idx ])
577
- expected = ah .int_to_dtype (
590
+ expected = ah .mock_int_dtype (
578
591
# We avoid shifting very large ints
579
592
scalar_l << right if right < dh .dtype_nbits [res .dtype ] else 0 ,
580
593
dh .dtype_nbits [res .dtype ],
@@ -591,7 +604,7 @@ def test_bitwise_left_shift(ctx, data):
591
604
for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
592
605
scalar_l = int (left [l_idx ])
593
606
scalar_r = int (right [r_idx ])
594
- expected = ah .int_to_dtype (
607
+ expected = ah .mock_int_dtype (
595
608
# We avoid shifting very large ints
596
609
scalar_l << scalar_r if scalar_r < dh .dtype_nbits [res .dtype ] else 0 ,
597
610
dh .dtype_nbits [res .dtype ],
@@ -608,8 +621,7 @@ def test_bitwise_left_shift(ctx, data):
608
621
609
622
610
623
@pytest .mark .parametrize (
611
- "ctx" ,
612
- make_unary_params ("bitwise_invert" , boolean_and_all_integer_dtypes ()),
624
+ "ctx" , make_unary_params ("bitwise_invert" , boolean_and_all_integer_dtypes ())
613
625
)
614
626
@given (data = st .data ())
615
627
def test_bitwise_invert (ctx , data ):
@@ -619,23 +631,14 @@ def test_bitwise_invert(ctx, data):
619
631
620
632
ph .assert_dtype (ctx .func_name , x .dtype , out .dtype )
621
633
ph .assert_shape (ctx .func_name , out .shape , x .shape )
622
- for idx in sh .ndindex (out .shape ):
623
- if out .dtype == xp .bool :
624
- scalar_x = bool (x [idx ])
625
- scalar_o = bool (out [idx ])
626
- expected = not scalar_x
627
- else :
628
- scalar_x = int (x [idx ])
629
- scalar_o = int (out [idx ])
630
- expected = ah .int_to_dtype (
631
- ~ scalar_x , dh .dtype_nbits [out .dtype ], dh .dtype_signed [out .dtype ]
632
- )
633
- f_x = sh .fmt_idx ("x" , idx )
634
- f_o = sh .fmt_idx ("out" , idx )
635
- assert scalar_o == expected , (
636
- f"{ f_o } ={ scalar_o } , but should be ~{ f_x } ={ scalar_x } "
637
- f"[{ ctx .func_name } ()]\n { f_x } ={ scalar_x } "
638
- )
634
+ if x .dtype == xp .bool :
635
+ # invert op for booleans is weird, so use not
636
+ refimpl = lambda s : not s
637
+ else :
638
+ refimpl = lambda s : mock_int_dtype (~ s , x .dtype )
639
+ unary_assert_against_refimpl (
640
+ ctx .func_name , dh .get_scalar_type (x .dtype ), x , out , refimpl , "~{}={}"
641
+ )
639
642
640
643
641
644
@pytest .mark .parametrize (
@@ -659,7 +662,7 @@ def test_bitwise_or(ctx, data):
659
662
else :
660
663
scalar_l = int (left [idx ])
661
664
scalar_o = int (res [idx ])
662
- expected = ah .int_to_dtype (
665
+ expected = ah .mock_int_dtype (
663
666
scalar_l | right ,
664
667
dh .dtype_nbits [res .dtype ],
665
668
dh .dtype_signed [res .dtype ],
@@ -681,7 +684,7 @@ def test_bitwise_or(ctx, data):
681
684
scalar_l = int (left [l_idx ])
682
685
scalar_r = int (right [r_idx ])
683
686
scalar_o = int (res [o_idx ])
684
- expected = ah .int_to_dtype (
687
+ expected = ah .mock_int_dtype (
685
688
scalar_l | scalar_r ,
686
689
dh .dtype_nbits [res .dtype ],
687
690
dh .dtype_signed [res .dtype ],
@@ -714,7 +717,7 @@ def test_bitwise_right_shift(ctx, data):
714
717
if ctx .right_is_scalar :
715
718
for idx in sh .ndindex (res .shape ):
716
719
scalar_l = int (left [idx ])
717
- expected = ah .int_to_dtype (
720
+ expected = ah .mock_int_dtype (
718
721
scalar_l >> right ,
719
722
dh .dtype_nbits [res .dtype ],
720
723
dh .dtype_signed [res .dtype ],
@@ -730,7 +733,7 @@ def test_bitwise_right_shift(ctx, data):
730
733
for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
731
734
scalar_l = int (left [l_idx ])
732
735
scalar_r = int (right [r_idx ])
733
- expected = ah .int_to_dtype (
736
+ expected = ah .mock_int_dtype (
734
737
scalar_l >> scalar_r ,
735
738
dh .dtype_nbits [res .dtype ],
736
739
dh .dtype_signed [res .dtype ],
@@ -766,7 +769,7 @@ def test_bitwise_xor(ctx, data):
766
769
else :
767
770
scalar_l = int (left [idx ])
768
771
scalar_o = int (res [idx ])
769
- expected = ah .int_to_dtype (
772
+ expected = ah .mock_int_dtype (
770
773
scalar_l ^ right ,
771
774
dh .dtype_nbits [res .dtype ],
772
775
dh .dtype_signed [res .dtype ],
@@ -788,7 +791,7 @@ def test_bitwise_xor(ctx, data):
788
791
scalar_l = int (left [l_idx ])
789
792
scalar_r = int (right [r_idx ])
790
793
scalar_o = int (res [o_idx ])
791
- expected = ah .int_to_dtype (
794
+ expected = ah .mock_int_dtype (
792
795
scalar_l ^ scalar_r ,
793
796
dh .dtype_nbits [res .dtype ],
794
797
dh .dtype_signed [res .dtype ],
@@ -1366,25 +1369,17 @@ def test_multiply(ctx, data):
1366
1369
@given (data = st .data ())
1367
1370
def test_negative (ctx , data ):
1368
1371
x = data .draw (ctx .strat , label = "x" )
1372
+ # negative of the smallest negative integer is out-of-scope
1373
+ if x .dtype in dh .int_dtypes :
1374
+ assume (xp .all (x > dh .dtype_ranges [x .dtype ].min ))
1369
1375
1370
1376
out = ctx .func (x )
1371
1377
1372
1378
ph .assert_dtype (ctx .func_name , x .dtype , out .dtype )
1373
1379
ph .assert_shape (ctx .func_name , out .shape , x .shape )
1374
-
1375
- # Negation is an involution
1376
- ah .assert_exactly_equal (x , ctx .func (out ))
1377
-
1378
- mask = ah .isfinite (x )
1379
- if dh .is_int_dtype (x .dtype ):
1380
- minval = dh .dtype_ranges [x .dtype ][0 ]
1381
- if minval < 0 :
1382
- # negative of the smallest representable negative integer is not defined
1383
- mask = xp .not_equal (x , ah .full (x .shape , minval , dtype = x .dtype ))
1384
-
1385
- # Additive inverse
1386
- y = xp .add (x [mask ], out [mask ])
1387
- ah .assert_exactly_equal (y , ah .zero (x [mask ].shape , x .dtype ))
1380
+ unary_assert_against_refimpl (
1381
+ ctx .func_name , dh .get_scalar_type (x .dtype ), x , out , operator .neg , "-({})={}"
1382
+ )
1388
1383
1389
1384
1390
1385
@pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , xps .scalar_dtypes ()))
@@ -1438,8 +1433,7 @@ def test_positive(ctx, data):
1438
1433
1439
1434
ph .assert_dtype (ctx .func_name , x .dtype , out .dtype )
1440
1435
ph .assert_shape (ctx .func_name , out .shape , x .shape )
1441
- # Positive does nothing
1442
- ah .assert_exactly_equal (out , x )
1436
+ ph .assert_array (ctx .func_name , out , x )
1443
1437
1444
1438
1445
1439
@pytest .mark .parametrize ("ctx" , make_binary_params ("pow" , xps .numeric_dtypes ()))
0 commit comments