Skip to content

[TPU] fix kv_cache_update kernel block size choosing logic #21007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,6 @@
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype) -> int:
"""Returns the size in bytes of one page of the KV cache."""
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
padded_head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
return block_size * num_kv_heads * padded_head_size * kv_cache_dtype.itemsize

Check failure on line 334 in vllm/v1/attention/backends/pallas.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/pallas.py:334:81: E501 Line too long (81 > 80)
5 changes: 3 additions & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# Conservative VMEM usage limit: 32 MiB
vmem_limit = 32 * 1024 * 1024
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
# calculate num_slices_per_block based on 16MB in case any register spills.
vmem_limit = 16 * 1024 * 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why we reduce the vmem_limit from 32mb to 16mb?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted in the comment, there might be register spills.

num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
Comment on lines 1869 to 1871
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function will raise a ZeroDivisionError if page_size_bytes is 0. This can happen for attention-free models where head_size is 0, leading to page_size_bytes being 0.

To prevent a crash, we should handle this case by returning a default value, since this value is not used for attention-free models.

Suggested change
num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
if page_size_bytes == 0:
# For models without KV cache (e.g. attention-free), page size is 0.
# The return value is not used in this case, so we can return a default.
return 64
num_slices_per_block = vmem_limit // page_size_bytes

Expand Down
Loading