-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[TPU] fix kv_cache_update kernel block size choosing logic #21007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Chengji Yao <chengjiyao@google.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
@bythew3i @tengyifei could you take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes a TPU CI test by adjusting the KV cache update logic. A critical issue was identified where the application could crash with a ZeroDivisionError
when running attention-free models on TPU, due to page_size_bytes
being zero. A fix has been suggested to handle this edge case.
num_slices_per_block = vmem_limit // page_size_bytes | ||
assert num_slices_per_block > 0, "Number of slices should be positive" | ||
num_slices_per_block = prev_power_of_2(num_slices_per_block) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function will raise a ZeroDivisionError
if page_size_bytes
is 0. This can happen for attention-free models where head_size
is 0, leading to page_size_bytes
being 0.
To prevent a crash, we should handle this case by returning a default value, since this value is not used for attention-free models.
num_slices_per_block = vmem_limit // page_size_bytes | |
assert num_slices_per_block > 0, "Number of slices should be positive" | |
num_slices_per_block = prev_power_of_2(num_slices_per_block) | |
if page_size_bytes == 0: | |
# For models without KV cache (e.g. attention-free), page size is 0. | |
# The return value is not used in this case, so we can return a default. | |
return 64 | |
num_slices_per_block = vmem_limit // page_size_bytes |
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1c81513
to
4de8313
Compare
Could you summarize what problem this PR is addressing? |
|
||
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) | ||
return (block_size * num_combined_kv_heads * padded_head_size * | ||
kv_cache_dtype_bits // 8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is 8 in the kv_cache_dtype_bits // 8
?
vmem_limit = 32 * 1024 * 1024 | ||
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we | ||
# calculate num_slices_per_block based on 16MB in case any register spills. | ||
vmem_limit = 16 * 1024 * 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why we reduce the vmem_limit from 32mb to 16mb?
Purpose
Fix TPU CI test.
Test Plan
pytest -s -v tests/v1/tpu/test_basic.py
Test Result
Passed.
(Optional) Documentation Update