-
Dear contributor, I'm trying to implement complex step finite difference (CSFD) method in jax, which seems efficient for high order derivative. However I encounter a problem which seems intrinsic to jax. Let's use CSFD to calculate the derivative of cos by numpy and jax.numpy import numpy as np
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
def np_cos(x):
return np.cos(x)
x = 1.5
h = 1e-20
-np.sin(x)
# right answer:-0.9974949866040544 CSFD works fine with numpy using extremely small h (even h=1e-300 works). np_cos(x+1j*h).imag / h
#output -0.9974949866040546 However, jax.numpy seems fail def jax_cos(x):
return jnp.cos(x)
jax_cos(x+1j*h).imag / h
#output 0. I think main reason is the different output of cos in numpy and jax.numpy. Is there some way to fix it? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi, thanks for the question. I think this is not a bug, but rather a floating point roundoff issue that is a consequence of different choices for how to implement complex trigonometric functions. For a detailed explanation of what is happening here, see the similar question in #9358 and in particular the answer at #9358 (comment) |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question. I think this is not a bug, but rather a floating point roundoff issue that is a consequence of different choices for how to implement complex trigonometric functions. For a detailed explanation of what is happening here, see the similar question in #9358 and in particular the answer at #9358 (comment)