Skip to content

Commit d8c25e3

Browse files
committed
Docs for shape assertion utils
1 parent e01073d commit d8c25e3

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,28 @@ def assert_dtype(
9191
"""
9292
Assert the output dtype is as expected.
9393
94-
We infer the expected dtype from in_dtype and to test out_dtype, e.g.
94+
If expected=None, we infer the expected dtype as in_dtype, to test
95+
out_dtype, e.g.
9596
9697
>>> x = xp.arange(5, dtype=xp.uint8)
9798
>>> out = xp.abs(x)
9899
>>> assert_dtype('abs', x.dtype, out.dtype)
99100
101+
is equivalent to
102+
103+
>>> assert out.dtype == xp.uint8
104+
100105
Or for multiple input dtypes, the expected dtype is inferred from their
101106
resulting type promotion, e.g.
102107
103108
>>> x1 = xp.arange(5, dtype=xp.uint8)
104109
>>> x2 = xp.arange(5, dtype=xp.uint16)
105110
>>> out = xp.add(x1, x2)
106-
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) # expected=xp.uint16
111+
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
112+
113+
is equivalent to
114+
115+
>>> assert out.dtype == xp.uint16
107116
108117
We can also specify the expected dtype ourselves, e.g.
109118
@@ -182,7 +191,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
182191
"""
183192
Assert the output dtype is the default index dtype, e.g.
184193
185-
>>> out = xp.argmax(<array>)
194+
>>> out = xp.argmax(xp.arange(5))
186195
>>> assert_default_int('argmax', out.dtype)
187196
188197
"""
@@ -202,6 +211,13 @@ def assert_shape(
202211
repr_name="out.shape",
203212
**kw,
204213
):
214+
"""
215+
Assert the output shape is as expected, e.g.
216+
217+
>>> out = xp.ones((3, 3, 3))
218+
>>> assert_shape('ones', out.shape, (3, 3, 3))
219+
220+
"""
205221
if isinstance(out_shape, int):
206222
out_shape = (out_shape,)
207223
if isinstance(expected, int):
@@ -222,6 +238,20 @@ def assert_result_shape(
222238
repr_name="out.shape",
223239
**kw,
224240
):
241+
"""
242+
Assert the output shape is as expected.
243+
244+
If expected=None, we infer the expected shape as the result of broadcasting
245+
in_shapes, to test against out_shape, e.g.
246+
247+
>>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
248+
>>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
249+
250+
is equivalent to
251+
252+
>>> assert out.shape == (3, 3)
253+
254+
"""
225255
if expected is None:
226256
expected = sh.broadcast_shapes(*in_shapes)
227257
f_in_shapes = " . ".join(str(s) for s in in_shapes)

0 commit comments

Comments
 (0)