From 50e7f2d16d6957cc62d05f67746b2d0758f61a04 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 15 Jul 2025 18:06:24 +0000 Subject: [PATCH 1/3] [TPU] fix kv cache update kernel block size Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 4 +++- vllm/v1/worker/tpu_model_runner.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 32ef5dc2e36..313c6752fed 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -329,4 +329,6 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype) -> int: """Returns the size in bytes of one page of the KV cache.""" - return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize + padded_head_size = cdiv(head_size, + TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + return block_size * num_kv_heads * padded_head_size * kv_cache_dtype.itemsize diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 83a80bd865b..a1d66d1fa40 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: out of scalar registers. Thus this function will limit the number of slices to 64. """ - # Conservative VMEM usage limit: 32 MiB - vmem_limit = 32 * 1024 * 1024 + # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we + # calculate num_slices_per_block based on 16MB in case any register spills. + vmem_limit = 16 * 1024 * 1024 num_slices_per_block = vmem_limit // page_size_bytes assert num_slices_per_block > 0, "Number of slices should be positive" num_slices_per_block = prev_power_of_2(num_slices_per_block) From f805446b34fda05235a2fbab074ea5156bab984f Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 15 Jul 2025 19:47:58 +0000 Subject: [PATCH 2/3] [TPU] address comments Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 45 +++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 313c6752fed..74ddedddb30 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -326,9 +326,52 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, return kv_cache +# We can move this function to a common utils file if it's also useful for other +# hardware. +def dtype_bits(dtype: torch.dtype): + if dtype.is_floating_point: + try: + return torch.finfo(dtype).bits + except TypeError: + pass + elif dtype.is_complex: + if dtype is torch.complex32: + return 32 + elif dtype is torch.complex64: + return 64 + elif dtype is torch.complex128: + return 128 + else: + try: + return torch.iinfo(dtype).bits + # torch.iinfo cannot support int4, int2, bits8... + except TypeError: + pass + str_dtype = str(dtype) + # support torch.int4, torch.int5, torch.uint5... + if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"): + return int(str_dtype[-1]) + raise TypeError(f"Getting the bit width of {dtype} is not supported") + + +def get_dtype_packing(dtype): + bits = dtype_bits(dtype) + if 32 % bits != 0: + raise ValueError( + f"The bit width must be divisible by 32, but got bits={bits}, " + "dtype={dtype}") + return 32 // bits + + def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype) -> int: """Returns the size in bytes of one page of the KV cache.""" padded_head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - return block_size * num_kv_heads * padded_head_size * kv_cache_dtype.itemsize + packing = get_dtype_packing(kv_cache_dtype) + # for the implicit padding in XLA + padded_head_size = max(padded_head_size, packing) + kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) + num_combined_kv_heads = num_kv_heads * 2 + return (block_size * num_combined_kv_heads * padded_head_size * + kv_cache_dtype_bits // 8) From 4de8313908bbc39a791e100a54a3d46887f7aeb5 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 15 Jul 2025 22:40:04 +0000 Subject: [PATCH 3/3] [TPU] address comments Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 74ddedddb30..b7fc1ffeb65 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -368,10 +368,12 @@ def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, """Returns the size in bytes of one page of the KV cache.""" padded_head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + num_combined_kv_heads = num_kv_heads * 2 + + # NOTE: for the implicit padding in XLA packing = get_dtype_packing(kv_cache_dtype) - # for the implicit padding in XLA - padded_head_size = max(padded_head_size, packing) + num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing + kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) - num_combined_kv_heads = num_kv_heads * 2 return (block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8)