Porting LLaMA 2 to JAX, issues with dynamic slices 😢 #17315
Unanswered
Artur-Galstyan
asked this question in
Q&A
Replies: 1 comment
-
You might find this reference helpful: |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone!
I'm porting LLaMA2 to Jax using Equinox, but I'm having some trouble with the dynamic slices and the fact that it keeps reJITting.
Consider this:
If I keep
start_pos
the same, everything is blazingly fast (takes 0.009 seconds per run). Nice.But as soon as it changes, it takes close to 4 seconds, because it recompiles all the attention layers.
Here's what I mean:
(For brevity, here's only the relevant snippet)
The problem is the
keys
array (andvalues
too, but since they're the same shapes, I'll only refer tokeys
). Asstart_pos
becomes larger, so does thekeys
array, e.g.:etc.
I do however know, the maximum size it can have, but I'm not sure how I can use that information to make it all be fast. Does anyone have an idea?
Beta Was this translation helpful? Give feedback.
All reactions