Skip to content

JAX instability when factoring out matmul #21739

Answered by jakevdp
max-vladymyrov asked this question in General
Discussion options

You must be logged in to vote

This looks like it's behaving as expected: you're executing your code in float32 precision. float32 has 23 bits of mantissa, so different ways of computing the "same" floating-point result would be expected to differ by 1 part in $2^{23}$, which is roughly 1E-7.

Note that this is not a JAX issue, but an issue with any library that does floating point computations. For example, here's NumPy doing float32 vector multiplication two ways: front-to-back and back-to-front:

In [1]: import numpy as np
In [2]: np.random.seed(0)
In [3]: x = np.random.randn(1000).astype('float32')
In [4]: result1 = x @ x
In [5]: result2 = x[::-1] @ x[::-1]
In [6]: np.testing.assert_array_equal(result1, result2)
-------

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by max-vladymyrov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants