-
I'm implementing a givens rotations that require a lot of matrix multiplications with same shape size. Here is code: from functools import reduce
import jax
import jax.numpy as jnp
import time
def gr() -> None:
"""Run givens rotation."""
size = 30 # 40, 50, 90. # change size here.
thetas = jax.random.uniform(jax.random.key(42),
(4, size *
(size - 1) // 2)) # batch 4, size 90*89//2 = 4005
# batch givens rotation
ix, iy = jnp.triu_indices(size, 1)
g = jnp.eye(size) * jnp.ones_like(thetas)[..., None, None]
c, s = jnp.cos(thetas), jnp.sin(thetas)
cs = jnp.stack([c, -s], axis=1)
sc = jnp.stack([s, c], axis=1)
g = (
g.at[..., jnp.arange(len(ix)), ix, [ix, iy]].set(cs).at[...,
jnp.arange(len(ix)), iy,
[ix, iy]].set(sc))
# now g is (4, 4005, 90, 90)
# run one givens rotations:
t1 = time.time()
multi_dot = jax.jit(jnp.linalg.multi_dot)
ng = multi_dot(g[0]).block_until_ready()
t2 = time.time()
print("multi dot 1st run:", t2 - t1)
t1 = time.time()
ng = multi_dot(g[1]).block_until_ready()
t2 = time.time()
print("multi dot:", t2 - t1)
t1 = time.time()
@jax.jit
def rdot(xx: jax.Array) -> jax.Array:
return reduce(jnp.dot, xx)
ng = rdot(g[0]).block_until_ready()
t2 = time.time()
print("reduce dot 1st run:", t2 - t1)
t1 = time.time()
ng = rdot(g[1]).block_until_ready()
t2 = time.time()
print("reduce dot:", t2 - t1) I have run function # 30
multi dot 1st run: 6.373845100402832
multi dot: 0.0009920597076416016
reduce dot 1st run: 1.6028120517730713
reduce dot: 0.0012731552124023438
# 40
multi dot 1st run: 31.00104808807373
multi dot: 0.002560138702392578
reduce dot 1st run: 3.654721975326538
reduce dot: 0.0026671886444091797
# 50
multi dot 1st run: 117.09066390991211
multi dot: 0.0068149566650390625
reduce dot 1st run: 8.86375379562378
reduce dot: 0.0063228607177734375
# 90 I ignored multi dot
reduce dot 1st run: 93.46390914916992
reduce dot: 0.10265707969665527 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The slow compilation times here come from the fact that in either case your operation is relying on Python control flow to dispatch a large number of JAX operations: Python loops are flattened leading to large programs, and large programs have slow compilation (compilation time grows roughly as the number of operations squared). In situations like this, you can often do better by re-expressing your loop in terms of a control flow primitive like @jax.jit
def scan_dot(xx: jax.Array) -> jax.Array:
return lax.scan(lambda y, x: (jnp.dot(y, x), None), jnp.eye(size, dtype=xx.dtype), xx)[0] For
|
Beta Was this translation helpful? Give feedback.
multi_dot
is not much help here – its main purpose is to find an optimal ordering of a sequence of dot products given the size of the inputs. In your case, all inputs are the same size, so the order of operations doesn't matter.The slow compilation times here come from the fact that in either case your operation is relying on Python control flow to dispatch a large number of JAX operations: Python loops are flattened leading to large programs, and large programs have slow compilation (compilation time grows roughly as the number of operations squared). In situations like this, you can often do better by re-expressing your loop in terms of a control flow primitive like
lax.scan
. For example: