Skip to content

Commit 9938059

Browse files
committed
Add repeat()
1 parent eb063e2 commit 9938059

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,14 @@
277277
flip,
278278
moveaxis,
279279
permute_dims,
280+
repeat,
280281
reshape,
281282
roll,
282283
squeeze,
283284
stack,
284285
)
285286

286-
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "reshape", "roll", "squeeze", "stack"]
287+
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack"]
287288

288289
from ._searching_functions import argmax, argmin, nonzero, where
289290

array_api_strict/_flags.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def set_array_api_strict_flags(
8181
The functions that make use of data-dependent shapes, and are therefore
8282
disabled by setting this flag to False are
8383
84-
- `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`.
85-
- `nonzero`
84+
- `unique_all()`, `unique_counts()`, `unique_inverse()`, and `unique_values()`.
85+
- `nonzero()`
8686
- Boolean array indexing
87-
- `repeat` when the `repeats` argument is an array (requires 2023.12
87+
- `repeat()` when the `repeats` argument is an array (requires 2023.12
8888
version of the standard)
8989
9090
See

array_api_strict/_manipulation_functions.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
from ._array_object import Array
4+
from ._creation_functions import asarray
45
from ._data_type_functions import result_type
5-
from ._flags import requires_api_version
6+
from ._flags import requires_api_version, get_array_api_strict_flags
67

78
from typing import TYPE_CHECKING
89

@@ -58,7 +59,6 @@ def moveaxis(
5859
"""
5960
return Array._new(np.moveaxis(x._array, source, destination))
6061

61-
6262
# Note: The function name is different here (see also matrix_transpose).
6363
# Unlike transpose(), the axes argument is required.
6464
def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
@@ -69,6 +69,29 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
6969
"""
7070
return Array._new(np.transpose(x._array, axes))
7171

72+
@requires_api_version('2023.12')
73+
def repeat(
74+
x: Array,
75+
repeats: Union[int, Array],
76+
/,
77+
*,
78+
axis: Optional[int] = None,
79+
) -> Array:
80+
"""
81+
Array API compatible wrapper for :py:func:`np.repeat <numpy.repeat>`.
82+
83+
See its docstring for more information.
84+
"""
85+
if isinstance(repeats, Array):
86+
data_dependent_shapes = get_array_api_strict_flags()['data_dependent_shapes']
87+
if not data_dependent_shapes:
88+
raise RuntimeError("repeat() with repeats as an array requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")
89+
elif isinstance(repeats, int):
90+
repeats = asarray(repeats)
91+
else:
92+
raise TypeError("repeats must be an int or array")
93+
94+
return Array._new(np.repeat(x._array, repeats, axis=axis))
7295

7396
# Note: the optional argument is called 'shape', not 'newshape'
7497
def reshape(x: Array,

array_api_strict/tests/test_flags.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
reset_array_api_strict_flags)
33

44
from .. import (asarray, unique_all, unique_counts, unique_inverse,
5-
unique_values, nonzero)
5+
unique_values, nonzero, repeat)
66

77
import array_api_strict as xp
88

@@ -102,8 +102,12 @@ def test_api_version():
102102
assert xp.__array_api_version__ == '2021.12'
103103

104104
def test_data_dependent_shapes():
105+
with pytest.warns(UserWarning):
106+
set_array_api_strict_flags(api_version='2023.12') # to enable repeat()
107+
105108
a = asarray([0, 0, 1, 2, 2])
106109
mask = asarray([True, False, True, False, True])
110+
repeats = asarray([1, 1, 2, 2, 2])
107111

108112
# Should not error
109113
unique_all(a)
@@ -112,7 +116,8 @@ def test_data_dependent_shapes():
112116
unique_values(a)
113117
nonzero(a)
114118
a[mask]
115-
# TODO: add repeat when it is implemented
119+
repeat(a, repeats)
120+
repeat(a, 2)
116121

117122
set_array_api_strict_flags(data_dependent_shapes=False)
118123

@@ -122,6 +127,8 @@ def test_data_dependent_shapes():
122127
pytest.raises(RuntimeError, lambda: unique_values(a))
123128
pytest.raises(RuntimeError, lambda: nonzero(a))
124129
pytest.raises(RuntimeError, lambda: a[mask])
130+
pytest.raises(RuntimeError, lambda: repeat(a, repeats))
131+
repeat(a, 2) # Should never error
125132

126133
linalg_examples = {
127134
'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)),

0 commit comments

Comments
 (0)