Skip to content

Precision differences between TPU and CPU for same computations #20510

Answered by jakevdp
yixiaoer asked this question in General
Discussion options

You must be logged in to vote

Hi - thanks for the question!

In general, you'll find that operations on TPU will be less accurate than operations on CPU; the reason for this comes down to the backend-dependent implementations of various ops. In broad strokes, TPU operations tend to trade accuracy for speed, and so things like tanh will not be computed to full 32-bit precision. The reason for this is that TPUs are purpose-built for running bfloat16 neural networks, so in most cases it's wasteful to spend cycles computing activation functions to full precision when those extra decimals will be truncated in the next matmul.

Regarding jax_default_matmul_precision, keep in mind that this will only affect matmul-like operati…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@yixiaoer
Comment options

Answer selected by yixiaoer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants