Skip to content

Commit 31c5a89

Browse files
committed
Add clip()
It is only enabled for when the api version is 2023.12. I have only tested that it works manually. There is no test suite support for clip() yet.
1 parent c39fdbf commit 31c5a89

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

array_api_strict/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
bitwise_right_shift,
135135
bitwise_xor,
136136
ceil,
137+
clip,
137138
conj,
138139
cos,
139140
cosh,
@@ -196,6 +197,7 @@
196197
"bitwise_right_shift",
197198
"bitwise_xor",
198199
"ceil",
200+
"clip",
199201
"cos",
200202
"cosh",
201203
"divide",

array_api_strict/_elementwise_functions.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
_result_type,
1313
)
1414
from ._array_object import Array
15+
from ._flags import requires_api_version
16+
from ._creation_functions import asarray
17+
from ._utility_functions import any as xp_any
18+
19+
from typing import Optional, Union
1520

1621
import numpy as np
1722

@@ -240,6 +245,68 @@ def ceil(x: Array, /) -> Array:
240245
return x
241246
return Array._new(np.ceil(x._array))
242247

248+
# WARNING: This function is not yet tested by the array-api-tests test suite.
249+
250+
# Note: min and max argument names are different and not optional in numpy.
251+
@requires_api_version('2023.12')
252+
def clip(
253+
x: Array,
254+
/,
255+
min: Optional[Union[int, float, Array]] = None,
256+
max: Optional[Union[int, float, Array]] = None,
257+
) -> Array:
258+
"""
259+
Array API compatible wrapper for :py:func:`np.clip <numpy.clip>`.
260+
261+
See its docstring for more information.
262+
"""
263+
if (x.dtype not in _real_numeric_dtypes
264+
or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes
265+
or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes):
266+
raise TypeError("Only real numeric dtypes are allowed in clip")
267+
if not isinstance(min, (int, float, Array, type(None))):
268+
raise TypeError("min must be an None, int, float, or an array")
269+
if not isinstance(max, (int, float, Array, type(None))):
270+
raise TypeError("max must be an None, int, float, or an array")
271+
272+
# Mixed dtype kinds is implementation defined
273+
if (x.dtype in _integer_dtypes
274+
and (isinstance(min, float) or
275+
isinstance(min, Array) and min.dtype in _real_floating_dtypes)):
276+
raise TypeError("min must be integral when x is integral")
277+
if (x.dtype in _integer_dtypes
278+
and (isinstance(max, float) or
279+
isinstance(max, Array) and max.dtype in _real_floating_dtypes)):
280+
raise TypeError("max must be integral when x is integral")
281+
if (x.dtype in _real_floating_dtypes
282+
and (isinstance(min, int) or
283+
isinstance(min, Array) and min.dtype in _integer_dtypes)):
284+
raise TypeError("min must be floating-point when x is floating-point")
285+
if (x.dtype in _real_floating_dtypes
286+
and (isinstance(max, int) or
287+
isinstance(max, Array) and max.dtype in _integer_dtypes)):
288+
raise TypeError("max must be floating-point when x is floating-point")
289+
290+
if min is max is None:
291+
# Note: NumPy disallows min = max = None
292+
return x
293+
294+
# Normalize to make the below logic simpler
295+
if min is not None:
296+
min = asarray(min)._array
297+
if max is not None:
298+
max = asarray(max)._array
299+
300+
# min > max is implementation defined
301+
if min is not None and max is not None and np.any(min > max):
302+
raise ValueError("min must be less than or equal to max")
303+
304+
result = np.clip(x._array, min, max)
305+
# Note: NumPy applies type promotion, but the standard specifies the
306+
# return dtype should be the same as x
307+
if result.dtype != x.dtype._np_dtype:
308+
result = result.astype(x.dtype._np_dtype)
309+
return Array._new(result)
243310

244311
def conj(x: Array, /) -> Array:
245312
"""

0 commit comments

Comments
 (0)