@@ -89,7 +89,7 @@ def assert_dtype(
89
89
repr_name : str = "out.dtype" ,
90
90
):
91
91
"""
92
- Tests the output dtype is as expected.
92
+ Assert the output dtype is as expected.
93
93
94
94
We infer the expected dtype from in_dtype and to test out_dtype, e.g.
95
95
@@ -128,7 +128,7 @@ def assert_dtype(
128
128
129
129
def assert_kw_dtype (func_name : str , kw_dtype : DataType , out_dtype : DataType ):
130
130
"""
131
- Test the output dtype is the passed keyword dtype, e.g.
131
+ Assert the output dtype is the passed keyword dtype, e.g.
132
132
133
133
>>> kw = {'dtype': xp.uint8}
134
134
>>> out = xp.ones(5, **kw)
@@ -144,33 +144,54 @@ def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType):
144
144
assert out_dtype == kw_dtype , msg
145
145
146
146
147
- def assert_default_float (func_name : str , dtype : DataType ):
148
- f_dtype = dh .dtype_to_name [dtype ]
147
+ def assert_default_float (func_name : str , out_dtype : DataType ):
148
+ """
149
+ Assert the output dtype is the default float, e.g.
150
+
151
+ >>> out = xp.ones(5)
152
+ >>> assert_default_float('ones', out.dtype)
153
+
154
+ """
155
+ f_dtype = dh .dtype_to_name [out_dtype ]
149
156
f_default = dh .dtype_to_name [dh .default_float ]
150
157
msg = (
151
158
f"out.dtype={ f_dtype } , should be default "
152
159
f"floating-point dtype { f_default } [{ func_name } ()]"
153
160
)
154
- assert dtype == dh .default_float , msg
161
+ assert out_dtype == dh .default_float , msg
162
+
163
+
164
+ def assert_default_int (func_name : str , out_dtype : DataType ):
165
+ """
166
+ Assert the output dtype is the default int, e.g.
155
167
168
+ >>> out = xp.full(5, 42)
169
+ >>> assert_default_int('full', out.dtype)
156
170
157
- def assert_default_int ( func_name : str , dtype : DataType ):
158
- f_dtype = dh .dtype_to_name [dtype ]
171
+ """
172
+ f_dtype = dh .dtype_to_name [out_dtype ]
159
173
f_default = dh .dtype_to_name [dh .default_int ]
160
174
msg = (
161
175
f"out.dtype={ f_dtype } , should be default "
162
176
f"integer dtype { f_default } [{ func_name } ()]"
163
177
)
164
- assert dtype == dh .default_int , msg
178
+ assert out_dtype == dh .default_int , msg
179
+
165
180
181
+ def assert_default_index (func_name : str , out_dtype : DataType , repr_name = "out.dtype" ):
182
+ """
183
+ Assert the output dtype is the default index dtype, e.g.
184
+
185
+ >>> out = xp.argmax(<array>)
186
+ >>> assert_default_int('argmax', out.dtype)
166
187
167
- def assert_default_index ( func_name : str , dtype : DataType , repr_name = "out.dtype" ):
168
- f_dtype = dh .dtype_to_name [dtype ]
188
+ """
189
+ f_dtype = dh .dtype_to_name [out_dtype ]
169
190
msg = (
170
191
f"{ repr_name } ={ f_dtype } , should be the default index dtype, "
171
192
f"which is either int32 or int64 [{ func_name } ()]"
172
193
)
173
- assert dtype in (xp .int32 , xp .int64 ), msg
194
+ assert out_dtype in (xp .int32 , xp .int64 ), msg
174
195
175
196
176
197
def assert_shape (
0 commit comments