Skip to content

[CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

yannicks1
Copy link
Collaborator

@yannicks1 yannicks1 commented Jun 24, 2025

[CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching

implement optimization idea by @JRosenkranz: do prefill only on next multiple of block size and then during decode pad with (valid) block id. Reduces computes for prefill and does not waist any valid blocks ids if whole blocks are padded to make tkv homogeneous.

solves #255

Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Copy link

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, first install the linting requirements, then run format.sh and commit the changes. This can be done with uv directly:

uv sync --frozen --group lint --active --inexact

Or this can be done with pip:

uv pip compile --group lint > requirements-lint.txt
pip install -r requirements-lint.txt
bash format.sh

Now you are good to go 🚀

Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
@yannicks1 yannicks1 self-assigned this Jun 26, 2025
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
@yannicks1 yannicks1 force-pushed the ysc-homog-tkv-opt-joshua branch 2 times, most recently from a7e7ae9 to 49d92f5 Compare June 27, 2025 22:40
yannicks1 and others added 3 commits June 30, 2025 08:02
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
@yannicks1
Copy link
Collaborator Author

great news: This runs on Spyre 🎉

I just ran cb_spyre_inference.py which (with the parameters on this branch) exploits all functionality:

cc: @tdoublep @JRosenkranz @joerunde @nikolaospapandreou @sducouedic

@yannicks1
Copy link
Collaborator Author

bot:test
TEST_FILE=tests/e2e/test_spyre_cb.py MARKERS="spyre"

@yannicks1
Copy link
Collaborator Author

bot:test
TEST_FILE=tests/e2e/test_spyre_cb.py MARKERS="spyre"

yannicks1 and others added 4 commits July 10, 2025 09:42
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
@yannicks1
Copy link
Collaborator Author

bot:test
TEST_FILE=tests/e2e/test_spyre_cb.py MARKERS="spyre"

@yannicks1
Copy link
Collaborator Author

6/7 tests passed on the Spyre card! looks like the failure is a known issue unrelated to this PR. 🥳

Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
@yannicks1
Copy link
Collaborator Author

bot:test
TEST_FILE=tests/e2e/test_spyre_cb.py MARKERS="spyre"

@yannicks1
Copy link
Collaborator Author

bot:test
MARKERS="spyre"

@yannicks1
Copy link
Collaborator Author

bot:test
TEST_FILE=tests/e2e/test_spyre_cb_scheduler_step.py MARKERS="spyre"

Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
@yannicks1 yannicks1 changed the title [do not merge][CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching [CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching Jul 23, 2025
@yannicks1 yannicks1 marked this pull request as ready for review July 23, 2025 16:24
@@ -139,4 +139,4 @@
print("-----------------------------------")

if not any_differ:
print("\nAll results match!\n")
print("\nAll results match!\n")
Copy link
Collaborator

@prashantgupta24 prashantgupta24 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (should revert changes to this file)

Suggested change
print("\nAll results match!\n")
print("\nAll results match!\n")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth adding some debug logs to the optimizations?

# filling block table with padding blocks (reusing id 0)
blocks = self.req_ids2blocks[req_id].copy()
for i in range(n_blocks - len(self.req_ids2blocks[req_id])):
blocks.appendleft(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this is an incredibly stupid question, but why is it ok to use block id 0? Does it make a difference if it free (i.e. it's in self.block_pool) or not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on Josh's comment in an internal issue,

"The only requirement when padding this is that we choose a block ID that exists in the pool of allotted block ids at server start (this way it will map to a real location in the memory space). In this case, when performing paged attention compute, the placeholder block will be part of compute (we will still take a performance hit during decode - at least until heterogeneous tkv is available), but will not take up any extra space as part of the block allotment at the time of server start."

IIUC I think that's why block 0 makes sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although yeah I'm not sure what will happen if block 0 is actually being used by another request?

Copy link
Collaborator

@prashantgupta24 prashantgupta24 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to create a scenario where the block table looks like this for 2 requests:

block table: [deque([0, 2]), deque([0, 4])]

From what I can see the output still is correct, although anyone looking at the block table could be confused

@@ -884,8 +892,10 @@ def _prepare_prompt(
# applies left padding to align with tkv of current decode batch
# and right padding to align with the next block boundary
input_tokens, position_ids, mask =\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still processing the changes, but wondering if the comment above needs rewording...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants