-
Hello, I am quite noob to JAX. I am trying to implement Alphazero by myself. So I referred to the "search" part of MCTX code from Deepmind. They use I wonder why they don't use In short, for |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In the simplest cases, slicing will lower to a import jax
import jax.numpy as jnp
x = jnp.ones((10,))
i = jnp.arange(5)
def f1(x):
return x[:5]
print(jax.make_jaxpr(f1)(x))
def f2(x):
return x[jnp.arange(5)]
print(jax.make_jaxpr(f2)(x))
In general, I'd expect slicing to be faster than multiple indexing because it gives the compiler better constraints about the operation: i.e. a slice is always a contiguous chunk of the array, while a sequence of indices in general requires a random access pattern. |
Beta Was this translation helpful? Give feedback.
In the simplest cases, slicing will lower to a
slice
HLO, while multiple indexing will lower to agather
HLO. You can see this by usingjax.make_jaxpr
to view the jaxpr representation of the operation: