Skip to content

vmapping pad with variable argument but static output size #13371

Answered by gnecula
stefano-1981 asked this question in General
Discussion options

You must be logged in to vote

Can you use dynamic_update_slice instead of pad? Something like this:

def my_pad(m, translation):
      base = jnp.zeros((1000, 1000), dtype=jnp.float32)
      return lax.dynamic_update_slice(base, m, (translation[0], translation[1]))

out = jax.vmap(my_pad)(images, translation)

If this works for you in JAX, then the next step is to see if this can be converted with jax2tf + tflite.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@stefano-1981
Comment options

Answer selected by stefano-1981
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants