Replies: 1 comment
-
What you're describing is inconsistent with how Here's an example of microbenchmarks showing that import jax.numpy as jnp
from jax import nn
import jax
@jax.jit
def f1(x):
return x/x.shape[2]
@jax.jit
def f2(x):
temp = nn.relu(x)
return temp/(jnp.sum(temp,axis=-1,keepdims=True) + 1e-5)
@jax.jit
def choose_attention(alpha, x):
return jax.lax.cond(alpha[0, 0, 0, 0], lambda _: f2(x), lambda _: f1(x), operand=None)
x = jnp.zeros((10, 10, 1000))
_ = f1(x)
%timeit f1(x).block_until_ready()
# 25.1 µs ± 784 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
_ = f2(x)
%timeit f2(x).block_until_ready()
# 119 µs ± 19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
alpha = jnp.zeros((1, 1, 1, 1))
_ = choose_attention(alpha, x)
%timeit choose_attention(alpha, x).block_until_ready()
# 26.9 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
alpha = jnp.ones((1, 1, 1, 1))
_ = choose_attention(alpha, x)
%timeit choose_attention(alpha, x).block_until_ready()
# 121 µs ± 18.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am looping through each attention head and applying either function f1 or f2 depending on the value of parameter self.alpha. F1 is slower than F2 but my implementation always gives me the same runtime when i run for different values of alpha.
Beta Was this translation helpful? Give feedback.
All reactions