Skip to content

Commit b10629e

Browse files
committed
BUG: work around numpy<2 ceil etc for int inputs
1 parent f54d6f3 commit b10629e

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

array_api_compat/numpy/_aliases.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,26 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
142142
return np.take_along_axis(x, indices, axis=axis)
143143

144144

145+
# ceil, floor, and trunc return integers for integer inputs in NumPy < 2
146+
147+
def ceil(x: Array, /) -> Array:
148+
if np.issubdtype(x.dtype, np.integer):
149+
return x.copy()
150+
return np.ceil(x)
151+
152+
153+
def floor(x: Array, /) -> Array:
154+
if np.issubdtype(x.dtype, np.integer):
155+
return x.copy()
156+
return np.floor(x)
157+
158+
159+
def round(x: Array, /) -> Array:
160+
if np.issubdtype(x.dtype, np.integer):
161+
return x.copy()
162+
return np.round(x)
163+
164+
145165
# These functions are completely new here. If the library already has them
146166
# (i.e., numpy 2.0), use the library version instead of our wrapper.
147167
if hasattr(np, "vecdot"):
@@ -170,6 +190,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
170190
"atan",
171191
"atan2",
172192
"atanh",
193+
"ceil",
194+
"floor",
195+
"round",
173196
"bitwise_left_shift",
174197
"bitwise_invert",
175198
"bitwise_right_shift",

0 commit comments

Comments
 (0)