Calls to decode slow down over time during parallel generation #3629
-
is this expected? I assume it is, but I just thought I'd ask. The effect doesn't seem noticeable with single sequence generation, but the total sequence lengths involved are also a lot smaller. For example, running a 70B model with 8 parallel sequences, shared prompt of 1,230 tokens:
Output from my hacked version of At |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
If the prompt is shared then the context is just The computation increases with the sequence length since the KQ and KQV operations grow with the number of tokens. |
Beta Was this translation helpful? Give feedback.
If the prompt is shared then the context is just
1230 + 40*8
The computation increases with the sequence length since the KQ and KQV operations grow with the number of tokens.
The current KV cache implementation computes KQ and KQV for all batches in a single pass by masking the attention respectively. The benefit of this is that we avoid the overhead from splitting the batch into separate attention streams and launching multiple kernels. The drawback is that we go through some extra cross-sequence computations that are technically not needed and are being thrown away b…