Skip to content

Commit 685d97f

Browse files
Jake VanderPlasjax authors
authored andcommitted
Remove support for complex jnp.floor_divide
It is not well-defined, and the implementation is currently untested. Both Python and NumPy raise an error when attempting complex-valued floor division. PiperOrigin-RevId: 621918604
1 parent 23b0fa8 commit 685d97f

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

jax/_src/numpy/ufuncs.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,16 +289,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
289289
# TODO(mattjj): investigate why subtracting a scalar was causing promotion
290290
return _where(select, quotient - 1, quotient)
291291
elif dtypes.issubdtype(dtype, np.complexfloating):
292-
x1r = lax.real(x1)
293-
x1i = lax.imag(x1)
294-
x2r = lax.real(x2)
295-
x2i = lax.imag(x2)
296-
which = lax.ge(lax.abs(x2r), lax.abs(x2i))
297-
rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
298-
rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
299-
out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
300-
lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
301-
return lax.convert_element_type(out, dtype)
292+
raise TypeError("floor_divide does not support complex-valued inputs")
302293
else:
303294
return _float_divmod(x1, x2)[0]
304295

0 commit comments

Comments
 (0)