|
1 | 1 | import math
|
2 | 2 | import operator
|
3 | 3 | from enum import Enum, auto
|
4 |
| -from typing import Callable, List, NamedTuple, Optional, Union |
| 4 | +from typing import Callable, List, NamedTuple, Optional, TypeVar, Union |
5 | 5 |
|
6 | 6 | import pytest
|
7 | 7 | from hypothesis import assume, given
|
@@ -85,11 +85,14 @@ def default_filter(s: Scalar) -> bool:
|
85 | 85 | return math.isfinite(s) and s is not -0.0 and s is not +0.0
|
86 | 86 |
|
87 | 87 |
|
| 88 | +T = TypeVar("T") |
| 89 | + |
| 90 | + |
88 | 91 | def unary_assert_against_refimpl(
|
89 | 92 | func_name: str,
|
90 | 93 | in_: Array,
|
91 | 94 | res: Array,
|
92 |
| - refimpl: Callable[[Scalar], Scalar], |
| 95 | + refimpl: Callable[[T], T], |
93 | 96 | expr_template: Optional[str] = None,
|
94 | 97 | res_stype: Optional[ScalarType] = None,
|
95 | 98 | filter_: Callable[[Scalar], bool] = default_filter,
|
@@ -136,7 +139,7 @@ def binary_assert_against_refimpl(
|
136 | 139 | left: Array,
|
137 | 140 | right: Array,
|
138 | 141 | res: Array,
|
139 |
| - refimpl: Callable[[Scalar, Scalar], Scalar], |
| 142 | + refimpl: Callable[[T, T], T], |
140 | 143 | expr_template: Optional[str] = None,
|
141 | 144 | res_stype: Optional[ScalarType] = None,
|
142 | 145 | left_sym: str = "x1",
|
@@ -382,7 +385,7 @@ def binary_param_assert_against_refimpl(
|
382 | 385 | right: Union[Array, Scalar],
|
383 | 386 | res: Array,
|
384 | 387 | op_sym: str,
|
385 |
| - refimpl: Callable[[Scalar, Scalar], Scalar], |
| 388 | + refimpl: Callable[[T, T], T], |
386 | 389 | res_stype: Optional[ScalarType] = None,
|
387 | 390 | filter_: Callable[[Scalar], bool] = default_filter,
|
388 | 391 | strict_check: Optional[bool] = None,
|
@@ -456,7 +459,7 @@ def test_abs(ctx, data):
|
456 | 459 | ctx.func_name,
|
457 | 460 | x,
|
458 | 461 | out,
|
459 |
| - abs, |
| 462 | + abs, # type: ignore |
460 | 463 | expr_template="abs({})={}",
|
461 | 464 | filter_=lambda s: (
|
462 | 465 | s == float("infinity") or (math.isfinite(s) and s is not -0.0)
|
@@ -1013,7 +1016,7 @@ def test_negative(ctx, data):
|
1013 | 1016 | ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
|
1014 | 1017 | ph.assert_shape(ctx.func_name, out.shape, x.shape)
|
1015 | 1018 | unary_assert_against_refimpl(
|
1016 |
| - ctx.func_name, x, out, operator.neg, expr_template="-({})={}" |
| 1019 | + ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore |
1017 | 1020 | )
|
1018 | 1021 |
|
1019 | 1022 |
|
|
0 commit comments