-
Notifications
You must be signed in to change notification settings - Fork 18
[CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Or this can be done with
Now you are good to go 🚀 |
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
a7e7ae9
to
49d92f5
Compare
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
great news: This runs on Spyre 🎉 I just ran
cc: @tdoublep @JRosenkranz @joerunde @nikolaospapandreou @sducouedic |
bot:test |
bot:test |
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
bot:test |
6/7 tests passed on the Spyre card! looks like the failure is a known issue unrelated to this PR. 🥳 |
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
bot:test |
bot:test |
bot:test |
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
@@ -139,4 +139,4 @@ | |||
print("-----------------------------------") | |||
|
|||
if not any_differ: | |||
print("\nAll results match!\n") | |||
print("\nAll results match!\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (should revert changes to this file)
print("\nAll results match!\n") | |
print("\nAll results match!\n") | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth adding some debug logs to the optimizations?
# filling block table with padding blocks (reusing id 0) | ||
blocks = self.req_ids2blocks[req_id].copy() | ||
for i in range(n_blocks - len(self.req_ids2blocks[req_id])): | ||
blocks.appendleft(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this is an incredibly stupid question, but why is it ok to use block id 0? Does it make a difference if it free (i.e. it's in self.block_pool) or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on Josh's comment in an internal issue,
"The only requirement when padding this is that we choose a block ID that exists in the pool of allotted block ids at server start (this way it will map to a real location in the memory space). In this case, when performing paged attention compute, the placeholder block will be part of compute (we will still take a performance hit during decode - at least until heterogeneous tkv is available), but will not take up any extra space as part of the block allotment at the time of server start."
IIUC I think that's why block 0 makes sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although yeah I'm not sure what will happen if block 0 is actually being used by another request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to create a scenario where the block table looks like this for 2 requests:
block table: [deque([0, 2]), deque([0, 4])]
From what I can see the output still is correct, although anyone looking at the block table could be confused
@@ -884,8 +892,10 @@ def _prepare_prompt( | |||
# applies left padding to align with tkv of current decode batch | |||
# and right padding to align with the next block boundary | |||
input_tokens, position_ids, mask =\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still processing the changes, but wondering if the comment above needs rewording...
[CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching
implement optimization idea by @JRosenkranz: do prefill only on next multiple of block size and then during decode pad with (valid) block id. Reduces computes for prefill and does not waist any valid blocks ids if whole blocks are padded to make tkv homogeneous.
solves #255