-
I compared results for some computations with JAX on both TPU and CPU using import jax; jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
import einops as op
import jax.numpy as jnp
import jax.random as jrand
batch_size, input_size, len_, hidden_size = 10, 512, 20, 1024
key = jrand.key(42)
key, subkey = jrand.split(key)
x = jrand.normal(subkey, (batch_size, input_size, len_))
key, subkey = jrand.split(key)
w = jrand.normal(subkey, (batch_size, input_size, hidden_size))
out_einsum = op.einsum(x, w, 'b i l, b i h-> b l h')
out_tanh = jnp.tanh(x)
out_sigmoid = jax.nn.sigmoid(x)
out_elemal = x * x
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
out_einsum_cpu = op.einsum(x, w, 'b i l, b i h-> b l h')
out_tanh_cpu = jnp.tanh(x)
out_sigmoid_cpu = jax.nn.sigmoid(x)
out_elemal_cpu = x * x
print(jnp.allclose(out_einsum, out_einsum_cpu, atol=1e-5)) # False
print(jnp.allclose(out_tanh, out_tanh_cpu, atol=1e-5)) # False
print(jnp.allclose(out_sigmoid, out_sigmoid_cpu, atol=1e-5)) # True
print(jnp.allclose(out_elemal, out_elemal_cpu, atol=1e-5)) # True Given the README description of precision on TPU, I think both TPU and CPU uses 32-bit values, and expected identical results: Why there are differences in the outcomes of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi - thanks for the question! In general, you'll find that operations on TPU will be less accurate than operations on CPU; the reason for this comes down to the backend-dependent implementations of various ops. In broad strokes, TPU operations tend to trade accuracy for speed, and so things like Regarding We should probably update the docs you refer to in order to make these features more clear. Does that answer your question? |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question!
In general, you'll find that operations on TPU will be less accurate than operations on CPU; the reason for this comes down to the backend-dependent implementations of various ops. In broad strokes, TPU operations tend to trade accuracy for speed, and so things like
tanh
will not be computed to full 32-bit precision. The reason for this is that TPUs are purpose-built for running bfloat16 neural networks, so in most cases it's wasteful to spend cycles computing activation functions to full precision when those extra decimals will be truncated in the next matmul.Regarding
jax_default_matmul_precision
, keep in mind that this will only affect matmul-like operati…