Skip to content

Commit 493f669

Browse files
committed
Generic type hint for refimpl args
1 parent 6e8cda6 commit 493f669

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import operator
33
from enum import Enum, auto
4-
from typing import Callable, List, NamedTuple, Optional, Union
4+
from typing import Callable, List, NamedTuple, Optional, TypeVar, Union
55

66
import pytest
77
from hypothesis import assume, given
@@ -85,11 +85,14 @@ def default_filter(s: Scalar) -> bool:
8585
return math.isfinite(s) and s is not -0.0 and s is not +0.0
8686

8787

88+
T = TypeVar("T")
89+
90+
8891
def unary_assert_against_refimpl(
8992
func_name: str,
9093
in_: Array,
9194
res: Array,
92-
refimpl: Callable[[Scalar], Scalar],
95+
refimpl: Callable[[T], T],
9396
expr_template: Optional[str] = None,
9497
res_stype: Optional[ScalarType] = None,
9598
filter_: Callable[[Scalar], bool] = default_filter,
@@ -136,7 +139,7 @@ def binary_assert_against_refimpl(
136139
left: Array,
137140
right: Array,
138141
res: Array,
139-
refimpl: Callable[[Scalar, Scalar], Scalar],
142+
refimpl: Callable[[T, T], T],
140143
expr_template: Optional[str] = None,
141144
res_stype: Optional[ScalarType] = None,
142145
left_sym: str = "x1",
@@ -382,7 +385,7 @@ def binary_param_assert_against_refimpl(
382385
right: Union[Array, Scalar],
383386
res: Array,
384387
op_sym: str,
385-
refimpl: Callable[[Scalar, Scalar], Scalar],
388+
refimpl: Callable[[T, T], T],
386389
res_stype: Optional[ScalarType] = None,
387390
filter_: Callable[[Scalar], bool] = default_filter,
388391
strict_check: Optional[bool] = None,
@@ -456,7 +459,7 @@ def test_abs(ctx, data):
456459
ctx.func_name,
457460
x,
458461
out,
459-
abs,
462+
abs, # type: ignore
460463
expr_template="abs({})={}",
461464
filter_=lambda s: (
462465
s == float("infinity") or (math.isfinite(s) and s is not -0.0)
@@ -1013,7 +1016,7 @@ def test_negative(ctx, data):
10131016
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
10141017
ph.assert_shape(ctx.func_name, out.shape, x.shape)
10151018
unary_assert_against_refimpl(
1016-
ctx.func_name, x, out, operator.neg, expr_template="-({})={}"
1019+
ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore
10171020
)
10181021

10191022

0 commit comments

Comments
 (0)