@@ -59,16 +59,18 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
59
59
60
60
def unary_assert_against_refimpl (
61
61
func_name : str ,
62
- in_stype : ScalarType ,
63
62
in_ : Array ,
64
63
res : Array ,
65
64
refimpl : Callable [[Scalar ], Scalar ],
66
65
expr_template : str ,
66
+ in_stype : Optional [ScalarType ] = None ,
67
67
res_stype : Optional [ScalarType ] = None ,
68
68
ignorer : Callable [[Scalar ], bool ] = bool ,
69
69
):
70
70
if in_ .shape != res .shape :
71
71
raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
72
+ if in_stype is None :
73
+ in_stype = dh .get_scalar_type (in_ .dtype )
72
74
if res_stype is None :
73
75
res_stype = in_stype
74
76
for idx in sh .ndindex (in_ .shape ):
@@ -88,32 +90,77 @@ def unary_assert_against_refimpl(
88
90
89
91
def binary_assert_against_refimpl (
90
92
func_name : str ,
91
- in_stype : ScalarType ,
92
93
left : Array ,
93
- right : Array ,
94
+ right : Union [ Scalar , Array ] ,
94
95
res : Array ,
95
96
refimpl : Callable [[Scalar , Scalar ], Scalar ],
96
97
expr_template : str ,
98
+ in_stype : Optional [ScalarType ] = None ,
97
99
res_stype : Optional [ScalarType ] = None ,
98
100
left_sym : str = "x1" ,
99
101
right_sym : str = "x2" ,
100
- res_sym : str = "out" ,
102
+ right_is_scalar : bool = False ,
103
+ res_name : str = "out" ,
104
+ ignorer : Callable [[Scalar , Scalar ], bool ] = bool ,
101
105
):
106
+ if in_stype is None :
107
+ in_stype = dh .get_scalar_type (left .dtype )
102
108
if res_stype is None :
103
109
res_stype = in_stype
104
- for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
105
- scalar_l = in_stype (left [l_idx ])
106
- scalar_r = in_stype (right [r_idx ])
107
- expected = refimpl (scalar_l , scalar_r )
108
- scalar_o = res_stype (res [o_idx ])
109
- f_l = sh .fmt_idx (left_sym , l_idx )
110
- f_r = sh .fmt_idx (right_sym , r_idx )
111
- f_o = sh .fmt_idx (res_sym , o_idx )
112
- expr = expr_template .format (scalar_l , scalar_r , expected )
113
- assert scalar_o == expected , (
114
- f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
115
- f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
116
- )
110
+ if right_is_scalar :
111
+ if left .dtype != xp .bool :
112
+ m , M = dh .dtype_ranges [left .dtype ]
113
+ for idx in sh .ndindex (res .shape ):
114
+ scalar_l = in_stype (left [idx ])
115
+ if any (ignorer (s ) for s in [scalar_l , right ]):
116
+ continue
117
+ expected = refimpl (scalar_l , right )
118
+ if left .dtype != xp .bool :
119
+ if expected <= m or expected >= M :
120
+ continue
121
+ scalar_o = res_stype (res [idx ])
122
+ f_l = sh .fmt_idx (left_sym , idx )
123
+ f_o = sh .fmt_idx (res_name , idx )
124
+ expr = expr_template .format (scalar_l , right , expected )
125
+ if dh .is_float_dtype (left .dtype ):
126
+ assert isclose (scalar_o , expected ), (
127
+ f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
128
+ f"{ f_l } ={ scalar_l } "
129
+ )
130
+
131
+ else :
132
+ assert scalar_o == expected , (
133
+ f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
134
+ f"{ f_l } ={ scalar_l } "
135
+ )
136
+ else :
137
+ result_dtype = dh .result_type (left .dtype , right .dtype )
138
+ if result_dtype != xp .bool :
139
+ m , M = dh .dtype_ranges [result_dtype ]
140
+ for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
141
+ scalar_l = in_stype (left [l_idx ])
142
+ scalar_r = in_stype (right [r_idx ])
143
+ if any (ignorer (s ) for s in [scalar_l , scalar_r ]):
144
+ continue
145
+ expected = refimpl (scalar_l , scalar_r )
146
+ if result_dtype != xp .bool :
147
+ if expected <= m or expected >= M :
148
+ continue
149
+ scalar_o = res_stype (res [o_idx ])
150
+ f_l = sh .fmt_idx (left_sym , l_idx )
151
+ f_r = sh .fmt_idx (right_sym , r_idx )
152
+ f_o = sh .fmt_idx (res_name , o_idx )
153
+ expr = expr_template .format (scalar_l , scalar_r , expected )
154
+ if dh .is_float_dtype (result_dtype ):
155
+ assert isclose (scalar_o , expected ), (
156
+ f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
157
+ f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
158
+ )
159
+ else :
160
+ assert scalar_o == expected , (
161
+ f"{ f_o } ={ scalar_o } , but should be { expr } [{ func_name } ()]\n "
162
+ f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
163
+ )
117
164
118
165
119
166
# When appropiate, this module tests operators alongside their respective
@@ -325,7 +372,6 @@ def test_abs(ctx, data):
325
372
ph .assert_shape (ctx .func_name , out .shape , x .shape )
326
373
unary_assert_against_refimpl (
327
374
ctx .func_name ,
328
- dh .get_scalar_type (x .dtype ),
329
375
x ,
330
376
out ,
331
377
abs ,
@@ -379,37 +425,34 @@ def test_add(ctx, data):
379
425
380
426
assert_binary_param_dtype (ctx , left , right , res )
381
427
assert_binary_param_shape (ctx , left , right , res )
382
- m , M = dh .dtype_ranges [res .dtype ]
383
- scalar_type = dh .get_scalar_type (res .dtype )
384
428
if ctx .right_is_scalar :
385
- for idx in sh .ndindex (res .shape ):
386
- scalar_l = scalar_type (left [idx ])
387
- expected = scalar_l + right
388
- if not math .isfinite (expected ) or expected <= m or expected >= M :
389
- continue
390
- scalar_o = scalar_type (res [idx ])
391
- f_l = sh .fmt_idx (ctx .left_sym , idx )
392
- f_o = sh .fmt_idx (ctx .res_name , idx )
393
- assert isclose (scalar_o , expected ), (
394
- f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } + { right } )={ expected } "
395
- f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } "
396
- )
429
+ binary_assert_against_refimpl (
430
+ func_name = ctx .func_name ,
431
+ left_sym = ctx .left_sym ,
432
+ left = left ,
433
+ right_sym = ctx .right_sym ,
434
+ right = right ,
435
+ right_is_scalar = True ,
436
+ res_name = ctx .res_name ,
437
+ res = res ,
438
+ refimpl = operator .add ,
439
+ expr_template = "({} + {})={}" ,
440
+ ignorer = lambda s : not math .isfinite (s ),
441
+ )
397
442
else :
398
443
ph .assert_array (ctx .func_name , res , ctx .func (right , left )) # cumulative
399
- for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
400
- scalar_l = scalar_type (left [l_idx ])
401
- scalar_r = scalar_type (right [r_idx ])
402
- expected = scalar_l + scalar_r
403
- if not math .isfinite (expected ) or expected <= m or expected >= M :
404
- continue
405
- scalar_o = scalar_type (res [o_idx ])
406
- f_l = sh .fmt_idx (ctx .left_sym , l_idx )
407
- f_r = sh .fmt_idx (ctx .right_sym , r_idx )
408
- f_o = sh .fmt_idx (ctx .res_name , o_idx )
409
- assert isclose (scalar_o , expected ), (
410
- f"{ f_o } ={ scalar_o } , but should be roughly ({ f_l } + { f_r } )={ expected } "
411
- f"[{ ctx .func_name } ()]\n { f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
412
- )
444
+ binary_assert_against_refimpl (
445
+ func_name = ctx .func_name ,
446
+ left_sym = ctx .left_sym ,
447
+ left = left ,
448
+ right_sym = ctx .right_sym ,
449
+ right = right ,
450
+ res_name = ctx .res_name ,
451
+ res = res ,
452
+ refimpl = operator .add ,
453
+ expr_template = "({} + {})={}" ,
454
+ ignorer = lambda s : not math .isfinite (s ),
455
+ )
413
456
414
457
415
458
@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -531,11 +574,7 @@ def test_bitwise_and(ctx, data):
531
574
# for mypy
532
575
assert isinstance (scalar_l , int )
533
576
assert isinstance (right , int )
534
- expected = ah .mock_int_dtype (
535
- scalar_l & right ,
536
- dh .dtype_nbits [res .dtype ],
537
- dh .dtype_signed [res .dtype ],
538
- )
577
+ expected = mock_int_dtype (scalar_l & right , res .dtype )
539
578
scalar_o = scalar_type (res [idx ])
540
579
f_l = sh .fmt_idx (ctx .left_sym , idx )
541
580
f_o = sh .fmt_idx (ctx .res_name , idx )
@@ -553,11 +592,7 @@ def test_bitwise_and(ctx, data):
553
592
# for mypy
554
593
assert isinstance (scalar_l , int )
555
594
assert isinstance (scalar_r , int )
556
- expected = ah .mock_int_dtype (
557
- scalar_l & scalar_r ,
558
- dh .dtype_nbits [res .dtype ],
559
- dh .dtype_signed [res .dtype ],
560
- )
595
+ expected = mock_int_dtype (scalar_l & scalar_r , res .dtype )
561
596
scalar_o = scalar_type (res [o_idx ])
562
597
f_l = sh .fmt_idx (ctx .left_sym , l_idx )
563
598
f_r = sh .fmt_idx (ctx .right_sym , r_idx )
@@ -587,11 +622,10 @@ def test_bitwise_left_shift(ctx, data):
587
622
if ctx .right_is_scalar :
588
623
for idx in sh .ndindex (res .shape ):
589
624
scalar_l = int (left [idx ])
590
- expected = ah . mock_int_dtype (
625
+ expected = mock_int_dtype (
591
626
# We avoid shifting very large ints
592
627
scalar_l << right if right < dh .dtype_nbits [res .dtype ] else 0 ,
593
- dh .dtype_nbits [res .dtype ],
594
- dh .dtype_signed [res .dtype ],
628
+ res .dtype ,
595
629
)
596
630
scalar_o = int (res [idx ])
597
631
f_l = sh .fmt_idx (ctx .left_sym , idx )
@@ -604,11 +638,10 @@ def test_bitwise_left_shift(ctx, data):
604
638
for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
605
639
scalar_l = int (left [l_idx ])
606
640
scalar_r = int (right [r_idx ])
607
- expected = ah . mock_int_dtype (
641
+ expected = mock_int_dtype (
608
642
# We avoid shifting very large ints
609
643
scalar_l << scalar_r if scalar_r < dh .dtype_nbits [res .dtype ] else 0 ,
610
- dh .dtype_nbits [res .dtype ],
611
- dh .dtype_signed [res .dtype ],
644
+ res .dtype ,
612
645
)
613
646
scalar_o = int (res [o_idx ])
614
647
f_l = sh .fmt_idx (ctx .left_sym , l_idx )
@@ -636,9 +669,7 @@ def test_bitwise_invert(ctx, data):
636
669
refimpl = lambda s : not s
637
670
else :
638
671
refimpl = lambda s : mock_int_dtype (~ s , x .dtype )
639
- unary_assert_against_refimpl (
640
- ctx .func_name , dh .get_scalar_type (x .dtype ), x , out , refimpl , "~{}={}"
641
- )
672
+ unary_assert_against_refimpl (ctx .func_name , x , out , refimpl , "~{}={}" )
642
673
643
674
644
675
@pytest .mark .parametrize (
@@ -662,11 +693,7 @@ def test_bitwise_or(ctx, data):
662
693
else :
663
694
scalar_l = int (left [idx ])
664
695
scalar_o = int (res [idx ])
665
- expected = ah .mock_int_dtype (
666
- scalar_l | right ,
667
- dh .dtype_nbits [res .dtype ],
668
- dh .dtype_signed [res .dtype ],
669
- )
696
+ expected = mock_int_dtype (scalar_l | right , res .dtype )
670
697
f_l = sh .fmt_idx (ctx .left_sym , idx )
671
698
f_o = sh .fmt_idx (ctx .res_name , idx )
672
699
assert scalar_o == expected , (
@@ -684,11 +711,7 @@ def test_bitwise_or(ctx, data):
684
711
scalar_l = int (left [l_idx ])
685
712
scalar_r = int (right [r_idx ])
686
713
scalar_o = int (res [o_idx ])
687
- expected = ah .mock_int_dtype (
688
- scalar_l | scalar_r ,
689
- dh .dtype_nbits [res .dtype ],
690
- dh .dtype_signed [res .dtype ],
691
- )
714
+ expected = mock_int_dtype (scalar_l | scalar_r , res .dtype )
692
715
f_l = sh .fmt_idx (ctx .left_sym , l_idx )
693
716
f_r = sh .fmt_idx (ctx .right_sym , r_idx )
694
717
f_o = sh .fmt_idx (ctx .res_name , o_idx )
@@ -717,11 +740,7 @@ def test_bitwise_right_shift(ctx, data):
717
740
if ctx .right_is_scalar :
718
741
for idx in sh .ndindex (res .shape ):
719
742
scalar_l = int (left [idx ])
720
- expected = ah .mock_int_dtype (
721
- scalar_l >> right ,
722
- dh .dtype_nbits [res .dtype ],
723
- dh .dtype_signed [res .dtype ],
724
- )
743
+ expected = mock_int_dtype (scalar_l >> right , res .dtype )
725
744
scalar_o = int (res [idx ])
726
745
f_l = sh .fmt_idx (ctx .left_sym , idx )
727
746
f_o = sh .fmt_idx (ctx .res_name , idx )
@@ -733,11 +752,7 @@ def test_bitwise_right_shift(ctx, data):
733
752
for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
734
753
scalar_l = int (left [l_idx ])
735
754
scalar_r = int (right [r_idx ])
736
- expected = ah .mock_int_dtype (
737
- scalar_l >> scalar_r ,
738
- dh .dtype_nbits [res .dtype ],
739
- dh .dtype_signed [res .dtype ],
740
- )
755
+ expected = mock_int_dtype (scalar_l >> scalar_r , res .dtype )
741
756
scalar_o = int (res [o_idx ])
742
757
f_l = sh .fmt_idx (ctx .left_sym , l_idx )
743
758
f_r = sh .fmt_idx (ctx .right_sym , r_idx )
@@ -769,11 +784,7 @@ def test_bitwise_xor(ctx, data):
769
784
else :
770
785
scalar_l = int (left [idx ])
771
786
scalar_o = int (res [idx ])
772
- expected = ah .mock_int_dtype (
773
- scalar_l ^ right ,
774
- dh .dtype_nbits [res .dtype ],
775
- dh .dtype_signed [res .dtype ],
776
- )
787
+ expected = mock_int_dtype (scalar_l ^ right , res .dtype )
777
788
f_l = sh .fmt_idx (ctx .left_sym , idx )
778
789
f_o = sh .fmt_idx (ctx .res_name , idx )
779
790
assert scalar_o == expected , (
@@ -791,11 +802,7 @@ def test_bitwise_xor(ctx, data):
791
802
scalar_l = int (left [l_idx ])
792
803
scalar_r = int (right [r_idx ])
793
804
scalar_o = int (res [o_idx ])
794
- expected = ah .mock_int_dtype (
795
- scalar_l ^ scalar_r ,
796
- dh .dtype_nbits [res .dtype ],
797
- dh .dtype_signed [res .dtype ],
798
- )
805
+ expected = mock_int_dtype (scalar_l ^ scalar_r , res .dtype )
799
806
f_l = sh .fmt_idx (ctx .left_sym , l_idx )
800
807
f_r = sh .fmt_idx (ctx .right_sym , r_idx )
801
808
f_o = sh .fmt_idx (ctx .res_name , o_idx )
@@ -1309,13 +1316,7 @@ def test_logical_and(x1, x2):
1309
1316
ph .assert_dtype ("logical_and" , [x1 .dtype , x2 .dtype ], out .dtype )
1310
1317
ph .assert_result_shape ("logical_and" , [x1 .shape , x2 .shape ], out .shape )
1311
1318
binary_assert_against_refimpl (
1312
- "logical_and" ,
1313
- bool ,
1314
- x1 ,
1315
- x2 ,
1316
- out ,
1317
- lambda l , r : l and r ,
1318
- "({} and {})={}" ,
1319
+ "logical_and" , x1 , x2 , out , lambda l , r : l and r , "({} and {})={}"
1319
1320
)
1320
1321
1321
1322
@@ -1324,9 +1325,7 @@ def test_logical_not(x):
1324
1325
out = ah .logical_not (x )
1325
1326
ph .assert_dtype ("logical_not" , x .dtype , out .dtype )
1326
1327
ph .assert_shape ("logical_not" , out .shape , x .shape )
1327
- unary_assert_against_refimpl (
1328
- "logical_not" , bool , x , out , lambda i : not i , "(not {})={}"
1329
- )
1328
+ unary_assert_against_refimpl ("logical_not" , x , out , lambda i : not i , "(not {})={}" )
1330
1329
1331
1330
1332
1331
@given (* hh .two_mutual_arrays ([xp .bool ]))
@@ -1335,7 +1334,7 @@ def test_logical_or(x1, x2):
1335
1334
ph .assert_dtype ("logical_or" , [x1 .dtype , x2 .dtype ], out .dtype )
1336
1335
ph .assert_result_shape ("logical_or" , [x1 .shape , x2 .shape ], out .shape )
1337
1336
binary_assert_against_refimpl (
1338
- "logical_or" , bool , x1 , x2 , out , lambda l , r : l or r , "({} or {})={}"
1337
+ "logical_or" , x1 , x2 , out , lambda l , r : l or r , "({} or {})={}"
1339
1338
)
1340
1339
1341
1340
@@ -1345,7 +1344,7 @@ def test_logical_xor(x1, x2):
1345
1344
ph .assert_dtype ("logical_xor" , [x1 .dtype , x2 .dtype ], out .dtype )
1346
1345
ph .assert_result_shape ("logical_xor" , [x1 .shape , x2 .shape ], out .shape )
1347
1346
binary_assert_against_refimpl (
1348
- "logical_xor" , bool , x1 , x2 , out , lambda l , r : l ^ r , "({} ^ {})={}"
1347
+ "logical_xor" , x1 , x2 , out , lambda l , r : l ^ r , "({} ^ {})={}"
1349
1348
)
1350
1349
1351
1350
@@ -1377,9 +1376,7 @@ def test_negative(ctx, data):
1377
1376
1378
1377
ph .assert_dtype (ctx .func_name , x .dtype , out .dtype )
1379
1378
ph .assert_shape (ctx .func_name , out .shape , x .shape )
1380
- unary_assert_against_refimpl (
1381
- ctx .func_name , dh .get_scalar_type (x .dtype ), x , out , operator .neg , "-({})={}"
1382
- )
1379
+ unary_assert_against_refimpl (ctx .func_name , x , out , operator .neg , "-({})={}" )
1383
1380
1384
1381
1385
1382
@pytest .mark .parametrize ("ctx" , make_binary_params ("not_equal" , xps .scalar_dtypes ()))
0 commit comments