-
Notifications
You must be signed in to change notification settings - Fork 18
[CB] Reduce wastage in prefill compute and pad blocks in homogeneous continuous batching #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7df971e
5e1d468
c8a33de
49d92f5
fce5d91
df75e41
226db17
cdbaa45
c0fe359
21da7da
0830748
4f4706b
a33d5e5
5cb7413
ed08bb7
faf36f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -809,6 +809,14 @@ def _prepare_prompt( | |
if new_batch: | ||
self.tkv = block_padding | ||
|
||
# optimization: cut out fully padded blocks on the left | ||
n_pad_blocks = (self.tkv - max_prompt_len) // self.block_size | ||
block_padding -= n_pad_blocks * self.block_size | ||
left_padding = self.tkv - n_pad_blocks * self.block_size | ||
if n_pad_blocks > 0: | ||
logger.debug("Prefill reduced by %d blocks due to optimization.", | ||
n_pad_blocks) | ||
|
||
# Internal state is managed here. | ||
slot_mapping = [] | ||
|
||
|
@@ -819,13 +827,17 @@ def _prepare_prompt( | |
new_tokens = (request_data.sampling_params.max_tokens | ||
if request_data.sampling_params is not None else 0) | ||
n = self.tkv + new_tokens - 1 | ||
n_reserved_blocks = math.ceil(n / self.block_size) | ||
# subtract the padding blocks from the reserved blocks | ||
n_fully_padded_blocks = math.floor( | ||
(self.tkv - len(request_data.prompt_token_ids)) / | ||
self.block_size) | ||
n_reserved_blocks = math.ceil( | ||
n / self.block_size) - n_fully_padded_blocks | ||
self.req_ids2reserved_blocks[ | ||
request_data.req_id] = n_reserved_blocks | ||
|
||
# retrieve initial (unpadded) tokens | ||
prompt_tokens = request_data.prompt_token_ids | ||
left_padding = self.tkv - len(prompt_tokens) | ||
input_token_list.append( | ||
torch.tensor(prompt_tokens, | ||
dtype=torch.long, | ||
|
@@ -859,7 +871,7 @@ def _prepare_prompt( | |
sampling_params=sampling_params, | ||
generator=generator, | ||
output_token_ids=[], | ||
left_padding=left_padding) | ||
left_padding=self.tkv - max_prompt_len) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't understand this - previously we were left padding based on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also if possible a short visualization of how the tkv and blocks actually behave based on this optimization would be much appreciated as well! The original demo in which you showed how CB works with the table visualization helped me a lot There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahh, I understand your confusion.. First of all, as we have only one sequence per prefill, But this reminds me of something I wanted to do for a long time. It looks like we will stick to prefill batch size 1 for the foreseeable future therefore I will clean up this loop over 1 element to make things much more readable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did the refactoring in #335 , would be nice to get it in asap, do rebase this PR and make it more readable. @prashantgupta24 @maxdebayser |
||
self.requests[req_id] = req_state | ||
self.input_batch.add_request(req_state) | ||
self.prefill_batch.add_request(req_state) | ||
|
@@ -881,11 +893,15 @@ def _prepare_prompt( | |
block_table = None | ||
|
||
# get position ids and attention mask | ||
# applies left padding to align with tkv of current decode batch | ||
# applies left padding to ensure that the tkv of the new sequence | ||
# tkv_prefill aligns with tkv of current decode batch tkv_decode: | ||
# tkv_prefill % block_size = tkv_decode % block_size | ||
# and right padding to align with the next block boundary | ||
input_tokens, position_ids, mask =\ | ||
yannicks1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.pad_input_ids(input_token_list, min_pad_length=block_padding) | ||
mask = mask.unsqueeze(1) | ||
self.pad_input_ids(input_token_list, | ||
min_pad_left=left_padding, | ||
min_pad_right=block_padding) | ||
mask = mask.unsqueeze(1).contiguous() | ||
|
||
# not needed for prefill | ||
current_tkv_mask = None | ||
|
@@ -931,16 +947,28 @@ def _prepare_decode( | |
} | ||
req_ids = self.input_batch.sorted_requests_ids | ||
|
||
n_blocks = 0 | ||
for req_id in req_ids: | ||
# adding new blocks if needed | ||
if self.tkv % self.block_size == 0: | ||
self.req_ids2blocks[req_id].append(self.block_pool.popleft()) | ||
n_blocks = max(n_blocks, len(self.req_ids2blocks[req_id])) | ||
|
||
for req_id in req_ids: | ||
# TODO: Will this always just be one token ID if there's no spec | ||
# or jump decoding? | ||
|
||
req_state: CachedRequestState = self.requests[req_id] | ||
# adding new blocks if needed | ||
if self.tkv // self.block_size + 1 > len( | ||
self.req_ids2blocks[req_id]): | ||
self.req_ids2blocks[req_id].append(self.block_pool.popleft()) | ||
block_table.append(self.req_ids2blocks[req_id]) | ||
|
||
# filling block table with padding blocks to make it rectangular | ||
# Note: the padding block id 0 here is chosen arbitrarily, it can | ||
# be any allocated block id on the Sypre card (has to be in range | ||
# [0, self.n_blocks - 1]). Further it also be an block id that holds | ||
# actual KV cache for another (or the same) sequence. | ||
blocks = self.req_ids2blocks[req_id].copy() | ||
for i in range(n_blocks - len(self.req_ids2blocks[req_id])): | ||
blocks.appendleft(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps this is an incredibly stupid question, but why is it ok to use block id 0? Does it make a difference if it free (i.e. it's in self.block_pool) or not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on Josh's comment in an internal issue, "The only requirement when padding this is that we choose a block ID that exists in the pool of allotted block ids at server start (this way it will map to a real location in the memory space). In this case, when performing paged attention compute, the placeholder block will be part of compute (we will still take a performance hit during decode - at least until heterogeneous tkv is available), but will not take up any extra space as part of the block allotment at the time of server start." IIUC I think that's why block 0 makes sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although yeah I'm not sure what will happen if block 0 is actually being used by another request? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I managed to create a scenario where the block table looks like this for 2 requests:
From what I can see the output still is correct, although anyone looking at the block table could be confused There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it is exactly how @prashantgupta24 stated it. the padding id has to be real allocated memory, but it can be used or free. felt natural to use 0 for padding otherwise i would have to check if the padding id is in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it makes no difference if the block is free or not, it just has to have been allocated and this is always true for block 0, right? Thanks, that's all the questions I had. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exactly! |
||
|
||
block_table.append(blocks) | ||
|
||
# slot_mapping for all blocks of sequence | ||
start_slot = block_table[-1][-1] * self.block_size | ||
|
@@ -1027,29 +1055,23 @@ def reduce_left_padding(self) -> None: | |
for req in requests: | ||
req.left_padding -= offset | ||
|
||
# free blocks | ||
for _ in range(n_padded_blocks): | ||
freed_block_id = self.req_ids2blocks[req.req_id].popleft() | ||
logger.debug("Freeing block with id: %s", freed_block_id) | ||
self.block_pool.append(freed_block_id) | ||
self.req_ids2reserved_blocks[req.req_id] -= 1 | ||
|
||
# update tkv | ||
self.tkv -= offset | ||
|
||
def pad_input_ids( | ||
self, | ||
input_ids_list: list[torch.Tensor], | ||
min_pad_length: int = 0, | ||
min_pad_left: int = 0, | ||
min_pad_right: int = 0, | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
|
||
# left padding to align with tkv of current decode batch | ||
input_tokens_left, position_ids_left, mask_left =\ | ||
super().pad_input_ids(input_ids_list, min_pad_length=self.tkv) | ||
super().pad_input_ids(input_ids_list, min_pad_length=min_pad_left) | ||
|
||
# right padding to align with the next block boundary | ||
left_pad_len = input_tokens_left.shape[1] | ||
n_pads_right = min_pad_length - left_pad_len | ||
n_pads_right = min_pad_right - left_pad_len | ||
|
||
# set number of right pads for the next model forward pass: | ||
# need to be excluded before sampling tokens | ||
|
@@ -1059,7 +1081,7 @@ def pad_input_ids( | |
# apply right padding to input_tokens, position_ids and mask | ||
logger.info( | ||
"Right padding request of length %d tokens to %d tokens.", | ||
left_pad_len, min_pad_length) | ||
left_pad_len, min_pad_right) | ||
|
||
input_tokens_right = torch.tensor( | ||
[[self.pad_token_id for i in range(n_pads_right)]], | ||
|
Uh oh!
There was an error while loading. Please reload this page.