JAX instability when factoring out matmul #21739
-
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
This looks like it's behaving as expected: you're executing your code in float32 precision. float32 has 23 bits of mantissa, so different ways of computing the "same" floating-point result would be expected to differ by 1 part in Note that this is not a JAX issue, but an issue with any library that does floating point computations. For example, here's NumPy doing float32 vector multiplication two ways: front-to-back and back-to-front: In [1]: import numpy as np
In [2]: np.random.seed(0)
In [3]: x = np.random.randn(1000).astype('float32')
In [4]: result1 = x @ x
In [5]: result2 = x[::-1] @ x[::-1]
In [6]: np.testing.assert_array_equal(result1, result2)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<...>
Arrays are not equal
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.00048828
Max relative difference: 5.001431e-07
x: array(976.28265, dtype=float32)
y: array(976.28314, dtype=float32) You see that NumPy shows similar differences if you compute the "same" result two different ways. There is no fix to this: it's fundamental to floating point computations. The real solution here is to write your code such that it's robust to this kind of expected floating point rounding error. |
Beta Was this translation helpful? Give feedback.
This looks like it's behaving as expected: you're executing your code in float32 precision. float32 has 23 bits of mantissa, so different ways of computing the "same" floating-point result would be expected to differ by 1 part in$2^{23}$ , which is roughly
1E-7
.Note that this is not a JAX issue, but an issue with any library that does floating point computations. For example, here's NumPy doing float32 vector multiplication two ways: front-to-back and back-to-front: