-
To avoid materialising the output during staging, I wrote a version of @partial(jit, static_argnames=('n', 'k'))
def triu_indices(n: int, k: int = 0):
assert n >= 0 and k >= 0
N = max(n - k, 0)
l = N * (N + 1) // 2
iota = lax.iota(np.int32, N)
idx = lax.cumsum(1 + iota, reverse=True)
I = lax.cumsum(jnp.zeros(l, dtype=np.int32).at[idx].set(1))
J = lax.iota(np.int32, l) - lax.cumsum(jnp.zeros(l, dtype=np.int32).at[idx].set(iota))
return I, J + k |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Feb 7, 2023
Replies: 1 comment 9 replies
-
When I try your code on a Colab CPU runtime, I find that your JAX version is pretty fast. Can you say more about how you benchmarked this? _ = jax.block_until_ready(triu_indices(10))
%timeit jax.block_until_ready(triu_indices(10))
# 8.37 µs ± 79.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.triu_indices(10)
# 23.3 µs ± 4.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) |
Beta Was this translation helpful? Give feedback.
9 replies
Answer selected by
soraros
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When I try your code on a Colab CPU runtime, I find that your JAX version is pretty fast. Can you say more about how you benchmarked this?