@@ -88,6 +88,31 @@ def assert_dtype(
88
88
* ,
89
89
repr_name : str = "out.dtype" ,
90
90
):
91
+ """
92
+ Tests the output dtype is as expected.
93
+
94
+ We infer the expected dtype from in_dtype and to test out_dtype, e.g.
95
+
96
+ >>> x = xp.arange(5, dtype=xp.uint8)
97
+ >>> out = xp.abs(x)
98
+ >>> assert_dtype('abs', x.dtype, out.dtype)
99
+
100
+ Or for multiple input dtypes, the expected dtype is inferred from their
101
+ resulting type promotion, e.g.
102
+
103
+ >>> x1 = xp.arange(5, dtype=xp.uint8)
104
+ >>> x2 = xp.arange(5, dtype=xp.uint16)
105
+ >>> out = xp.add(x1, x2)
106
+ >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) # expected=xp.uint16
107
+
108
+ We can also specify the expected dtype ourselves, e.g.
109
+
110
+ >>> x = xp.arange(5, dtype=xp.int8)
111
+ >>> out = xp.sum(x)
112
+ >>> default_int = xp.asarray(0).dtype
113
+ >>> assert_dtype('sum', x, out.dtype, default_int)
114
+
115
+ """
91
116
in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) else [in_dtype ]
92
117
f_in_dtypes = dh .fmt_types (tuple (in_dtypes ))
93
118
f_out_dtype = dh .dtype_to_name [out_dtype ]
0 commit comments