Skip to content

Commit bfa3961

Browse files
committed
Docstring for ph.assert_dtype()
1 parent 37be0aa commit bfa3961

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,31 @@ def assert_dtype(
8888
*,
8989
repr_name: str = "out.dtype",
9090
):
91+
"""
92+
Tests the output dtype is as expected.
93+
94+
We infer the expected dtype from in_dtype and to test out_dtype, e.g.
95+
96+
>>> x = xp.arange(5, dtype=xp.uint8)
97+
>>> out = xp.abs(x)
98+
>>> assert_dtype('abs', x.dtype, out.dtype)
99+
100+
Or for multiple input dtypes, the expected dtype is inferred from their
101+
resulting type promotion, e.g.
102+
103+
>>> x1 = xp.arange(5, dtype=xp.uint8)
104+
>>> x2 = xp.arange(5, dtype=xp.uint16)
105+
>>> out = xp.add(x1, x2)
106+
>>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) # expected=xp.uint16
107+
108+
We can also specify the expected dtype ourselves, e.g.
109+
110+
>>> x = xp.arange(5, dtype=xp.int8)
111+
>>> out = xp.sum(x)
112+
>>> default_int = xp.asarray(0).dtype
113+
>>> assert_dtype('sum', x, out.dtype, default_int)
114+
115+
"""
91116
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype]
92117
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
93118
f_out_dtype = dh.dtype_to_name[out_dtype]

0 commit comments

Comments
 (0)