@@ -91,19 +91,28 @@ def assert_dtype(
91
91
"""
92
92
Assert the output dtype is as expected.
93
93
94
- We infer the expected dtype from in_dtype and to test out_dtype, e.g.
94
+ If expected=None, we infer the expected dtype as in_dtype, to test
95
+ out_dtype, e.g.
95
96
96
97
>>> x = xp.arange(5, dtype=xp.uint8)
97
98
>>> out = xp.abs(x)
98
99
>>> assert_dtype('abs', x.dtype, out.dtype)
99
100
101
+ is equivalent to
102
+
103
+ >>> assert out.dtype == xp.uint8
104
+
100
105
Or for multiple input dtypes, the expected dtype is inferred from their
101
106
resulting type promotion, e.g.
102
107
103
108
>>> x1 = xp.arange(5, dtype=xp.uint8)
104
109
>>> x2 = xp.arange(5, dtype=xp.uint16)
105
110
>>> out = xp.add(x1, x2)
106
- >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) # expected=xp.uint16
111
+ >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype)
112
+
113
+ is equivalent to
114
+
115
+ >>> assert out.dtype == xp.uint16
107
116
108
117
We can also specify the expected dtype ourselves, e.g.
109
118
@@ -182,7 +191,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
182
191
"""
183
192
Assert the output dtype is the default index dtype, e.g.
184
193
185
- >>> out = xp.argmax(<array> )
194
+ >>> out = xp.argmax(xp.arange(5) )
186
195
>>> assert_default_int('argmax', out.dtype)
187
196
188
197
"""
@@ -202,6 +211,13 @@ def assert_shape(
202
211
repr_name = "out.shape" ,
203
212
** kw ,
204
213
):
214
+ """
215
+ Assert the output shape is as expected, e.g.
216
+
217
+ >>> out = xp.ones((3, 3, 3))
218
+ >>> assert_shape('ones', out.shape, (3, 3, 3))
219
+
220
+ """
205
221
if isinstance (out_shape , int ):
206
222
out_shape = (out_shape ,)
207
223
if isinstance (expected , int ):
@@ -222,6 +238,20 @@ def assert_result_shape(
222
238
repr_name = "out.shape" ,
223
239
** kw ,
224
240
):
241
+ """
242
+ Assert the output shape is as expected.
243
+
244
+ If expected=None, we infer the expected shape as the result of broadcasting
245
+ in_shapes, to test against out_shape, e.g.
246
+
247
+ >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
248
+ >>> assert_shape('add', [(3, 1), (1, 3)], out.shape)
249
+
250
+ is equivalent to
251
+
252
+ >>> assert out.shape == (3, 3)
253
+
254
+ """
225
255
if expected is None :
226
256
expected = sh .broadcast_shapes (* in_shapes )
227
257
f_in_shapes = " . " .join (str (s ) for s in in_shapes )
0 commit comments