Skip to content

Commit e99a305

Browse files
committed
jnp.floor_div: lower directly to div for unsigned int
1 parent 87b869d commit e99a305

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jax/_src/numpy/ufuncs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,9 @@ def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
281281
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
282282
x1, x2 = promote_args_numeric("floor_divide", x1, x2)
283283
dtype = dtypes.dtype(x1)
284-
if dtypes.issubdtype(dtype, np.integer):
284+
if dtypes.issubdtype(dtype, np.unsignedinteger):
285+
return lax.div(x1, x2)
286+
elif dtypes.issubdtype(dtype, np.integer):
285287
quotient = lax.div(x1, x2)
286288
select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
287289
# TODO(mattjj): investigate why subtracting a scalar was causing promotion

0 commit comments

Comments
 (0)