Skip to content

[CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 34 additions & 32 deletions tests/e2e/test_spyre_cb_scheduler_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,26 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
},
{
# Prefill sequence 2
# total blocks in use: 4 - 2 + 2 = 4
# total blocks in use: 4 - 2 + 1 = 3
"step": 67,
"tkv": 128, # Tkv doesn't increase because it is a prefill
"waiting": [],
"running": ["2", "1"],
"request_outputs": ["2"],
# 5 - 2 (seq 0) + 3 (prefill (2 blocks) + decodes (1 block))
"n_reserved_blocks": 6,
"n_used_blocks": 4
# 5 - 2 (seq 0) + 2 (prefill (1 block) + decodes (1 block))
"n_reserved_blocks": 5,
"n_used_blocks": 3
},
{
# Decode sequences 1 and 2
# total blocks in use: 4 + 2 = 6
# total blocks in use: 3 + 2 = 5
"step": 68,
"tkv": 129,
"waiting": [],
"running": ["2", "1"],
"request_outputs": ["2", "1"],
"n_reserved_blocks": 6,
"n_used_blocks": 6
"n_reserved_blocks": 5,
"n_used_blocks": 5
},
{
# Sequence 1 finishes at step 69
Expand All @@ -122,18 +122,18 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
"running": ["2"],
"request_outputs": ["2", "1"],
"finished_requests": ["1"],
"n_reserved_blocks": 6,
"n_used_blocks": 6
"n_reserved_blocks": 5,
"n_used_blocks": 5
},
{
# Decode sequence 2
# total blocks in use: 6 - 3 - 1 (remove padded block) = 2
# total blocks in use: 5 - 3 = 2
"step": 70,
"tkv": 67, # tkv is reset by 64 due to removing the padded block
"waiting": [],
"running": ["2"],
"request_outputs": ["2"],
# 6 - 3 (seq 1 left) - 1 (removing the padded block)
# 5 - 3 (seq 1 left)
"n_reserved_blocks": 2,
"n_used_blocks": 2
},
Expand Down Expand Up @@ -256,15 +256,15 @@ def test_prompts_misaligned_with_tkv_boundaries(
},
{
# Prefill sequence 2
# total blocks in use: 4 - 2 + 2 = 4
# total blocks in use: 4 - 2 + 1 = 3
"step": 59,
"tkv": 120, # Tkv doesn't increase because it is a prefill
"waiting": [],
"running": ["2", "1"],
"request_outputs": ["2"],
# 5 - 2 (seq 0) + 2 (prefill (2 block) + 8 decodes in 2nd block)
"n_reserved_blocks": 5,
"n_used_blocks": 4
# 5 - 2 (seq 0) + 1 (prefill (1 block) + 8 decodes in 1st block)
"n_reserved_blocks": 4,
"n_used_blocks": 3
},
{
# Decode sequences 1 and 2
Expand All @@ -273,8 +273,8 @@ def test_prompts_misaligned_with_tkv_boundaries(
"waiting": [],
"running": ["2", "1"],
"request_outputs": ["2", "1"],
"n_reserved_blocks": 5,
"n_used_blocks": 4
"n_reserved_blocks": 4,
"n_used_blocks": 3
},
{
# Sequence 2 finishes at step 68
Expand All @@ -285,18 +285,18 @@ def test_prompts_misaligned_with_tkv_boundaries(
"running": ["1"],
"request_outputs": ["2", "1"],
"finished_requests": ["2"],
"n_reserved_blocks": 5,
"n_used_blocks": 4
"n_reserved_blocks": 4,
"n_used_blocks": 3
},
{
# Decode sequences 1
# total blocks in use: 4 - 2 + 1 = 3
# total blocks in use: 3 - 1 + 1 = 3
"step": 68,
"tkv": 129,
"waiting": [],
"running": ["1"],
"request_outputs": ["1"],
"n_reserved_blocks": 3, # 5 - 2 (seq 2)
"n_reserved_blocks": 3, # 4 - 1 (seq 2)
"n_used_blocks": 3
},
{
Expand Down Expand Up @@ -667,26 +667,26 @@ def test_requested_tokens_not_fitting_remaining_space(
},
{
# Prefill sequence 1
# total blocks in use: 2 + 2
# total blocks in use: 2 + 1
"step": 2,
"tkv": 128,
"waiting": ["2"],
"running": ["1", "0"],
"request_outputs": ["1"],
# prefill (2 blocks) + 56 decodes (1 block)
"n_reserved_blocks": 7,
"n_used_blocks": 4
# prefill (1 block) + 56 decodes (1 block)
"n_reserved_blocks": 6,
"n_used_blocks": 3
},
{
# Decode sequences 0 and 1
# total blocks in use: 4 + 2 (decodes)
# total blocks in use: 3 + 2 (decodes)
"step": 3,
"tkv": 129,
"waiting": ["2"],
"running": ["1", "0"],
"request_outputs": ["1", "0"],
"n_reserved_blocks": 7,
"n_used_blocks": 6
"n_reserved_blocks": 6,
"n_used_blocks": 5
},
{
# Sequence 1 finishes at step 58
Expand All @@ -697,19 +697,19 @@ def test_requested_tokens_not_fitting_remaining_space(
"running": ["0"],
"request_outputs": ["1", "0"],
"finished_requests": ["1"],
"n_reserved_blocks": 7,
"n_used_blocks": 6
"n_reserved_blocks": 6,
"n_used_blocks": 5
},
{
# Decode sequence 0
# Cannot prefill sequence 2: 185 + 80 = 265 > 256
# total blocks in use: 6 - 3 = 3
# total blocks in use: 5 - 2 = 3
"step": 59,
"tkv": 185,
"waiting": ["2"],
"running": ["0"],
"request_outputs": ["0"],
"n_reserved_blocks": 4, # 7 - 3 (seq 1)
"n_reserved_blocks": 4, # 6 - 2 (seq 1)
"n_used_blocks": 3
},
{
Expand Down Expand Up @@ -830,6 +830,7 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
# total number of blocks needed if scheduled together : 4 * (1 + 1) = 8
available_blocks = 8
max_num_seqs = 4

checked_steps = [
{
"step": 0,
Expand Down Expand Up @@ -962,6 +963,7 @@ def test_requests_use_more_than_available_blocks(
# total number of blocks needed if scheduled together : 4 * (1 + 1) = 8
available_blocks = 4
max_num_seqs = 4

checked_steps = [
{
"step": 0,
Expand Down
68 changes: 45 additions & 23 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,14 @@ def _prepare_prompt(
if new_batch:
self.tkv = block_padding

# optimization: cut out fully padded blocks on the left
n_pad_blocks = (self.tkv - max_prompt_len) // self.block_size
block_padding -= n_pad_blocks * self.block_size
left_padding = self.tkv - n_pad_blocks * self.block_size
if n_pad_blocks > 0:
logger.debug("Prefill reduced by %d blocks due to optimization.",
n_pad_blocks)

# Internal state is managed here.
slot_mapping = []

Expand All @@ -819,13 +827,17 @@ def _prepare_prompt(
new_tokens = (request_data.sampling_params.max_tokens
if request_data.sampling_params is not None else 0)
n = self.tkv + new_tokens - 1
n_reserved_blocks = math.ceil(n / self.block_size)
# subtract the padding blocks from the reserved blocks
n_fully_padded_blocks = math.floor(
(self.tkv - len(request_data.prompt_token_ids)) /
self.block_size)
n_reserved_blocks = math.ceil(
n / self.block_size) - n_fully_padded_blocks
self.req_ids2reserved_blocks[
request_data.req_id] = n_reserved_blocks

# retrieve initial (unpadded) tokens
prompt_tokens = request_data.prompt_token_ids
left_padding = self.tkv - len(prompt_tokens)
input_token_list.append(
torch.tensor(prompt_tokens,
dtype=torch.long,
Expand Down Expand Up @@ -859,7 +871,7 @@ def _prepare_prompt(
sampling_params=sampling_params,
generator=generator,
output_token_ids=[],
left_padding=left_padding)
left_padding=self.tkv - max_prompt_len)
Copy link
Collaborator

@prashantgupta24 prashantgupta24 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand this - previously we were left padding based on self.tkv - len(prompt_tokens) per request, but now why we do it based on only the max_prompt length? A comment would be appreciated 🙏

Copy link
Collaborator

@prashantgupta24 prashantgupta24 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if possible a short visualization of how the tkv and blocks actually behave based on this optimization would be much appreciated as well! The original demo in which you showed how CB works with the table visualization helped me a lot

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, I understand your confusion.. First of all, as we have only one sequence per prefill, max_prompt_len == len(prompt_tokens), so what you are saying is also valid.
The reason I changed it was that I introduced the variable left_padding in line 815 outside the loop. So I had to remove it inside the loop and took self.tkv - max_prompt_len whereas I could also have taken self.tkv - len(prompt_tokens)

But this reminds me of something I wanted to do for a long time. It looks like we will stick to prefill batch size 1 for the foreseeable future therefore I will clean up this loop over 1 element to make things much more readable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did the refactoring in #335 , would be nice to get it in asap, do rebase this PR and make it more readable. @prashantgupta24 @maxdebayser

self.requests[req_id] = req_state
self.input_batch.add_request(req_state)
self.prefill_batch.add_request(req_state)
Expand All @@ -881,11 +893,15 @@ def _prepare_prompt(
block_table = None

# get position ids and attention mask
# applies left padding to align with tkv of current decode batch
# applies left padding to ensure that the tkv of the new sequence
# tkv_prefill aligns with tkv of current decode batch tkv_decode:
# tkv_prefill % block_size = tkv_decode % block_size
# and right padding to align with the next block boundary
input_tokens, position_ids, mask =\
self.pad_input_ids(input_token_list, min_pad_length=block_padding)
mask = mask.unsqueeze(1)
self.pad_input_ids(input_token_list,
min_pad_left=left_padding,
min_pad_right=block_padding)
mask = mask.unsqueeze(1).contiguous()

# not needed for prefill
current_tkv_mask = None
Expand Down Expand Up @@ -931,16 +947,28 @@ def _prepare_decode(
}
req_ids = self.input_batch.sorted_requests_ids

n_blocks = 0
for req_id in req_ids:
# adding new blocks if needed
if self.tkv % self.block_size == 0:
self.req_ids2blocks[req_id].append(self.block_pool.popleft())
n_blocks = max(n_blocks, len(self.req_ids2blocks[req_id]))

for req_id in req_ids:
# TODO: Will this always just be one token ID if there's no spec
# or jump decoding?

req_state: CachedRequestState = self.requests[req_id]
# adding new blocks if needed
if self.tkv // self.block_size + 1 > len(
self.req_ids2blocks[req_id]):
self.req_ids2blocks[req_id].append(self.block_pool.popleft())
block_table.append(self.req_ids2blocks[req_id])

# filling block table with padding blocks to make it rectangular
# Note: the padding block id 0 here is chosen arbitrarily, it can
# be any allocated block id on the Sypre card (has to be in range
# [0, self.n_blocks - 1]). Further it also be an block id that holds
# actual KV cache for another (or the same) sequence.
blocks = self.req_ids2blocks[req_id].copy()
for i in range(n_blocks - len(self.req_ids2blocks[req_id])):
blocks.appendleft(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this is an incredibly stupid question, but why is it ok to use block id 0? Does it make a difference if it free (i.e. it's in self.block_pool) or not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on Josh's comment in an internal issue,

"The only requirement when padding this is that we choose a block ID that exists in the pool of allotted block ids at server start (this way it will map to a real location in the memory space). In this case, when performing paged attention compute, the placeholder block will be part of compute (we will still take a performance hit during decode - at least until heterogeneous tkv is available), but will not take up any extra space as part of the block allotment at the time of server start."

IIUC I think that's why block 0 makes sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although yeah I'm not sure what will happen if block 0 is actually being used by another request?

Copy link
Collaborator

@prashantgupta24 prashantgupta24 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to create a scenario where the block table looks like this for 2 requests:

block table: [deque([0, 2]), deque([0, 4])]

From what I can see the output still is correct, although anyone looking at the block table could be confused

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is exactly how @prashantgupta24 stated it. the padding id has to be real allocated memory, but it can be used or free. felt natural to use 0 for padding otherwise i would have to check if the padding id is in the self.block_pool (e.g. <= self.n_blocks - 1). Reserving an extra physical mem block for padding feels like an overkill/waste to me. But I will add some comments to make this clearer:)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it makes no difference if the block is free or not, it just has to have been allocated and this is always true for block 0, right? Thanks, that's all the questions I had.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly!


block_table.append(blocks)

# slot_mapping for all blocks of sequence
start_slot = block_table[-1][-1] * self.block_size
Expand Down Expand Up @@ -1027,29 +1055,23 @@ def reduce_left_padding(self) -> None:
for req in requests:
req.left_padding -= offset

# free blocks
for _ in range(n_padded_blocks):
freed_block_id = self.req_ids2blocks[req.req_id].popleft()
logger.debug("Freeing block with id: %s", freed_block_id)
self.block_pool.append(freed_block_id)
self.req_ids2reserved_blocks[req.req_id] -= 1

# update tkv
self.tkv -= offset

def pad_input_ids(
self,
input_ids_list: list[torch.Tensor],
min_pad_length: int = 0,
min_pad_left: int = 0,
min_pad_right: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

# left padding to align with tkv of current decode batch
input_tokens_left, position_ids_left, mask_left =\
super().pad_input_ids(input_ids_list, min_pad_length=self.tkv)
super().pad_input_ids(input_ids_list, min_pad_length=min_pad_left)

# right padding to align with the next block boundary
left_pad_len = input_tokens_left.shape[1]
n_pads_right = min_pad_length - left_pad_len
n_pads_right = min_pad_right - left_pad_len

# set number of right pads for the next model forward pass:
# need to be excluded before sampling tokens
Expand All @@ -1059,7 +1081,7 @@ def pad_input_ids(
# apply right padding to input_tokens, position_ids and mask
logger.info(
"Right padding request of length %d tokens to %d tokens.",
left_pad_len, min_pad_length)
left_pad_len, min_pad_right)

input_tokens_right = torch.tensor(
[[self.pad_token_id for i in range(n_pads_right)]],
Expand Down
Loading