Skip to content

Commit 85431bd

Browse files
authored
[TPU] fix kv_cache_update kernel block size choosing logic (#21007)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent c11013d commit 85431bd

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
326326
return kv_cache
327327

328328

329+
# We can move this function to a common utils file if it's also useful for other
330+
# hardware.
331+
def dtype_bits(dtype: torch.dtype):
332+
if dtype.is_floating_point:
333+
try:
334+
return torch.finfo(dtype).bits
335+
except TypeError:
336+
pass
337+
elif dtype.is_complex:
338+
if dtype is torch.complex32:
339+
return 32
340+
elif dtype is torch.complex64:
341+
return 64
342+
elif dtype is torch.complex128:
343+
return 128
344+
else:
345+
try:
346+
return torch.iinfo(dtype).bits
347+
# torch.iinfo cannot support int4, int2, bits8...
348+
except TypeError:
349+
pass
350+
str_dtype = str(dtype)
351+
# support torch.int4, torch.int5, torch.uint5...
352+
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
353+
return int(str_dtype[-1])
354+
raise TypeError(f"Getting the bit width of {dtype} is not supported")
355+
356+
357+
def get_dtype_packing(dtype):
358+
bits = dtype_bits(dtype)
359+
if 32 % bits != 0:
360+
raise ValueError(
361+
f"The bit width must be divisible by 32, but got bits={bits}, "
362+
"dtype={dtype}")
363+
return 32 // bits
364+
365+
329366
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
330367
kv_cache_dtype: torch.dtype) -> int:
331368
"""Returns the size in bytes of one page of the KV cache."""
332-
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
369+
padded_head_size = cdiv(head_size,
370+
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
371+
num_combined_kv_heads = num_kv_heads * 2
372+
373+
# NOTE: for the implicit padding in XLA
374+
packing = get_dtype_packing(kv_cache_dtype)
375+
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
376+
377+
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
378+
return (block_size * num_combined_kv_heads * padded_head_size *
379+
kv_cache_dtype_bits // 8)

vllm/v1/worker/tpu_model_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
18631863
out of scalar registers. Thus this function will limit the number of
18641864
slices to 64.
18651865
"""
1866-
# Conservative VMEM usage limit: 32 MiB
1867-
vmem_limit = 32 * 1024 * 1024
1866+
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
1867+
# calculate num_slices_per_block based on 16MB in case any register spills.
1868+
vmem_limit = 16 * 1024 * 1024
18681869
num_slices_per_block = vmem_limit // page_size_bytes
18691870
assert num_slices_per_block > 0, "Number of slices should be positive"
18701871
num_slices_per_block = prev_power_of_2(num_slices_per_block)

0 commit comments

Comments
 (0)