You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
More numerically stable jax.experimental.jet._integer_pow_taylor.
The previous version of `jax.experimental.jet._integer_pow_taylor` became numerically unstable when we evaluated the derivative jet of `x**n` around `x = 0` for integer `n > 2`. This was especially true in the higher order derivatives in `float32`.
Beforehand there was a complicated expression for approximating the jet of `x**n`. In this CL we replace integer powers with explicit multiplication, e.g. `x**2` becomes `x*x`. For higher powers of `x`, we use the exponentiation by squaring trick, which evaluates expressions such as `x**4` as `(x * x) * (x * x)`, which uses log(n) multiplies instead of `n-1`.
PiperOrigin-RevId: 617205359
0 commit comments