@@ -66,14 +66,12 @@ def unary_assert_against_refimpl(
66
66
res : Array ,
67
67
refimpl : Callable [[Scalar ], Scalar ],
68
68
expr_template : str ,
69
- in_stype : Optional [ScalarType ] = None ,
70
69
res_stype : Optional [ScalarType ] = None ,
71
70
filter_ : Callable [[Scalar ], bool ] = math .isfinite ,
72
71
):
73
72
if in_ .shape != res .shape :
74
73
raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
75
- if in_stype is None :
76
- in_stype = dh .get_scalar_type (in_ .dtype )
74
+ in_stype = dh .get_scalar_type (in_ .dtype )
77
75
if res_stype is None :
78
76
res_stype = in_stype
79
77
m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
@@ -109,15 +107,13 @@ def binary_assert_against_refimpl(
109
107
res : Array ,
110
108
refimpl : Callable [[Scalar , Scalar ], Scalar ],
111
109
expr_template : str ,
112
- in_stype : Optional [ScalarType ] = None ,
113
110
res_stype : Optional [ScalarType ] = None ,
114
111
left_sym : str = "x1" ,
115
112
right_sym : str = "x2" ,
116
113
res_name : str = "out" ,
117
114
filter_ : Callable [[Scalar ], bool ] = math .isfinite ,
118
115
):
119
- if in_stype is None :
120
- in_stype = dh .get_scalar_type (left .dtype )
116
+ in_stype = dh .get_scalar_type (left .dtype )
121
117
if res_stype is None :
122
118
res_stype = in_stype
123
119
m , M = dh .dtype_ranges .get (res .dtype , (None , None ))
@@ -350,14 +346,12 @@ def binary_param_assert_against_refimpl(
350
346
res : Array ,
351
347
refimpl : Callable [[Scalar , Scalar ], Scalar ],
352
348
expr_template : str ,
353
- in_stype : Optional [ScalarType ] = None ,
354
349
res_stype : Optional [ScalarType ] = None ,
355
350
filter_ : Callable [[Scalar ], bool ] = math .isfinite ,
356
351
):
357
352
if ctx .right_is_scalar :
358
353
assert filter_ (right ) # sanity check
359
- if in_stype is None :
360
- in_stype = dh .get_scalar_type (left .dtype )
354
+ in_stype = dh .get_scalar_type (left .dtype )
361
355
if res_stype is None :
362
356
res_stype = in_stype
363
357
m , M = dh .dtype_ranges .get (left .dtype , (None , None ))
@@ -389,7 +383,6 @@ def binary_param_assert_against_refimpl(
389
383
else :
390
384
binary_assert_against_refimpl (
391
385
func_name = ctx .func_name ,
392
- in_stype = in_stype ,
393
386
left_sym = ctx .left_sym ,
394
387
left = left ,
395
388
right_sym = ctx .right_sym ,
0 commit comments