Skip to content

Commit ce2e2c0

Browse files
committed
Hard-code array scalar casting input dtypes for dh.func_in_dtypes
1 parent 6071f44 commit ce2e2c0

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,13 @@ def result_type(*dtypes: DataType):
406406
func_returns_bool[iop] = func_returns_bool[op]
407407

408408

409+
func_in_dtypes["__bool__"] = (xp.bool,)
410+
func_in_dtypes["__int__"] = all_int_dtypes
411+
func_in_dtypes["__index__"] = all_int_dtypes
412+
func_in_dtypes["__float__"] = float_dtypes
413+
func_in_dtypes["__dlpack__"] = numeric_dtypes
414+
415+
409416
@lru_cache
410417
def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
411418
f_types = []

0 commit comments

Comments
 (0)