diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b39bd86..6454764 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -300,14 +300,35 @@ def clip( if min is not None and max is not None and np.any(min > max): raise ValueError("min must be less than or equal to max") - result = np.clip(x._array, min, max) # Note: NumPy applies type promotion, but the standard specifies the - # return dtype should be the same as x - if result.dtype != x.dtype._np_dtype: - # TODO: I'm not completely sure this always gives the correct thing - # for integer dtypes. See https://github.com/numpy/numpy/issues/24976 - result = result.astype(x.dtype._np_dtype) - return Array._new(result) + # return dtype should be the same as x We do this instead of just + # downcasting the result of np.clip() to handle some corner cases better + # (e.g., avoiding uint64 -> float64 promotion). + + # Note: cases where min or max overflow (integer) or round (float) in the + # wrong direction when downcasting to x.dtype are unspecified. This code + # just does whatever NumPy does when it downcasts in the assignment, but + # other behavior could be preferred, especially for integers. For example, + # this code produces: + + # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None) + # -128 + + # but an answer of 0 might be preferred. See + # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. + min_shape = () if min is None else min.shape + max_shape = () if max is None else max.shape + result_shape = np.broadcast_shapes(x.shape, min_shape, max_shape) + out = asarray(np.broadcast_to(x._array, result_shape), copy=True)._array + if min is not None: + a = np.broadcast_to(min, result_shape) + ia = (out < a) | np.isnan(a) + out[ia] = a[ia] + if max is not None: + b = np.broadcast_to(max, result_shape) + ib = (out > b) | np.isnan(b) + out[ib] = b[ib] + return Array._new(out) def conj(x: Array, /) -> Array: """