jnp.fft vs. spicy.ftt expected numerical difference #22058
-
Hi all, I'm observing that there is a relatively large difference between the implementation of FFT in Jax and scipy that does grow with array size. I am running the following on an apple Mac with M1 Pro and CPU version of Jax. import jax.numpy as jnp
from jax import random
import numpy as np
import scipy.fft
# Generate a random input array with double precision
key = random.PRNGKey(0)
x_jax = random.normal(key, (1024,))
# Compute FFT using JAX
fft_jax = jnp.fft.fft(x_jax)
# Compute FFT using SciPy
x_np = np.array(x_jax) # Convert JAX array to NumPy array
fft_scipy = scipy.fft.fft(x_np)
# Compute the maximum absolute error between the results
max_error = np.max(np.abs(fft_jax - fft_scipy))
print("Maximum absolute error:", max_error)
# Check the dtype to confirm precision
print("JAX FFT dtype:", fft_jax.dtype)
print("SciPy FFT dtype:", fft_scipy.dtype) The output in this case is
given complex64 I would have expected less numerical difference and the max absolute difference does seem to increase if I increase the array size. I would appreciate any help/explanation! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
JAX defaults to 32-bit computation (complex64 = two 32-bit components), while numpy/scipy default to 64-bit, even when inputs are 32-bit. The difference you're seeing is typical of rounding error for 32-bit computations. If you enable 64-bit computations in JAX (see JAX Sharp Bits: 64-bit precision) then you will see rounding errors typical of 64-bit float computation. |
Beta Was this translation helpful? Give feedback.
JAX defaults to 32-bit computation (complex64 = two 32-bit components), while numpy/scipy default to 64-bit, even when inputs are 32-bit. The difference you're seeing is typical of rounding error for 32-bit computations. If you enable 64-bit computations in JAX (see JAX Sharp Bits: 64-bit precision) then you will see rounding errors typical of 64-bit float computation.