JAX jit optimization seems too weak #9911
Replies: 4 comments 9 replies
-
After I use |
Beta Was this translation helpful? Give feedback.
-
Which backend is this on? CPU, GPU, and TPU are quite different. (CPU is the least optimized by far.) |
Beta Was this translation helpful? Give feedback.
-
@mattjj runnable code: version 1 import jax
import jax.numpy as jnp
from jax import lax
def signbit(x):
return lax.shift_right_logical(lax.bitcast_convert_type(x, jnp.int32), 31)
@jax.jit
@jax.vmap
def eigvalsrs_tridiagonal_bisection(
a: jnp.ndarray, # (n,) diag
b: jnp.ndarray, # (n - 1,) sub-diag
):
b2 = b ** 2
def count(x):
# Thanks to IEEE-754, we don't need pivmin! https://epubs.siam.org/doi/epdf/10.1137/050641624
# also see Faster numerical algorithms via exception handling
def scan_f(carry, data):
q, c = carry
ai, b2i = data
q = ai - x - b2i / q
return (q, c + signbit(q)), None
return lax.scan(scan_f, (1.0, 0), (a, jnp.pad(b2, (1, 0))), unroll=48)[0][1]
b_abs = lax.abs(b)
r = jnp.pad(b_abs, (1, 0)) + jnp.pad(b_abs, (0, 1))
emax = jnp.max(a + r)
emin = jnp.min(a - r)
norm = lax.max(lax.abs(emax), lax.abs(emin))
n = a.size
upper0 = emax + norm * 3e-7 * n
lower0 = emin - norm * 3e-7 * n
@jax.vmap
def bisection(cnt):
def step(carry, _):
lower, upper = carry
mid = (lower + upper) / 2
pred = count(mid) <= cnt
lower = lax.select(pred, mid, lower)
upper = lax.select(pred, upper, mid)
return (lower, upper), None
lower, upper = lax.scan(step, (lower0, upper0), None, length=24, unroll=3)[0]
return (lower + upper) / 2
return bisection(jnp.arange(n))
def test():
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
bsz = 4096
n = 2048
a = jax.random.normal(key1, (bsz, n))
b2 = jax.random.gamma(key2, jnp.arange(n - 1, 0, -1), (bsz, n - 1))
b = jnp.sqrt(b2)
def timer(f):
from time import time
f() # warmup
t = time()
for _ in range(3):
f()
print((time() - t) / 3)
from jax.scipy.linalg import eigh_tridiagonal
@jax.jit
@jax.vmap
def eigvalsrs_tridiagonal_jax(a, b):
return eigh_tridiagonal(a, b, eigvals_only=True)
timer(lambda: jax.block_until_ready(eigvalsrs_tridiagonal_jax(a, b)))
timer(lambda: jax.block_until_ready(eigvalsrs_tridiagonal_bisection(a, b)))
print(eigvalsrs_tridiagonal_bisection(a, b) - eigvalsrs_tridiagonal_jax(a, b)) version 2 import jax
import jax.numpy as jnp
from jax import lax
def signbit(x):
return lax.shift_right_logical(lax.bitcast_convert_type(x, jnp.int32), 31)
@jax.jit
@jax.vmap
def eigvalsrs_tridiagonal_bisection(
a: jnp.ndarray, # (n,) diag
b: jnp.ndarray, # (n - 1,) sub-diag
):
b2 = b ** 2
def count(x):
# Thanks to IEEE-754, we don't need pivmin! https://epubs.siam.org/doi/epdf/10.1137/050641624
# also see Faster numerical algorithms via exception handling
def scan_f(carry, data):
q, c = carry
ai, b2i = data
q = ai - x - b2i * lax.reciprocal(q)
return (q, c + signbit(q)), None
return lax.scan(scan_f, (1.0, 0), (a, jnp.pad(b2, (1, 0))), unroll=48)[0][1]
b_abs = lax.abs(b)
r = jnp.pad(b_abs, (1, 0)) + jnp.pad(b_abs, (0, 1))
emax = jnp.max(a + r)
emin = jnp.min(a - r)
norm = lax.max(lax.abs(emax), lax.abs(emin))
n = a.size
upper0 = emax + norm * 3e-7 * n
lower0 = emin - norm * 3e-7 * n
@jax.vmap
def bisection(cnt):
def step(carry, _):
lower, upper = carry
mid = (lower + upper) / 2
pred = count(mid) <= cnt
lower = lax.select(pred, mid, lower)
upper = lax.select(pred, upper, mid)
return (lower, upper), None
lower, upper = lax.scan(step, (lower0, upper0), None, length=24, unroll=3)[0]
return (lower + upper) / 2
return bisection(jnp.arange(n))
def test():
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
bsz = 4096
n = 2048
a = jax.random.normal(key1, (bsz, n))
b2 = jax.random.gamma(key2, jnp.arange(n - 1, 0, -1), (bsz, n - 1))
b = jnp.sqrt(b2)
def timer(f):
from time import time
f() # warmup
t = time()
for _ in range(3):
f()
print((time() - t) / 3)
from jax.scipy.linalg import eigh_tridiagonal
@jax.jit
@jax.vmap
def eigvalsrs_tridiagonal_jax(a, b):
return eigh_tridiagonal(a, b, eigvals_only=True)
timer(lambda: jax.block_until_ready(eigvalsrs_tridiagonal_jax(a, b)))
timer(lambda: jax.block_until_ready(eigvalsrs_tridiagonal_bisection(a, b)))
print(eigvalsrs_tridiagonal_bisection(a, b) - eigvalsrs_tridiagonal_jax(a, b)) version 3 import jax
import jax.numpy as jnp
from jax import lax
def signbit(x):
return lax.shift_right_logical(lax.bitcast_convert_type(x, jnp.int32), 31)
@jax.jit
@jax.vmap
def eigvalsrs_tridiagonal_bisection(
a: jnp.ndarray, # (n,) diag
b: jnp.ndarray, # (n - 1,) sub-diag
):
b2 = b ** 2
def count(x):
# Thanks to IEEE-754, we don't need pivmin! https://epubs.siam.org/doi/epdf/10.1137/050641624
# also see Faster numerical algorithms via exception handling
def scan_f(carry, data):
rec_q, c = carry
ai, b2i = data
q = ai - x - b2i * rec_q
return (lax.reciprocal(q), c + signbit(q)), None
return lax.scan(scan_f, (1.0, 0), (a, jnp.pad(b2, (1, 0))), unroll=48)[0][1]
b_abs = lax.abs(b)
r = jnp.pad(b_abs, (1, 0)) + jnp.pad(b_abs, (0, 1))
emax = jnp.max(a + r)
emin = jnp.min(a - r)
norm = lax.max(lax.abs(emax), lax.abs(emin))
n = a.size
upper0 = emax + norm * 3e-7 * n
lower0 = emin - norm * 3e-7 * n
@jax.vmap
def bisection(cnt):
def step(carry, _):
lower, upper = carry
mid = (lower + upper) / 2
pred = count(mid) <= cnt
lower = lax.select(pred, mid, lower)
upper = lax.select(pred, upper, mid)
return (lower, upper), None
lower, upper = lax.scan(step, (lower0, upper0), None, length=24, unroll=3)[0]
return (lower + upper) / 2
return bisection(jnp.arange(n))
def test():
key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
bsz = 4096
n = 2048
a = jax.random.normal(key1, (bsz, n))
b2 = jax.random.gamma(key2, jnp.arange(n - 1, 0, -1), (bsz, n - 1))
b = jnp.sqrt(b2)
def timer(f):
from time import time
f() # warmup
t = time()
for _ in range(3):
f()
print((time() - t) / 3)
from jax.scipy.linalg import eigh_tridiagonal
@jax.jit
@jax.vmap
def eigvalsrs_tridiagonal_jax(a, b):
return eigh_tridiagonal(a, b, eigvals_only=True)
timer(lambda: jax.block_until_ready(eigvalsrs_tridiagonal_jax(a, b)))
timer(lambda: jax.block_until_ready(eigvalsrs_tridiagonal_bisection(a, b)))
print(eigvalsrs_tridiagonal_bisection(a, b) - eigvalsrs_tridiagonal_jax(a, b)) |
Beta Was this translation helpful? Give feedback.
-
Surprisingly: q = - b2i / q + ai - x is 10% faster than q = ai - x - b2i / q |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
My use case:(GPU backend)
Data:
originally I write
time cost:0.7506s
And I change to
time cost:0.7276s
Finally I use
time cost:0.6984s (
jax.scipy.eigh_tridiagonal
time cost: 1.040s)I originally thought that
jit
can do such simple optimization for me, but it actually not.When
unroll=32
, version 2 faster than version 1 faster than version 3. Really strange!Beta Was this translation helpful? Give feedback.
All reactions