Skip to content

Commit f805446

Browse files
committed
[TPU] address comments
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 50e7f2d commit f805446

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,52 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
326326
return kv_cache
327327

328328

329+
# We can move this function to a common utils file if it's also useful for other
330+
# hardware.
331+
def dtype_bits(dtype: torch.dtype):
332+
if dtype.is_floating_point:
333+
try:
334+
return torch.finfo(dtype).bits
335+
except TypeError:
336+
pass
337+
elif dtype.is_complex:
338+
if dtype is torch.complex32:
339+
return 32
340+
elif dtype is torch.complex64:
341+
return 64
342+
elif dtype is torch.complex128:
343+
return 128
344+
else:
345+
try:
346+
return torch.iinfo(dtype).bits
347+
# torch.iinfo cannot support int4, int2, bits8...
348+
except TypeError:
349+
pass
350+
str_dtype = str(dtype)
351+
# support torch.int4, torch.int5, torch.uint5...
352+
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
353+
return int(str_dtype[-1])
354+
raise TypeError(f"Getting the bit width of {dtype} is not supported")
355+
356+
357+
def get_dtype_packing(dtype):
358+
bits = dtype_bits(dtype)
359+
if 32 % bits != 0:
360+
raise ValueError(
361+
f"The bit width must be divisible by 32, but got bits={bits}, "
362+
"dtype={dtype}")
363+
return 32 // bits
364+
365+
329366
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
330367
kv_cache_dtype: torch.dtype) -> int:
331368
"""Returns the size in bytes of one page of the KV cache."""
332369
padded_head_size = cdiv(head_size,
333370
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
334-
return block_size * num_kv_heads * padded_head_size * kv_cache_dtype.itemsize
371+
packing = get_dtype_packing(kv_cache_dtype)
372+
# for the implicit padding in XLA
373+
padded_head_size = max(padded_head_size, packing)
374+
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
375+
num_combined_kv_heads = num_kv_heads * 2
376+
return (block_size * num_combined_kv_heads * padded_head_size *
377+
kv_cache_dtype_bits // 8)

0 commit comments

Comments
 (0)