Skip to content

[TPU] fix kv_cache_update kernel block size choosing logic #21007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 16, 2025

Conversation

yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Jul 15, 2025

Purpose

Fix TPU CI test.

Test Plan

pytest -s -v tests/v1/tpu/test_basic.py

Test Result

Passed.

(Optional) Documentation Update

Signed-off-by: Chengji Yao <chengjiyao@google.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@yaochengji
Copy link
Collaborator Author

@bythew3i @tengyifei could you take a look?

@mergify mergify bot added v1 tpu Related to Google TPUs labels Jul 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a TPU CI test by adjusting the KV cache update logic. A critical issue was identified where the application could crash with a ZeroDivisionError when running attention-free models on TPU, due to page_size_bytes being zero. A fix has been suggested to handle this edge case.

Comment on lines 1869 to 1871
num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function will raise a ZeroDivisionError if page_size_bytes is 0. This can happen for attention-free models where head_size is 0, leading to page_size_bytes being 0.

To prevent a crash, we should handle this case by returning a default value, since this value is not used for attention-free models.

Suggested change
num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
if page_size_bytes == 0:
# For models without KV cache (e.g. attention-free), page size is 0.
# The return value is not used in this case, so we can return a default.
return 64
num_slices_per_block = vmem_limit // page_size_bytes

Signed-off-by: Chengji Yao <chengjiyao@google.com>
@yaochengji yaochengji added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 15, 2025
Signed-off-by: Chengji Yao <chengjiyao@google.com>
@yaochengji yaochengji force-pushed the chengji/fix-ci-test branch from 1c81513 to 4de8313 Compare July 15, 2025 22:54
@mgoin mgoin enabled auto-merge (squash) July 16, 2025 00:46
@vanbasten23
Copy link
Collaborator

Could you summarize what problem this PR is addressing?


kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (block_size * num_combined_kv_heads * padded_head_size *
kv_cache_dtype_bits // 8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is 8 in the kv_cache_dtype_bits // 8?

vmem_limit = 32 * 1024 * 1024
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
# calculate num_slices_per_block based on 16MB in case any register spills.
vmem_limit = 16 * 1024 * 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why we reduce the vmem_limit from 32mb to 16mb?

@mgoin mgoin merged commit 85431bd into vllm-project:main Jul 16, 2025
68 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants