|
12 | 12 | _result_type,
|
13 | 13 | )
|
14 | 14 | from ._array_object import Array
|
| 15 | +from ._flags import requires_api_version |
| 16 | +from ._creation_functions import asarray |
| 17 | +from ._utility_functions import any as xp_any |
| 18 | + |
| 19 | +from typing import Optional, Union |
15 | 20 |
|
16 | 21 | import numpy as np
|
17 | 22 |
|
@@ -240,6 +245,68 @@ def ceil(x: Array, /) -> Array:
|
240 | 245 | return x
|
241 | 246 | return Array._new(np.ceil(x._array))
|
242 | 247 |
|
| 248 | +# WARNING: This function is not yet tested by the array-api-tests test suite. |
| 249 | + |
| 250 | +# Note: min and max argument names are different and not optional in numpy. |
| 251 | +@requires_api_version('2023.12') |
| 252 | +def clip( |
| 253 | + x: Array, |
| 254 | + /, |
| 255 | + min: Optional[Union[int, float, Array]] = None, |
| 256 | + max: Optional[Union[int, float, Array]] = None, |
| 257 | +) -> Array: |
| 258 | + """ |
| 259 | + Array API compatible wrapper for :py:func:`np.clip <numpy.clip>`. |
| 260 | +
|
| 261 | + See its docstring for more information. |
| 262 | + """ |
| 263 | + if (x.dtype not in _real_numeric_dtypes |
| 264 | + or isinstance(min, Array) and min.dtype not in _real_numeric_dtypes |
| 265 | + or isinstance(max, Array) and max.dtype not in _real_numeric_dtypes): |
| 266 | + raise TypeError("Only real numeric dtypes are allowed in clip") |
| 267 | + if not isinstance(min, (int, float, Array, type(None))): |
| 268 | + raise TypeError("min must be an None, int, float, or an array") |
| 269 | + if not isinstance(max, (int, float, Array, type(None))): |
| 270 | + raise TypeError("max must be an None, int, float, or an array") |
| 271 | + |
| 272 | + # Mixed dtype kinds is implementation defined |
| 273 | + if (x.dtype in _integer_dtypes |
| 274 | + and (isinstance(min, float) or |
| 275 | + isinstance(min, Array) and min.dtype in _real_floating_dtypes)): |
| 276 | + raise TypeError("min must be integral when x is integral") |
| 277 | + if (x.dtype in _integer_dtypes |
| 278 | + and (isinstance(max, float) or |
| 279 | + isinstance(max, Array) and max.dtype in _real_floating_dtypes)): |
| 280 | + raise TypeError("max must be integral when x is integral") |
| 281 | + if (x.dtype in _real_floating_dtypes |
| 282 | + and (isinstance(min, int) or |
| 283 | + isinstance(min, Array) and min.dtype in _integer_dtypes)): |
| 284 | + raise TypeError("min must be floating-point when x is floating-point") |
| 285 | + if (x.dtype in _real_floating_dtypes |
| 286 | + and (isinstance(max, int) or |
| 287 | + isinstance(max, Array) and max.dtype in _integer_dtypes)): |
| 288 | + raise TypeError("max must be floating-point when x is floating-point") |
| 289 | + |
| 290 | + if min is max is None: |
| 291 | + # Note: NumPy disallows min = max = None |
| 292 | + return x |
| 293 | + |
| 294 | + # Normalize to make the below logic simpler |
| 295 | + if min is not None: |
| 296 | + min = asarray(min)._array |
| 297 | + if max is not None: |
| 298 | + max = asarray(max)._array |
| 299 | + |
| 300 | + # min > max is implementation defined |
| 301 | + if min is not None and max is not None and np.any(min > max): |
| 302 | + raise ValueError("min must be less than or equal to max") |
| 303 | + |
| 304 | + result = np.clip(x._array, min, max) |
| 305 | + # Note: NumPy applies type promotion, but the standard specifies the |
| 306 | + # return dtype should be the same as x |
| 307 | + if result.dtype != x.dtype._np_dtype: |
| 308 | + result = result.astype(x.dtype._np_dtype) |
| 309 | + return Array._new(result) |
243 | 310 |
|
244 | 311 | def conj(x: Array, /) -> Array:
|
245 | 312 | """
|
|
0 commit comments