vmap seems incompatible with CNN (haiku) #17894
Replies: 2 comments
-
That sounds like an XLA bug: it had trouble compiling a particular convolution. If this reproduces with an up to date jax and jaxlib, please file a bug. |
Beta Was this translation helpful? Give feedback.
0 replies
-
It turns out this warning message is actually caused by OOM... |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to vmap a 2D CNN, which is defined by using haiku. Every time jax will raise a warning says
followed by
When I disable the vmap the warning will disappear. I know I could stack every other axis to the batch dimension to avoid using vmap, however, it will be much convenient for me to keep the axises. I wonder if this is a bug in Jax or Haiku.
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions