vmapping pad with variable argument but static output size #13371
-
Hi there, I have a tensor of images, say, BxNxN and a translation tensor Bx2. I would like to vmap the def mypad(m, translation)
out_shape = (1000, 1000)
return jax.numpy.pad(m, ((translation[0], 1000 - N - translation[0]), (translation[1], 1000 - N - translation[1]) ))
# images has shape [B,N,N] and translations has shape [B, 2]
out = jax.vmap(mypad)(images, translations) Now I think I understand that this does not work because the pad operation output is non static. However this is a special case where the padding is variable but the output shape of As a workaround I've been using Any help? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Can you use
If this works for you in JAX, then the next step is to see if this can be converted with jax2tf + tflite. |
Beta Was this translation helpful? Give feedback.
Can you use
dynamic_update_slice
instead ofpad
? Something like this:If this works for you in JAX, then the next step is to see if this can be converted with jax2tf + tflite.