@@ -76,14 +76,14 @@ def unary_assert_against_refimpl(
76
76
in_stype = dh .get_scalar_type (in_ .dtype )
77
77
if res_stype is None :
78
78
res_stype = in_stype
79
- if res .dtype != xp .bool :
80
- m , M = dh .dtype_ranges [res .dtype ]
79
+ m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
81
80
for idx in sh .ndindex (in_ .shape ):
82
81
scalar_i = in_stype (in_ [idx ])
83
82
if not filter_ (scalar_i ):
84
83
continue
85
84
expected = refimpl (scalar_i )
86
85
if res .dtype != xp .bool :
86
+ assert m is not None and M is not None # for mypy
87
87
if expected <= m or expected >= M :
88
88
continue
89
89
scalar_o = res_stype (res [idx ])
@@ -105,7 +105,7 @@ def unary_assert_against_refimpl(
105
105
def binary_assert_against_refimpl (
106
106
func_name : str ,
107
107
left : Array ,
108
- right : Union [ Scalar , Array ] ,
108
+ right : Array ,
109
109
res : Array ,
110
110
refimpl : Callable [[Scalar , Scalar ], Scalar ],
111
111
expr_template : str ,
@@ -120,24 +120,23 @@ def binary_assert_against_refimpl(
120
120
in_stype = dh .get_scalar_type (left .dtype )
121
121
if res_stype is None :
122
122
res_stype = in_stype
123
- result_dtype = dh .result_type (left .dtype , right .dtype )
124
- if result_dtype != xp .bool :
125
- m , M = dh .dtype_ranges [result_dtype ]
123
+ m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
126
124
for l_idx , r_idx , o_idx in sh .iter_indices (left .shape , right .shape , res .shape ):
127
125
scalar_l = in_stype (left [l_idx ])
128
126
scalar_r = in_stype (right [r_idx ])
129
127
if not (filter_ (scalar_l ) and filter_ (scalar_r )):
130
128
continue
131
129
expected = refimpl (scalar_l , scalar_r )
132
- if result_dtype != xp .bool :
130
+ if res .dtype != xp .bool :
131
+ assert m is not None and M is not None # for mypy
133
132
if expected <= m or expected >= M :
134
133
continue
135
134
scalar_o = res_stype (res [o_idx ])
136
135
f_l = sh .fmt_idx (left_sym , l_idx )
137
136
f_r = sh .fmt_idx (right_sym , r_idx )
138
137
f_o = sh .fmt_idx (res_name , o_idx )
139
138
expr = expr_template .format (f_l , f_r , expected )
140
- if dh .is_float_dtype (result_dtype ):
139
+ if dh .is_float_dtype (res . dtype ):
141
140
assert isclose (scalar_o , expected ), (
142
141
f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
143
142
f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
@@ -357,18 +356,18 @@ def binary_param_assert_against_refimpl(
357
356
):
358
357
if ctx .right_is_scalar :
359
358
assert filter_ (right ) # sanity check
360
- if left .dtype != xp .bool :
361
- m , M = dh .dtype_ranges [left .dtype ]
362
359
if in_stype is None :
363
360
in_stype = dh .get_scalar_type (left .dtype )
364
361
if res_stype is None :
365
362
res_stype = in_stype
363
+ m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
366
364
for idx in sh .ndindex (res .shape ):
367
365
scalar_l = in_stype (left [idx ])
368
366
if not filter_ (scalar_l ):
369
367
continue
370
368
expected = refimpl (scalar_l , right )
371
369
if left .dtype != xp .bool :
370
+ assert m is not None and M is not None # for mypy
372
371
if expected <= m or expected >= M :
373
372
continue
374
373
scalar_o = res_stype (res [idx ])
0 commit comments