Skip to content

Commit 5a82a33

Browse files
committed
Fix typing issues with refimpl utils
1 parent 4d849f1 commit 5a82a33

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ def unary_assert_against_refimpl(
7676
in_stype = dh.get_scalar_type(in_.dtype)
7777
if res_stype is None:
7878
res_stype = in_stype
79-
if res.dtype != xp.bool:
80-
m, M = dh.dtype_ranges[res.dtype]
79+
m, M = dh.dtype_ranges.get(res.dtype, (None, None))
8180
for idx in sh.ndindex(in_.shape):
8281
scalar_i = in_stype(in_[idx])
8382
if not filter_(scalar_i):
8483
continue
8584
expected = refimpl(scalar_i)
8685
if res.dtype != xp.bool:
86+
assert m is not None and M is not None # for mypy
8787
if expected <= m or expected >= M:
8888
continue
8989
scalar_o = res_stype(res[idx])
@@ -105,7 +105,7 @@ def unary_assert_against_refimpl(
105105
def binary_assert_against_refimpl(
106106
func_name: str,
107107
left: Array,
108-
right: Union[Scalar, Array],
108+
right: Array,
109109
res: Array,
110110
refimpl: Callable[[Scalar, Scalar], Scalar],
111111
expr_template: str,
@@ -120,24 +120,23 @@ def binary_assert_against_refimpl(
120120
in_stype = dh.get_scalar_type(left.dtype)
121121
if res_stype is None:
122122
res_stype = in_stype
123-
result_dtype = dh.result_type(left.dtype, right.dtype)
124-
if result_dtype != xp.bool:
125-
m, M = dh.dtype_ranges[result_dtype]
123+
m, M = dh.dtype_ranges.get(res.dtype, (None, None))
126124
for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
127125
scalar_l = in_stype(left[l_idx])
128126
scalar_r = in_stype(right[r_idx])
129127
if not (filter_(scalar_l) and filter_(scalar_r)):
130128
continue
131129
expected = refimpl(scalar_l, scalar_r)
132-
if result_dtype != xp.bool:
130+
if res.dtype != xp.bool:
131+
assert m is not None and M is not None # for mypy
133132
if expected <= m or expected >= M:
134133
continue
135134
scalar_o = res_stype(res[o_idx])
136135
f_l = sh.fmt_idx(left_sym, l_idx)
137136
f_r = sh.fmt_idx(right_sym, r_idx)
138137
f_o = sh.fmt_idx(res_name, o_idx)
139138
expr = expr_template.format(f_l, f_r, expected)
140-
if dh.is_float_dtype(result_dtype):
139+
if dh.is_float_dtype(res.dtype):
141140
assert isclose(scalar_o, expected), (
142141
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
143142
f"{f_l}={scalar_l}, {f_r}={scalar_r}"
@@ -357,18 +356,18 @@ def binary_param_assert_against_refimpl(
357356
):
358357
if ctx.right_is_scalar:
359358
assert filter_(right) # sanity check
360-
if left.dtype != xp.bool:
361-
m, M = dh.dtype_ranges[left.dtype]
362359
if in_stype is None:
363360
in_stype = dh.get_scalar_type(left.dtype)
364361
if res_stype is None:
365362
res_stype = in_stype
363+
m, M = dh.dtype_ranges.get(left.dtype, (None, None))
366364
for idx in sh.ndindex(res.shape):
367365
scalar_l = in_stype(left[idx])
368366
if not filter_(scalar_l):
369367
continue
370368
expected = refimpl(scalar_l, right)
371369
if left.dtype != xp.bool:
370+
assert m is not None and M is not None # for mypy
372371
if expected <= m or expected >= M:
373372
continue
374373
scalar_o = res_stype(res[idx])

0 commit comments

Comments
 (0)