Skip to content

Commit 095be2f

Browse files
committed
Require the repeats array to have an integer dtype
NumPy allows it to be bool (casting it to int).
1 parent 9938059 commit 095be2f

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

array_api_strict/_manipulation_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ._array_object import Array
44
from ._creation_functions import asarray
55
from ._data_type_functions import result_type
6+
from ._dtypes import _integer_dtypes
67
from ._flags import requires_api_version, get_array_api_strict_flags
78

89
from typing import TYPE_CHECKING
@@ -86,6 +87,8 @@ def repeat(
8687
data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes']
8788
if not data_dependent_shapes:
8889
raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
90+
if repeats.dtype not in _integer_dtypes:
91+
raise TypeError("The repeats array must have an integer dtype")
8992
elif isinstance(repeats, int):
9093
repeats = asarray(repeats)
9194
else:

0 commit comments

Comments
 (0)