Skip to content

Why does jax.numpy have different trigonometric functions from numpy and math? #14227

Answered by jakevdp
hyunwoooh5 asked this question in General
Discussion options

You must be logged in to vote

JAX defaults to 32-bit computation (see http://go/jax-sharp-bits#double-64bit-precision), while math, numpy, and others default to 64-bit computation. If you enable 64-bit computation in JAX, then you should see that the results match:

import jax
jax.config.update('jax_enable_x64', True)

import math
import numpy as np
import jax.numpy as jnp

print(f"{jnp.cos(3.0):.15f}")
print(f"{math.cos(3.0):.15f}")
print(f"{np.cos(3.0):.15f}")
-0.989992496600445
-0.989992496600445
-0.989992496600445

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by hyunwoooh5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants