23
23
from . import dtype_helpers as dh
24
24
from . import hypothesis_helpers as hh
25
25
from . import pytest_helpers as ph
26
+ from . import shape_helpers as sh
26
27
from . import xps
27
28
from .algos import broadcast_shapes
28
29
from .typing import Array , DataType , Param , Scalar
@@ -377,13 +378,13 @@ def test_bitwise_and(
377
378
378
379
# Compare against the Python & operator.
379
380
if res .dtype == xp .bool :
380
- for idx in ah .ndindex (res .shape ):
381
+ for idx in sh .ndindex (res .shape ):
381
382
s_left = bool (_left [idx ])
382
383
s_right = bool (_right [idx ])
383
384
s_res = bool (res [idx ])
384
385
assert (s_left and s_right ) == s_res
385
386
else :
386
- for idx in ah .ndindex (res .shape ):
387
+ for idx in sh .ndindex (res .shape ):
387
388
s_left = int (_left [idx ])
388
389
s_right = int (_right [idx ])
389
390
s_res = int (res [idx ])
@@ -427,7 +428,7 @@ def test_bitwise_left_shift(
427
428
_right = xp .broadcast_to (right , shape )
428
429
429
430
# Compare against the Python << operator.
430
- for idx in ah .ndindex (res .shape ):
431
+ for idx in sh .ndindex (res .shape ):
431
432
s_left = int (_left [idx ])
432
433
s_right = int (_right [idx ])
433
434
s_res = int (res [idx ])
@@ -452,12 +453,12 @@ def test_bitwise_invert(func_name, func, strat, data):
452
453
ph .assert_shape (func_name , out .shape , x .shape )
453
454
# Compare against the Python ~ operator.
454
455
if out .dtype == xp .bool :
455
- for idx in ah .ndindex (out .shape ):
456
+ for idx in sh .ndindex (out .shape ):
456
457
s_x = bool (x [idx ])
457
458
s_out = bool (out [idx ])
458
459
assert (not s_x ) == s_out
459
460
else :
460
- for idx in ah .ndindex (out .shape ):
461
+ for idx in sh .ndindex (out .shape ):
461
462
s_x = int (x [idx ])
462
463
s_out = int (out [idx ])
463
464
s_invert = ah .int_to_dtype (
@@ -495,13 +496,13 @@ def test_bitwise_or(
495
496
496
497
# Compare against the Python | operator.
497
498
if res .dtype == xp .bool :
498
- for idx in ah .ndindex (res .shape ):
499
+ for idx in sh .ndindex (res .shape ):
499
500
s_left = bool (_left [idx ])
500
501
s_right = bool (_right [idx ])
501
502
s_res = bool (res [idx ])
502
503
assert (s_left or s_right ) == s_res
503
504
else :
504
- for idx in ah .ndindex (res .shape ):
505
+ for idx in sh .ndindex (res .shape ):
505
506
s_left = int (_left [idx ])
506
507
s_right = int (_right [idx ])
507
508
s_res = int (res [idx ])
@@ -547,7 +548,7 @@ def test_bitwise_right_shift(
547
548
_right = xp .broadcast_to (right , shape )
548
549
549
550
# Compare against the Python >> operator.
550
- for idx in ah .ndindex (res .shape ):
551
+ for idx in sh .ndindex (res .shape ):
551
552
s_left = int (_left [idx ])
552
553
s_right = int (_right [idx ])
553
554
s_res = int (res [idx ])
@@ -586,13 +587,13 @@ def test_bitwise_xor(
586
587
587
588
# Compare against the Python ^ operator.
588
589
if res .dtype == xp .bool :
589
- for idx in ah .ndindex (res .shape ):
590
+ for idx in sh .ndindex (res .shape ):
590
591
s_left = bool (_left [idx ])
591
592
s_right = bool (_right [idx ])
592
593
s_res = bool (res [idx ])
593
594
assert (s_left ^ s_right ) == s_res
594
595
else :
595
- for idx in ah .ndindex (res .shape ):
596
+ for idx in sh .ndindex (res .shape ):
596
597
s_left = int (_left [idx ])
597
598
s_right = int (_right [idx ])
598
599
s_res = int (res [idx ])
@@ -721,7 +722,7 @@ def test_equal(
721
722
_right = ah .asarray (_right , dtype = promoted_dtype )
722
723
723
724
scalar_type = dh .get_scalar_type (promoted_dtype )
724
- for idx in ah .ndindex (shape ):
725
+ for idx in sh .ndindex (shape ):
725
726
x1_idx = _left [idx ]
726
727
x2_idx = _right [idx ]
727
728
out_idx = out [idx ]
@@ -846,7 +847,7 @@ def test_greater(
846
847
_right = ah .asarray (_right , dtype = promoted_dtype )
847
848
848
849
scalar_type = dh .get_scalar_type (promoted_dtype )
849
- for idx in ah .ndindex (shape ):
850
+ for idx in sh .ndindex (shape ):
850
851
out_idx = out [idx ]
851
852
x1_idx = _left [idx ]
852
853
x2_idx = _right [idx ]
@@ -887,7 +888,7 @@ def test_greater_equal(
887
888
_right = ah .asarray (_right , dtype = promoted_dtype )
888
889
889
890
scalar_type = dh .get_scalar_type (promoted_dtype )
890
- for idx in ah .ndindex (shape ):
891
+ for idx in sh .ndindex (shape ):
891
892
out_idx = out [idx ]
892
893
x1_idx = _left [idx ]
893
894
x2_idx = _right [idx ]
@@ -907,7 +908,7 @@ def test_isfinite(x):
907
908
908
909
# Test the exact value by comparing to the math version
909
910
if dh .is_float_dtype (x .dtype ):
910
- for idx in ah .ndindex (x .shape ):
911
+ for idx in sh .ndindex (x .shape ):
911
912
s = float (x [idx ])
912
913
assert bool (res [idx ]) == math .isfinite (s )
913
914
@@ -925,7 +926,7 @@ def test_isinf(x):
925
926
926
927
# Test the exact value by comparing to the math version
927
928
if dh .is_float_dtype (x .dtype ):
928
- for idx in ah .ndindex (x .shape ):
929
+ for idx in sh .ndindex (x .shape ):
929
930
s = float (x [idx ])
930
931
assert bool (res [idx ]) == math .isinf (s )
931
932
@@ -943,7 +944,7 @@ def test_isnan(x):
943
944
944
945
# Test the exact value by comparing to the math version
945
946
if dh .is_float_dtype (x .dtype ):
946
- for idx in ah .ndindex (x .shape ):
947
+ for idx in sh .ndindex (x .shape ):
947
948
s = float (x [idx ])
948
949
assert bool (res [idx ]) == math .isnan (s )
949
950
@@ -979,7 +980,7 @@ def test_less(
979
980
_right = ah .asarray (_right , dtype = promoted_dtype )
980
981
981
982
scalar_type = dh .get_scalar_type (promoted_dtype )
982
- for idx in ah .ndindex (shape ):
983
+ for idx in sh .ndindex (shape ):
983
984
x1_idx = _left [idx ]
984
985
x2_idx = _right [idx ]
985
986
out_idx = out [idx ]
@@ -1020,7 +1021,7 @@ def test_less_equal(
1020
1021
_right = ah .asarray (_right , dtype = promoted_dtype )
1021
1022
1022
1023
scalar_type = dh .get_scalar_type (promoted_dtype )
1023
- for idx in ah .ndindex (shape ):
1024
+ for idx in sh .ndindex (shape ):
1024
1025
x1_idx = _left [idx ]
1025
1026
x2_idx = _right [idx ]
1026
1027
out_idx = out [idx ]
@@ -1100,15 +1101,15 @@ def test_logical_and(x1, x2):
1100
1101
_x1 = xp .broadcast_to (x1 , shape )
1101
1102
_x2 = xp .broadcast_to (x2 , shape )
1102
1103
1103
- for idx in ah .ndindex (shape ):
1104
+ for idx in sh .ndindex (shape ):
1104
1105
assert out [idx ] == (bool (_x1 [idx ]) and bool (_x2 [idx ]))
1105
1106
1106
1107
1107
1108
@given (xps .arrays (dtype = xp .bool , shape = hh .shapes ()))
1108
1109
def test_logical_not (x ):
1109
1110
out = ah .logical_not (x )
1110
1111
ph .assert_shape ("logical_not" , out .shape , x .shape )
1111
- for idx in ah .ndindex (x .shape ):
1112
+ for idx in sh .ndindex (x .shape ):
1112
1113
assert out [idx ] == (not bool (x [idx ]))
1113
1114
1114
1115
@@ -1122,7 +1123,7 @@ def test_logical_or(x1, x2):
1122
1123
_x1 = xp .broadcast_to (x1 , shape )
1123
1124
_x2 = xp .broadcast_to (x2 , shape )
1124
1125
1125
- for idx in ah .ndindex (shape ):
1126
+ for idx in sh .ndindex (shape ):
1126
1127
assert out [idx ] == (bool (_x1 [idx ]) or bool (_x2 [idx ]))
1127
1128
1128
1129
@@ -1136,7 +1137,7 @@ def test_logical_xor(x1, x2):
1136
1137
_x1 = xp .broadcast_to (x1 , shape )
1137
1138
_x2 = xp .broadcast_to (x2 , shape )
1138
1139
1139
- for idx in ah .ndindex (shape ):
1140
+ for idx in sh .ndindex (shape ):
1140
1141
assert out [idx ] == (bool (_x1 [idx ]) ^ bool (_x2 [idx ]))
1141
1142
1142
1143
@@ -1225,7 +1226,7 @@ def test_not_equal(
1225
1226
_right = ah .asarray (_right , dtype = promoted_dtype )
1226
1227
1227
1228
scalar_type = dh .get_scalar_type (promoted_dtype )
1228
- for idx in ah .ndindex (shape ):
1229
+ for idx in sh .ndindex (shape ):
1229
1230
out_idx = out [idx ]
1230
1231
x1_idx = _left [idx ]
1231
1232
x2_idx = _right [idx ]
0 commit comments