Skip to content

Commit 3830b17

Browse files
author
jax authors
committed
More numerically stable jax.experimental.jet._integer_pow_taylor.
The previous version of `jax.experimental.jet._integer_pow_taylor` became numerically unstable when we evaluated the derivative jet of `x**n` around `x = 0` for integer `n > 2`. This was especially true in the higher order derivatives in `float32`. Beforehand there was a complicated expression for approximating the jet of `x**n`. In this CL we replace integer powers with explicit multiplication, e.g. `x**2` becomes `x*x`. For higher powers of `x`, we use the exponentiation by squaring trick, which evaluates expressions such as `x**4` as `(x * x) * (x * x)`, which uses log(n) multiplies instead of `n-1`. PiperOrigin-RevId: 617205359
1 parent d866288 commit 3830b17

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

jax/experimental/jet.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -490,24 +490,22 @@ def _pow_taylor(primals_in, series_in):
490490
return primal_out, series_out
491491
jet_rules[lax.pow_p] = _pow_taylor
492492

493+
def _pow_by_squaring(x, n):
494+
if n < 0:
495+
return _pow_by_squaring(1 / x, -n)
496+
elif n == 0:
497+
return 1
498+
elif n % 2 == 0:
499+
return _pow_by_squaring(x * x, n / 2)
500+
elif n % 2 == 1:
501+
return x * _pow_by_squaring(x * x, (n - 1) / 2)
502+
493503
def _integer_pow_taylor(primals_in, series_in, *, y):
494504
if y == 0:
495505
return jet(jnp.ones_like, primals_in, series_in)
496-
elif y == 1:
497-
return jet(lambda x: x, primals_in, series_in)
498-
elif y == 2:
499-
return jet(lambda x: x * x, primals_in, series_in)
500-
x, = primals_in
501-
series, = series_in
502-
u = [x] + series
503-
v = [lax.integer_pow(x, y)] + [None] * len(series)
504-
for k in range(1, len(v)):
505-
vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1))
506-
uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k))
507-
v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x)
508-
primal_out, *series_out = v
506+
else:
507+
return jet(lambda x: _pow_by_squaring(x, y), primals_in, series_in)
509508

510-
return primal_out, series_out
511509
jet_rules[lax.integer_pow_p] = _integer_pow_taylor
512510

513511

0 commit comments

Comments
 (0)