Skip to content

Commit a5fd48f

Browse files
committed
Move assertion helpers to pytest_helpers.py
1 parent 1273270 commit a5fd48f

7 files changed

+111
-108
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from .algos import broadcast_shapes
21
import math
32
from inspect import getfullargspec
43
from typing import Any, Dict, Optional, Tuple, Union
54

5+
from . import _array_module as xp
66
from . import array_helpers as ah
77
from . import dtype_helpers as dh
88
from . import function_stubs
9-
from .typing import Array, DataType, Scalar, Shape
9+
from .algos import broadcast_shapes
10+
from .typing import Array, DataType, Scalar, ScalarType, Shape
1011

1112
__all__ = [
1213
"raises",
@@ -17,8 +18,10 @@
1718
"assert_kw_dtype",
1819
"assert_default_float",
1920
"assert_default_int",
21+
"assert_default_index",
2022
"assert_shape",
2123
"assert_result_shape",
24+
"assert_keepdimable_shape",
2225
"assert_fill",
2326
]
2427

@@ -117,6 +120,15 @@ def assert_default_int(func_name: str, dtype: DataType):
117120
assert dtype == dh.default_int, msg
118121

119122

123+
def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"):
124+
f_dtype = dh.dtype_to_name[dtype]
125+
msg = (
126+
f"{repr_name}={f_dtype}, should be the default index dtype, "
127+
f"which is either int32 or int64 [{func_name}()]"
128+
)
129+
assert dtype in (xp.int32, xp.int64), msg
130+
131+
120132
def assert_shape(
121133
func_name: str,
122134
out_shape: Union[int, Shape],
@@ -155,6 +167,57 @@ def assert_result_shape(
155167
assert out_shape == expected, msg
156168

157169

170+
def assert_keepdimable_shape(
171+
func_name: str,
172+
out_shape: Shape,
173+
in_shape: Shape,
174+
axes: Tuple[int, ...],
175+
keepdims: bool,
176+
/,
177+
**kw,
178+
):
179+
if keepdims:
180+
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
181+
else:
182+
shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes)
183+
assert_shape(func_name, out_shape, shape, **kw)
184+
185+
186+
def assert_0d_equals(
187+
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
188+
):
189+
msg = (
190+
f"{out_repr}={out_val}, should be {x_repr}={x_val} "
191+
f"[{func_name}({fmt_kw(kw)})]"
192+
)
193+
if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val):
194+
assert xp.isnan(x_val), msg
195+
else:
196+
assert x_val == out_val, msg
197+
198+
199+
def assert_scalar_equals(
200+
func_name: str,
201+
type_: ScalarType,
202+
idx: Shape,
203+
out: Scalar,
204+
expected: Scalar,
205+
/,
206+
**kw,
207+
):
208+
out_repr = "out" if idx == () else f"out[{idx}]"
209+
f_func = f"{func_name}({fmt_kw(kw)})"
210+
if type_ is bool or type_ is int:
211+
msg = f"{out_repr}={out}, should be {expected} [{f_func}]"
212+
assert out == expected, msg
213+
elif math.isnan(expected):
214+
msg = f"{out_repr}={out}, should be {expected} [{f_func}]"
215+
assert math.isnan(out), msg
216+
else:
217+
msg = f"{out_repr}={out}, should be roughly {expected} [{f_func}]"
218+
assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg
219+
220+
158221
def assert_fill(
159222
func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw
160223
):

array_api_tests/test_manipulation_functions.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,6 @@ def assert_array_ndindex(
4444
assert out[out_idx] == x[x_idx], msg
4545

4646

47-
def assert_equals(
48-
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
49-
):
50-
msg = (
51-
f"{out_repr}={out_val}, should be {x_repr}={x_val} "
52-
f"[{func_name}({ph.fmt_kw(kw)})]"
53-
)
54-
if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val):
55-
assert xp.isnan(x_val), msg
56-
else:
57-
assert x_val == out_val, msg
58-
59-
6047
@st.composite
6148
def concat_shapes(draw, shape, axis):
6249
shape = list(shape)
@@ -104,7 +91,7 @@ def test_concat(dtypes, kw, data):
10491
for x_num, x in enumerate(arrays, 1):
10592
for x_idx in sh.ndindex(x.shape):
10693
out_i = next(out_indices)
107-
assert_equals(
94+
ph.assert_0d_equals(
10895
"concat",
10996
f"x{x_num}[{x_idx}]",
11097
x[x_idx],
@@ -120,7 +107,7 @@ def test_concat(dtypes, kw, data):
120107
indexed_x = x[idx]
121108
for x_idx in sh.ndindex(indexed_x.shape):
122109
out_idx = next(out_indices)
123-
assert_equals(
110+
ph.assert_0d_equals(
124111
"concat",
125112
f"x{x_num}[{f_idx}][{x_idx}]",
126113
indexed_x[x_idx],
@@ -360,7 +347,7 @@ def test_stack(shape, dtypes, kw, data):
360347
indexed_x = x[idx]
361348
for x_idx in sh.ndindex(indexed_x.shape):
362349
out_idx = next(out_indices)
363-
assert_equals(
350+
ph.assert_0d_equals(
364351
"stack",
365352
f"x{x_num}[{f_idx}][{x_idx}]",
366353
indexed_x[x_idx],

array_api_tests/test_searching_functions.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,6 @@
88
from . import shape_helpers as sh
99
from . import xps
1010
from .algos import broadcast_shapes
11-
from .test_manipulation_functions import assert_equals as assert_equals_
12-
from .test_statistical_functions import assert_equals, assert_keepdimable_shape
13-
from .typing import DataType
14-
15-
16-
def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"):
17-
f_dtype = dh.dtype_to_name[dtype]
18-
msg = (
19-
f"{repr_name}={f_dtype}, should be the default index dtype, "
20-
f"which is either int32 or int64 [{func_name}()]"
21-
)
22-
assert dtype in (xp.int32, xp.int64), msg
2311

2412

2513
@given(
@@ -41,9 +29,9 @@ def test_argmax(x, data):
4129

4230
out = xp.argmax(x, **kw)
4331

44-
assert_default_index("argmax", out.dtype)
32+
ph.assert_default_index("argmax", out.dtype)
4533
axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
46-
assert_keepdimable_shape(
34+
ph.assert_keepdimable_shape(
4735
"argmax", out.shape, x.shape, axes, kw.get("keepdims", False), **kw
4836
)
4937
scalar_type = dh.get_scalar_type(x.dtype)
@@ -54,7 +42,7 @@ def test_argmax(x, data):
5442
s = scalar_type(x[idx])
5543
elements.append(s)
5644
expected = max(range(len(elements)), key=elements.__getitem__)
57-
assert_equals("argmax", int, out_idx, max_i, expected)
45+
ph.assert_scalar_equals("argmax", int, out_idx, max_i, expected)
5846

5947

6048
@given(
@@ -76,9 +64,9 @@ def test_argmin(x, data):
7664

7765
out = xp.argmin(x, **kw)
7866

79-
assert_default_index("argmin", out.dtype)
67+
ph.assert_default_index("argmin", out.dtype)
8068
axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
81-
assert_keepdimable_shape(
69+
ph.assert_keepdimable_shape(
8270
"argmin", out.shape, x.shape, axes, kw.get("keepdims", False), **kw
8371
)
8472
scalar_type = dh.get_scalar_type(x.dtype)
@@ -89,7 +77,7 @@ def test_argmin(x, data):
8977
s = scalar_type(x[idx])
9078
elements.append(s)
9179
expected = min(range(len(elements)), key=elements.__getitem__)
92-
assert_equals("argmin", int, out_idx, min_i, expected)
80+
ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected)
9381

9482

9583
# TODO: skip if opted out
@@ -106,7 +94,7 @@ def test_nonzero(x):
10694
assert (
10795
out[i].size == size
10896
), f"out[{i}].size={x.size}, but should be out[0].size={size}"
109-
assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
97+
ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
11098
indices = []
11199
if x.dtype == xp.bool:
112100
for idx in sh.ndindex(x.shape):
@@ -151,6 +139,10 @@ def test_where(shapes, dtypes, data):
151139
_x2 = xp.broadcast_to(x2, shape)
152140
for idx in sh.ndindex(shape):
153141
if _cond[idx]:
154-
assert_equals_("where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx])
142+
ph.assert_0d_equals(
143+
"where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx]
144+
)
155145
else:
156-
assert_equals_("where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx])
146+
ph.assert_0d_equals(
147+
"where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx]
148+
)

array_api_tests/test_set_functions.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from . import pytest_helpers as ph
1111
from . import shape_helpers as sh
1212
from . import xps
13-
from .test_searching_functions import assert_default_index
1413

1514

1615
@given(
@@ -29,11 +28,15 @@ def test_unique_all(x):
2928
ph.assert_dtype(
3029
"unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype"
3130
)
32-
assert_default_index("unique_all", out.indices.dtype, repr_name="out.indices.dtype")
33-
assert_default_index(
31+
ph.assert_default_index(
32+
"unique_all", out.indices.dtype, repr_name="out.indices.dtype"
33+
)
34+
ph.assert_default_index(
3435
"unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype"
3536
)
36-
assert_default_index("unique_all", out.counts.dtype, repr_name="out.counts.dtype")
37+
ph.assert_default_index(
38+
"unique_all", out.counts.dtype, repr_name="out.counts.dtype"
39+
)
3740

3841
assert (
3942
out.indices.shape == out.values.shape
@@ -121,7 +124,7 @@ def test_unique_counts(x):
121124
ph.assert_dtype(
122125
"unique_counts", x.dtype, out.values.dtype, repr_name="out.values.dtype"
123126
)
124-
assert_default_index(
127+
ph.assert_default_index(
125128
"unique_counts", out.counts.dtype, repr_name="out.counts.dtype"
126129
)
127130
assert (
@@ -168,7 +171,7 @@ def test_unique_inverse(x):
168171
ph.assert_dtype(
169172
"unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype"
170173
)
171-
assert_default_index(
174+
ph.assert_default_index(
172175
"unique_inverse",
173176
out.inverse_indices.dtype,
174177
repr_name="out.inverse_indices.dtype",

array_api_tests/test_sorting.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from . import pytest_helpers as ph
99
from . import shape_helpers as sh
1010
from . import xps
11-
from .test_manipulation_functions import assert_equals as assert_equals_
12-
from .test_searching_functions import assert_default_index
13-
from .test_statistical_functions import assert_equals
1411

1512

1613
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -37,7 +34,7 @@ def test_argsort(x, data):
3734

3835
out = xp.argsort(x, **kw)
3936

40-
assert_default_index("sort", out.dtype)
37+
ph.assert_default_index("sort", out.dtype)
4138
ph.assert_shape("sort", out.shape, x.shape, **kw)
4239
axis = kw.get("axis", -1)
4340
axes = sh.normalise_axis(axis, x.ndim)
@@ -50,7 +47,7 @@ def test_argsort(x, data):
5047
# sorted(..., reverse=descending) doesn't always work
5148
indices_order = reversed(indices_order)
5249
for idx, o in zip(indices, indices_order):
53-
assert_equals("argsort", int, idx, int(out[idx]), o)
50+
ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o)
5451

5552

5653
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -90,7 +87,7 @@ def test_sort(x, data):
9087
)
9188
x_indices = [indices[o] for o in indices_order]
9289
for out_idx, x_idx in zip(indices, x_indices):
93-
assert_equals_(
90+
ph.assert_0d_equals(
9491
"sort",
9592
f"x[{x_idx}]",
9693
x[x_idx],

0 commit comments

Comments
 (0)