Skip to content

Commit dcc4adf

Browse files
committed
Move axes() strategy to hypothesis_helpers.py
1 parent a435e63 commit dcc4adf

File tree

5 files changed

+29
-30
lines changed

5 files changed

+29
-30
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import reduce
33
from math import sqrt
44
from operator import mul
5-
from typing import Any, List, NamedTuple, Optional, Tuple, Sequence
5+
from typing import Any, List, NamedTuple, Optional, Tuple, Sequence, Union
66

77
from hypothesis import assume
88
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
@@ -399,3 +399,12 @@ def specified_kwargs(draw, *keys_values_defaults: KVD):
399399
if value is not default or draw(booleans()):
400400
kw[keyword] = value
401401
return kw
402+
403+
404+
def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]:
405+
"""Generate valid arguments for some axis keywords"""
406+
axes_strats = [none()]
407+
if ndim != 0:
408+
axes_strats.append(integers(-ndim, ndim - 1))
409+
axes_strats.append(xps.valid_tuple_axes(ndim))
410+
return one_of(axes_strats)

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from array_api_tests.algos import broadcast_shapes
1+
from .algos import broadcast_shapes
22
import math
33
from inspect import getfullargspec
44
from typing import Any, Dict, Optional, Tuple, Union

array_api_tests/test_searching_functions.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
from hypothesis import given
22
from hypothesis import strategies as st
33

4-
from array_api_tests.algos import broadcast_shapes
5-
from array_api_tests.test_manipulation_functions import assert_equals as assert_equals_
6-
from array_api_tests.test_statistical_functions import (
7-
assert_equals,
8-
assert_keepdimable_shape,
9-
axes_ndindex,
10-
normalise_axis,
11-
)
12-
from array_api_tests.typing import DataType
13-
144
from . import _array_module as xp
155
from . import array_helpers as ah
166
from . import dtype_helpers as dh
177
from . import hypothesis_helpers as hh
188
from . import pytest_helpers as ph
199
from . import xps
10+
from .algos import broadcast_shapes
11+
from .test_manipulation_functions import assert_equals as assert_equals_
12+
from .test_statistical_functions import (
13+
assert_equals,
14+
assert_keepdimable_shape,
15+
axes_ndindex,
16+
normalise_axis,
17+
)
18+
from .typing import DataType
2019

2120

2221
def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"):

array_api_tests/test_statistical_functions.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@
1515
from .typing import DataType, Scalar, ScalarType, Shape
1616

1717

18-
def axes(ndim: int) -> st.SearchStrategy[Optional[Union[int, Shape]]]:
19-
axes_strats = [st.none()]
20-
if ndim != 0:
21-
axes_strats.append(st.integers(-ndim, ndim - 1))
22-
axes_strats.append(xps.valid_tuple_axes(ndim))
23-
return st.one_of(axes_strats)
24-
25-
2618
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
2719
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
2820
return st.none() | st.sampled_from(dtypes)
@@ -108,7 +100,7 @@ def assert_equals(
108100
data=st.data(),
109101
)
110102
def test_max(x, data):
111-
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
103+
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
112104

113105
out = xp.max(x, **kw)
114106

@@ -137,7 +129,7 @@ def test_max(x, data):
137129
data=st.data(),
138130
)
139131
def test_mean(x, data):
140-
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
132+
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
141133

142134
out = xp.mean(x, **kw)
143135

@@ -166,7 +158,7 @@ def test_mean(x, data):
166158
data=st.data(),
167159
)
168160
def test_min(x, data):
169-
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
161+
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
170162

171163
out = xp.min(x, **kw)
172164

@@ -197,7 +189,7 @@ def test_min(x, data):
197189
def test_prod(x, data):
198190
kw = data.draw(
199191
hh.kwargs(
200-
axis=axes(x.ndim),
192+
axis=hh.axes(x.ndim),
201193
dtype=kwarg_dtypes(x.dtype),
202194
keepdims=st.booleans(),
203195
),
@@ -258,7 +250,7 @@ def test_prod(x, data):
258250
data=st.data(),
259251
)
260252
def test_std(x, data):
261-
axis = data.draw(axes(x.ndim), label="axis")
253+
axis = data.draw(hh.axes(x.ndim), label="axis")
262254
_axes = normalise_axis(axis, x.ndim)
263255
N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
264256
correction = data.draw(
@@ -295,7 +287,7 @@ def test_std(x, data):
295287
def test_sum(x, data):
296288
kw = data.draw(
297289
hh.kwargs(
298-
axis=axes(x.ndim),
290+
axis=hh.axes(x.ndim),
299291
dtype=kwarg_dtypes(x.dtype),
300292
keepdims=st.booleans(),
301293
),
@@ -356,7 +348,7 @@ def test_sum(x, data):
356348
data=st.data(),
357349
)
358350
def test_var(x, data):
359-
axis = data.draw(axes(x.ndim), label="axis")
351+
axis = data.draw(hh.axes(x.ndim), label="axis")
360352
_axes = normalise_axis(axis, x.ndim)
361353
N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
362354
correction = data.draw(

array_api_tests/test_utility_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .test_statistical_functions import (
1111
assert_equals,
1212
assert_keepdimable_shape,
13-
axes,
1413
axes_ndindex,
1514
normalise_axis,
1615
)
@@ -21,7 +20,7 @@
2120
data=st.data(),
2221
)
2322
def test_all(x, data):
24-
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
23+
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
2524

2625
out = xp.all(x, **kw)
2726

@@ -46,7 +45,7 @@ def test_all(x, data):
4645
data=st.data(),
4746
)
4847
def test_any(x, data):
49-
kw = data.draw(hh.kwargs(axis=axes(x.ndim), keepdims=st.booleans()), label="kw")
48+
kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
5049

5150
out = xp.any(x, **kw)
5251

0 commit comments

Comments
 (0)