Skip to content

Commit 1c81513

Browse files
committed
[TPU] address comments
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent f805446 commit 1c81513

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,12 @@ def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
368368
"""Returns the size in bytes of one page of the KV cache."""
369369
padded_head_size = cdiv(head_size,
370370
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
371+
num_combined_kv_heads = num_kv_heads * 2
372+
373+
# NOTE: for the implicit padding in XLA
371374
packing = get_dtype_packing(kv_cache_dtype)
372-
# for the implicit padding in XLA
373-
padded_head_size = max(padded_head_size, packing)
375+
num_combined_kv_heads = cdiv(num_kv_heads * 2, packing) * packing
376+
374377
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
375-
num_combined_kv_heads = num_kv_heads * 2
376378
return (block_size * num_combined_kv_heads * padded_head_size *
377379
kv_cache_dtype_bits // 8)

0 commit comments

Comments
 (0)