Skip to content

Commit 1273270

Browse files
committed
Move shape-related helpers into shape_helpers.py
1 parent 6194b08 commit 1273270

12 files changed

+162
-171
lines changed

array_api_tests/array_helpers.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import itertools
2-
31
from ._array_module import (isnan, all, any, equal, not_equal, logical_and,
42
logical_or, isfinite, greater, less, less_equal,
53
zeros, ones, full, bool, int8, int16, int32,
@@ -23,7 +21,7 @@
2321
'assert_isinf', 'positive_mathematical_sign',
2422
'assert_positive_mathematical_sign', 'negative_mathematical_sign',
2523
'assert_negative_mathematical_sign', 'same_sign',
26-
'assert_same_sign', 'ndindex', 'float64',
24+
'assert_same_sign', 'float64',
2725
'asarray', 'full', 'true', 'false', 'isnan']
2826

2927
def zero(shape, dtype):
@@ -319,13 +317,3 @@ def int_to_dtype(x, n, signed):
319317
if x & highest_bit:
320318
x = -((~x & mask) + 1)
321319
return x
322-
323-
def ndindex(shape):
324-
"""
325-
Iterator of n-D indices to an array
326-
327-
Yields tuples of integers to index every element of an array of shape
328-
`shape`. Same as np.ndindex().
329-
330-
"""
331-
return itertools.product(*[range(i) for i in shape])

array_api_tests/hypothesis_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import reduce
33
from math import sqrt
44
from operator import mul
5-
from typing import Any, List, NamedTuple, Optional, Tuple, Sequence, Union
5+
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
66

77
from hypothesis import assume
88
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
@@ -11,15 +11,15 @@
1111

1212
from . import _array_module as xp
1313
from . import dtype_helpers as dh
14+
from . import shape_helpers as sh
1415
from . import xps
1516
from ._array_module import _UndefinedStub
1617
from ._array_module import bool as bool_dtype
1718
from ._array_module import broadcast_to, eye, float32, float64, full
18-
from .array_helpers import ndindex
19+
from .algos import broadcast_shapes
1920
from .function_stubs import elementwise_functions
2021
from .pytest_helpers import nargs
2122
from .typing import Array, DataType, Shape
22-
from .algos import broadcast_shapes
2323

2424
# Set this to True to not fail tests just because a dtype isn't implemented.
2525
# If no compatible dtype is implemented for a given test, the test will fail
@@ -208,7 +208,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
208208
assume(xp.all(xp.abs(d) > 0.5))
209209

210210
a = xp.zeros(shape)
211-
for j, (idx, i) in enumerate(itertools.product(ndindex(stack_shape), range(n))):
211+
for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))):
212212
a[idx + (i, i)] = d[j]
213213
return a
214214

array_api_tests/meta/test_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import pytest
22

3-
from .. import array_helpers as ah
3+
from .. import shape_helpers as sh
44
from ..test_creation_functions import frange
5-
from ..test_manipulation_functions import axis_ndindex
65
from ..test_signatures import extension_module
7-
from ..test_statistical_functions import axes_ndindex
86

97

108
def test_extension_module_is_extension():
@@ -34,7 +32,7 @@ def test_frange(r, size, elements):
3432
[((), [()])],
3533
)
3634
def test_ndindex(shape, expected):
37-
assert list(ah.ndindex(shape)) == expected
35+
assert list(sh.ndindex(shape)) == expected
3836

3937

4038
@pytest.mark.parametrize(
@@ -50,7 +48,7 @@ def test_ndindex(shape, expected):
5048
],
5149
)
5250
def test_axis_ndindex(shape, axis, expected):
53-
assert list(axis_ndindex(shape, axis)) == expected
51+
assert list(sh.axis_ndindex(shape, axis)) == expected
5452

5553

5654
@pytest.mark.parametrize(
@@ -69,4 +67,4 @@ def test_axis_ndindex(shape, axis, expected):
6967
],
7068
)
7169
def test_axes_ndindex(shape, axes, expected):
72-
assert list(axes_ndindex(shape, axes)) == expected
70+
assert list(sh.axes_ndindex(shape, axes)) == expected

array_api_tests/shape_helpers.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from itertools import product
2+
from typing import Iterator, List, Optional, Tuple, Union
3+
4+
from .typing import Shape
5+
6+
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"]
7+
8+
9+
def normalise_axis(
10+
axis: Optional[Union[int, Tuple[int, ...]]], ndim: int
11+
) -> Tuple[int, ...]:
12+
if axis is None:
13+
return tuple(range(ndim))
14+
axes = axis if isinstance(axis, tuple) else (axis,)
15+
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
16+
return axes
17+
18+
19+
def ndindex(shape):
20+
"""Iterator of n-D indices to an array
21+
22+
Yields tuples of integers to index every element of an array of shape
23+
`shape`. Same as np.ndindex().
24+
"""
25+
return product(*[range(i) for i in shape])
26+
27+
28+
def axis_ndindex(
29+
shape: Shape, axis: int
30+
) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
31+
"""Generate indices that index all elements in dimensions beyond `axis`"""
32+
assert axis >= 0 # sanity check
33+
axis_indices = [range(side) for side in shape[:axis]]
34+
for _ in range(axis, len(shape)):
35+
axis_indices.append([slice(None, None)])
36+
yield from product(*axis_indices)
37+
38+
39+
def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
40+
"""Generate indices that index all elements except in `axes` dimensions"""
41+
base_indices = []
42+
axes_indices = []
43+
for axis, side in enumerate(shape):
44+
if axis in axes:
45+
base_indices.append([None])
46+
axes_indices.append(range(side))
47+
else:
48+
base_indices.append(range(side))
49+
axes_indices.append([None])
50+
for base_idx in product(*base_indices):
51+
indices = []
52+
for idx in product(*axes_indices):
53+
idx = list(idx)
54+
for axis, side in enumerate(idx):
55+
if axis not in axes:
56+
idx[axis] = base_idx[axis]
57+
idx = tuple(idx)
58+
indices.append(idx)
59+
yield list(indices)

array_api_tests/test_elementwise_functions.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from . import dtype_helpers as dh
2424
from . import hypothesis_helpers as hh
2525
from . import pytest_helpers as ph
26+
from . import shape_helpers as sh
2627
from . import xps
2728
from .algos import broadcast_shapes
2829
from .typing import Array, DataType, Param, Scalar
@@ -377,13 +378,13 @@ def test_bitwise_and(
377378

378379
# Compare against the Python & operator.
379380
if res.dtype == xp.bool:
380-
for idx in ah.ndindex(res.shape):
381+
for idx in sh.ndindex(res.shape):
381382
s_left = bool(_left[idx])
382383
s_right = bool(_right[idx])
383384
s_res = bool(res[idx])
384385
assert (s_left and s_right) == s_res
385386
else:
386-
for idx in ah.ndindex(res.shape):
387+
for idx in sh.ndindex(res.shape):
387388
s_left = int(_left[idx])
388389
s_right = int(_right[idx])
389390
s_res = int(res[idx])
@@ -427,7 +428,7 @@ def test_bitwise_left_shift(
427428
_right = xp.broadcast_to(right, shape)
428429

429430
# Compare against the Python << operator.
430-
for idx in ah.ndindex(res.shape):
431+
for idx in sh.ndindex(res.shape):
431432
s_left = int(_left[idx])
432433
s_right = int(_right[idx])
433434
s_res = int(res[idx])
@@ -452,12 +453,12 @@ def test_bitwise_invert(func_name, func, strat, data):
452453
ph.assert_shape(func_name, out.shape, x.shape)
453454
# Compare against the Python ~ operator.
454455
if out.dtype == xp.bool:
455-
for idx in ah.ndindex(out.shape):
456+
for idx in sh.ndindex(out.shape):
456457
s_x = bool(x[idx])
457458
s_out = bool(out[idx])
458459
assert (not s_x) == s_out
459460
else:
460-
for idx in ah.ndindex(out.shape):
461+
for idx in sh.ndindex(out.shape):
461462
s_x = int(x[idx])
462463
s_out = int(out[idx])
463464
s_invert = ah.int_to_dtype(
@@ -495,13 +496,13 @@ def test_bitwise_or(
495496

496497
# Compare against the Python | operator.
497498
if res.dtype == xp.bool:
498-
for idx in ah.ndindex(res.shape):
499+
for idx in sh.ndindex(res.shape):
499500
s_left = bool(_left[idx])
500501
s_right = bool(_right[idx])
501502
s_res = bool(res[idx])
502503
assert (s_left or s_right) == s_res
503504
else:
504-
for idx in ah.ndindex(res.shape):
505+
for idx in sh.ndindex(res.shape):
505506
s_left = int(_left[idx])
506507
s_right = int(_right[idx])
507508
s_res = int(res[idx])
@@ -547,7 +548,7 @@ def test_bitwise_right_shift(
547548
_right = xp.broadcast_to(right, shape)
548549

549550
# Compare against the Python >> operator.
550-
for idx in ah.ndindex(res.shape):
551+
for idx in sh.ndindex(res.shape):
551552
s_left = int(_left[idx])
552553
s_right = int(_right[idx])
553554
s_res = int(res[idx])
@@ -586,13 +587,13 @@ def test_bitwise_xor(
586587

587588
# Compare against the Python ^ operator.
588589
if res.dtype == xp.bool:
589-
for idx in ah.ndindex(res.shape):
590+
for idx in sh.ndindex(res.shape):
590591
s_left = bool(_left[idx])
591592
s_right = bool(_right[idx])
592593
s_res = bool(res[idx])
593594
assert (s_left ^ s_right) == s_res
594595
else:
595-
for idx in ah.ndindex(res.shape):
596+
for idx in sh.ndindex(res.shape):
596597
s_left = int(_left[idx])
597598
s_right = int(_right[idx])
598599
s_res = int(res[idx])
@@ -721,7 +722,7 @@ def test_equal(
721722
_right = ah.asarray(_right, dtype=promoted_dtype)
722723

723724
scalar_type = dh.get_scalar_type(promoted_dtype)
724-
for idx in ah.ndindex(shape):
725+
for idx in sh.ndindex(shape):
725726
x1_idx = _left[idx]
726727
x2_idx = _right[idx]
727728
out_idx = out[idx]
@@ -846,7 +847,7 @@ def test_greater(
846847
_right = ah.asarray(_right, dtype=promoted_dtype)
847848

848849
scalar_type = dh.get_scalar_type(promoted_dtype)
849-
for idx in ah.ndindex(shape):
850+
for idx in sh.ndindex(shape):
850851
out_idx = out[idx]
851852
x1_idx = _left[idx]
852853
x2_idx = _right[idx]
@@ -887,7 +888,7 @@ def test_greater_equal(
887888
_right = ah.asarray(_right, dtype=promoted_dtype)
888889

889890
scalar_type = dh.get_scalar_type(promoted_dtype)
890-
for idx in ah.ndindex(shape):
891+
for idx in sh.ndindex(shape):
891892
out_idx = out[idx]
892893
x1_idx = _left[idx]
893894
x2_idx = _right[idx]
@@ -907,7 +908,7 @@ def test_isfinite(x):
907908

908909
# Test the exact value by comparing to the math version
909910
if dh.is_float_dtype(x.dtype):
910-
for idx in ah.ndindex(x.shape):
911+
for idx in sh.ndindex(x.shape):
911912
s = float(x[idx])
912913
assert bool(res[idx]) == math.isfinite(s)
913914

@@ -925,7 +926,7 @@ def test_isinf(x):
925926

926927
# Test the exact value by comparing to the math version
927928
if dh.is_float_dtype(x.dtype):
928-
for idx in ah.ndindex(x.shape):
929+
for idx in sh.ndindex(x.shape):
929930
s = float(x[idx])
930931
assert bool(res[idx]) == math.isinf(s)
931932

@@ -943,7 +944,7 @@ def test_isnan(x):
943944

944945
# Test the exact value by comparing to the math version
945946
if dh.is_float_dtype(x.dtype):
946-
for idx in ah.ndindex(x.shape):
947+
for idx in sh.ndindex(x.shape):
947948
s = float(x[idx])
948949
assert bool(res[idx]) == math.isnan(s)
949950

@@ -979,7 +980,7 @@ def test_less(
979980
_right = ah.asarray(_right, dtype=promoted_dtype)
980981

981982
scalar_type = dh.get_scalar_type(promoted_dtype)
982-
for idx in ah.ndindex(shape):
983+
for idx in sh.ndindex(shape):
983984
x1_idx = _left[idx]
984985
x2_idx = _right[idx]
985986
out_idx = out[idx]
@@ -1020,7 +1021,7 @@ def test_less_equal(
10201021
_right = ah.asarray(_right, dtype=promoted_dtype)
10211022

10221023
scalar_type = dh.get_scalar_type(promoted_dtype)
1023-
for idx in ah.ndindex(shape):
1024+
for idx in sh.ndindex(shape):
10241025
x1_idx = _left[idx]
10251026
x2_idx = _right[idx]
10261027
out_idx = out[idx]
@@ -1100,15 +1101,15 @@ def test_logical_and(x1, x2):
11001101
_x1 = xp.broadcast_to(x1, shape)
11011102
_x2 = xp.broadcast_to(x2, shape)
11021103

1103-
for idx in ah.ndindex(shape):
1104+
for idx in sh.ndindex(shape):
11041105
assert out[idx] == (bool(_x1[idx]) and bool(_x2[idx]))
11051106

11061107

11071108
@given(xps.arrays(dtype=xp.bool, shape=hh.shapes()))
11081109
def test_logical_not(x):
11091110
out = ah.logical_not(x)
11101111
ph.assert_shape("logical_not", out.shape, x.shape)
1111-
for idx in ah.ndindex(x.shape):
1112+
for idx in sh.ndindex(x.shape):
11121113
assert out[idx] == (not bool(x[idx]))
11131114

11141115

@@ -1122,7 +1123,7 @@ def test_logical_or(x1, x2):
11221123
_x1 = xp.broadcast_to(x1, shape)
11231124
_x2 = xp.broadcast_to(x2, shape)
11241125

1125-
for idx in ah.ndindex(shape):
1126+
for idx in sh.ndindex(shape):
11261127
assert out[idx] == (bool(_x1[idx]) or bool(_x2[idx]))
11271128

11281129

@@ -1136,7 +1137,7 @@ def test_logical_xor(x1, x2):
11361137
_x1 = xp.broadcast_to(x1, shape)
11371138
_x2 = xp.broadcast_to(x2, shape)
11381139

1139-
for idx in ah.ndindex(shape):
1140+
for idx in sh.ndindex(shape):
11401141
assert out[idx] == (bool(_x1[idx]) ^ bool(_x2[idx]))
11411142

11421143

@@ -1225,7 +1226,7 @@ def test_not_equal(
12251226
_right = ah.asarray(_right, dtype=promoted_dtype)
12261227

12271228
scalar_type = dh.get_scalar_type(promoted_dtype)
1228-
for idx in ah.ndindex(shape):
1229+
for idx in sh.ndindex(shape):
12291230
out_idx = out[idx]
12301231
x1_idx = _left[idx]
12311232
x2_idx = _right[idx]

0 commit comments

Comments
 (0)