Skip to content

Commit 7386615

Browse files
committed
Remove redundant in_stype arg in refimpl utils
1 parent 5a82a33 commit 7386615

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,12 @@ def unary_assert_against_refimpl(
6666
res: Array,
6767
refimpl: Callable[[Scalar], Scalar],
6868
expr_template: str,
69-
in_stype: Optional[ScalarType] = None,
7069
res_stype: Optional[ScalarType] = None,
7170
filter_: Callable[[Scalar], bool] = math.isfinite,
7271
):
7372
if in_.shape != res.shape:
7473
raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
75-
if in_stype is None:
76-
in_stype = dh.get_scalar_type(in_.dtype)
74+
in_stype = dh.get_scalar_type(in_.dtype)
7775
if res_stype is None:
7876
res_stype = in_stype
7977
m, M = dh.dtype_ranges.get(res.dtype, (None, None))
@@ -109,15 +107,13 @@ def binary_assert_against_refimpl(
109107
res: Array,
110108
refimpl: Callable[[Scalar, Scalar], Scalar],
111109
expr_template: str,
112-
in_stype: Optional[ScalarType] = None,
113110
res_stype: Optional[ScalarType] = None,
114111
left_sym: str = "x1",
115112
right_sym: str = "x2",
116113
res_name: str = "out",
117114
filter_: Callable[[Scalar], bool] = math.isfinite,
118115
):
119-
if in_stype is None:
120-
in_stype = dh.get_scalar_type(left.dtype)
116+
in_stype = dh.get_scalar_type(left.dtype)
121117
if res_stype is None:
122118
res_stype = in_stype
123119
m, M = dh.dtype_ranges.get(res.dtype, (None, None))
@@ -350,14 +346,12 @@ def binary_param_assert_against_refimpl(
350346
res: Array,
351347
refimpl: Callable[[Scalar, Scalar], Scalar],
352348
expr_template: str,
353-
in_stype: Optional[ScalarType] = None,
354349
res_stype: Optional[ScalarType] = None,
355350
filter_: Callable[[Scalar], bool] = math.isfinite,
356351
):
357352
if ctx.right_is_scalar:
358353
assert filter_(right) # sanity check
359-
if in_stype is None:
360-
in_stype = dh.get_scalar_type(left.dtype)
354+
in_stype = dh.get_scalar_type(left.dtype)
361355
if res_stype is None:
362356
res_stype = in_stype
363357
m, M = dh.dtype_ranges.get(left.dtype, (None, None))
@@ -389,7 +383,6 @@ def binary_param_assert_against_refimpl(
389383
else:
390384
binary_assert_against_refimpl(
391385
func_name=ctx.func_name,
392-
in_stype=in_stype,
393386
left_sym=ctx.left_sym,
394387
left=left,
395388
right_sym=ctx.right_sym,

0 commit comments

Comments
 (0)