Skip to content

new to jax, how to optimize this extremely slow gaussian splat render? #26120

Answered by jakevdp
shi-yan asked this question in Q&A
Discussion options

You must be logged in to vote

You identified the main issue here: any time you find yourself writing for loops over array entries, your code is likely to be slow (this is true in JAX as it is in NumPy). The solution generally is to re-express your computations in terms of vectorized operations. In the simplest cases, it means that instead of something like this:

def f_loop(x: Array) -> Array:
  y = []
  for i in range(len(x)):
    y.append(jnp.exp(0.5 * x[i]))
  return jnp.array(y)

you could write something like this (taking advantage of the fact that jax.numpy operations work element-wise):

def f_vectorized(x: Array) -> Array:
  return jnp.exp(0.5 * x)

Now, vectorized operations don't admit the kind of early-exit con…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by shi-yan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants