Skip to content

Multiple indexing vs Slicing #30118

Answered by jakevdp
kmg3821 asked this question in Q&A
Discussion options

You must be logged in to vote

In the simplest cases, slicing will lower to a slice HLO, while multiple indexing will lower to a gather HLO. You can see this by using jax.make_jaxpr to view the jaxpr representation of the operation:

import jax
import jax.numpy as jnp

x = jnp.ones((10,))
i = jnp.arange(5)

def f1(x):
  return x[:5]
print(jax.make_jaxpr(f1)(x))
{ lambda ; a:f32[10]. let
    b:f32[5] = slice[limit_indices=(5,) start_indices=(0,) strides=None] a
  in (b,) }
def f2(x):
  return x[jnp.arange(5)]
print(jax.make_jaxpr(f2)(x))
{ lambda ; a:f32[10]. let
    b:i32[5] = iota[dimension=0 dtype=int32 shape=(5,) sharding=None] 
    c:bool[5] = lt b 0
    d:i32[5] = add b 10
    e:i32[5] = select_n c b d
    f:i32[5…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by kmg3821
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants