Why does jax.numpy have different trigonometric functions from numpy and math? #14227
Answered
by
jakevdp
hyunwoooh5
asked this question in
General
-
While I was writing a code in jax, I found that my c++ code and python give different results. Also, I checked other libararies or and mathematica and only jax.numpy give the different values: Is there any reason for this? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jan 31, 2023
Replies: 1 comment
-
JAX defaults to 32-bit computation (see http://go/jax-sharp-bits#double-64bit-precision), while import jax
jax.config.update('jax_enable_x64', True)
import math
import numpy as np
import jax.numpy as jnp
print(f"{jnp.cos(3.0):.15f}")
print(f"{math.cos(3.0):.15f}")
print(f"{np.cos(3.0):.15f}")
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
hyunwoooh5
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
JAX defaults to 32-bit computation (see http://go/jax-sharp-bits#double-64bit-precision), while
math
,numpy
, and others default to 64-bit computation. If you enable 64-bit computation in JAX, then you should see that the results match: