diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 32ef5dc2e36..b7fc1ffeb65 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, return kv_cache +# We can move this function to a common utils file if it's also useful for other +# hardware. +def dtype_bits(dtype: torch.dtype): + if dtype.is_floating_point: + try: + return torch.finfo(dtype).bits + except TypeError: + pass + elif dtype.is_complex: + if dtype is torch.complex32: + return 32 + elif dtype is torch.complex64: + return 64 + elif dtype is torch.complex128: + return 128 + else: + try: + return torch.iinfo(dtype).bits + # torch.iinfo cannot support int4, int2, bits8... + except TypeError: + pass + str_dtype = str(dtype) + # support torch.int4, torch.int5, torch.uint5... + if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"): + return int(str_dtype[-1]) + raise TypeError(f"Getting the bit width of {dtype} is not supported") + + +def get_dtype_packing(dtype): + bits = dtype_bits(dtype) + if 32 % bits != 0: + raise ValueError( + f"The bit width must be divisible by 32, but got bits={bits}, " + "dtype={dtype}") + return 32 // bits + + def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype) -> int: """Returns the size in bytes of one page of the KV cache.""" - return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize + padded_head_size = cdiv(head_size, + TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + num_combined_kv_heads = num_kv_heads * 2 + + # NOTE: for the implicit padding in XLA + packing = get_dtype_packing(kv_cache_dtype) + num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing + + kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) + return (block_size * num_combined_kv_heads * padded_head_size * + kv_cache_dtype_bits // 8) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 83a80bd865b..a1d66d1fa40 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1863,8 +1863,9 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: out of scalar registers. Thus this function will limit the number of slices to 64. """ - # Conservative VMEM usage limit: 32 MiB - vmem_limit = 32 * 1024 * 1024 + # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we + # calculate num_slices_per_block based on 16MB in case any register spills. + vmem_limit = 16 * 1024 * 1024 num_slices_per_block = vmem_limit // page_size_bytes assert num_slices_per_block > 0, "Number of slices should be positive" num_slices_per_block = prev_power_of_2(num_slices_per_block)