-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[TPU] fix kv_cache_update kernel block size choosing logic #21007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: | |||||||||||||||||||
out of scalar registers. Thus this function will limit the number of | ||||||||||||||||||||
slices to 64. | ||||||||||||||||||||
""" | ||||||||||||||||||||
# Conservative VMEM usage limit: 32 MiB | ||||||||||||||||||||
vmem_limit = 32 * 1024 * 1024 | ||||||||||||||||||||
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we | ||||||||||||||||||||
# calculate num_slices_per_block based on 16MB in case any register spills. | ||||||||||||||||||||
vmem_limit = 16 * 1024 * 1024 | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why we reduce the vmem_limit from 32mb to 16mb? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noted in the comment, there might be register spills. |
||||||||||||||||||||
num_slices_per_block = vmem_limit // page_size_bytes | ||||||||||||||||||||
assert num_slices_per_block > 0, "Number of slices should be positive" | ||||||||||||||||||||
num_slices_per_block = prev_power_of_2(num_slices_per_block) | ||||||||||||||||||||
Comment on lines
1869
to
1871
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function will raise a To prevent a crash, we should handle this case by returning a default value, since this value is not used for attention-free models.
Suggested change
|
||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is 8 in the
kv_cache_dtype_bits // 8
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 bytes = 8 bits