Skip to content

Commit e01073d

Browse files
committed
Docstrings for default dtype utils
1 parent c6d59e6 commit e01073d

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def assert_dtype(
8989
repr_name: str = "out.dtype",
9090
):
9191
"""
92-
Tests the output dtype is as expected.
92+
Assert the output dtype is as expected.
9393
9494
We infer the expected dtype from in_dtype and to test out_dtype, e.g.
9595
@@ -128,7 +128,7 @@ def assert_dtype(
128128

129129
def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
130130
"""
131-
Test the output dtype is the passed keyword dtype, e.g.
131+
Assert the output dtype is the passed keyword dtype, e.g.
132132
133133
>>> kw = {'dtype': xp.uint8}
134134
>>> out = xp.ones(5, **kw)
@@ -144,33 +144,54 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
144144
assert out_dtype == kw_dtype, msg
145145

146146

147-
def assert_default_float(func_name: str, dtype: DataType):
148-
f_dtype = dh.dtype_to_name[dtype]
147+
def assert_default_float(func_name: str, out_dtype: DataType):
148+
"""
149+
Assert the output dtype is the default float, e.g.
150+
151+
>>> out = xp.ones(5)
152+
>>> assert_default_float('ones', out.dtype)
153+
154+
"""
155+
f_dtype = dh.dtype_to_name[out_dtype]
149156
f_default = dh.dtype_to_name[dh.default_float]
150157
msg = (
151158
f"out.dtype={f_dtype}, should be default "
152159
f"floating-point dtype {f_default} [{func_name}()]"
153160
)
154-
assert dtype == dh.default_float, msg
161+
assert out_dtype == dh.default_float, msg
162+
163+
164+
def assert_default_int(func_name: str, out_dtype: DataType):
165+
"""
166+
Assert the output dtype is the default int, e.g.
155167
168+
>>> out = xp.full(5, 42)
169+
>>> assert_default_int('full', out.dtype)
156170
157-
def assert_default_int(func_name: str, dtype: DataType):
158-
f_dtype = dh.dtype_to_name[dtype]
171+
"""
172+
f_dtype = dh.dtype_to_name[out_dtype]
159173
f_default = dh.dtype_to_name[dh.default_int]
160174
msg = (
161175
f"out.dtype={f_dtype}, should be default "
162176
f"integer dtype {f_default} [{func_name}()]"
163177
)
164-
assert dtype == dh.default_int, msg
178+
assert out_dtype == dh.default_int, msg
179+
165180

181+
def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dtype"):
182+
"""
183+
Assert the output dtype is the default index dtype, e.g.
184+
185+
>>> out = xp.argmax(<array>)
186+
>>> assert_default_int('argmax', out.dtype)
166187
167-
def assert_default_index(func_name: str, dtype: DataType, repr_name="out.dtype"):
168-
f_dtype = dh.dtype_to_name[dtype]
188+
"""
189+
f_dtype = dh.dtype_to_name[out_dtype]
169190
msg = (
170191
f"{repr_name}={f_dtype}, should be the default index dtype, "
171192
f"which is either int32 or int64 [{func_name}()]"
172193
)
173-
assert dtype in (xp.int32, xp.int64), msg
194+
assert out_dtype in (xp.int32, xp.int64), msg
174195

175196

176197
def assert_shape(

0 commit comments

Comments
 (0)