Skip to content

[TPU] support fp8 kv cache quantization #19292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,11 +1363,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if current_platform.is_rocm() or (
current_platform.is_cuda()
and current_platform.is_device_capability(100)) or (
current_platform.device_name
== "hpu"): # handle hpu also for OOT platform
if current_platform.is_rocm() or (current_platform.is_cuda(
) and current_platform.is_device_capability(100)) or (
current_platform.device_name == "hpu"
) or current_platform.is_tpu(): # handle hpu also for OOT platform
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (
Expand Down
4 changes: 3 additions & 1 deletion vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ class TpuPlatform(Platform):
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla"

supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
supported_quantization: list[str] = [
"fp8", "tpu_int8", "compressed-tensors"
]

additional_env_vars: list[str] = [
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
Expand Down
61 changes: 53 additions & 8 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128

# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}


class PallasAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -156,8 +169,6 @@ def __init__(
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")

Expand All @@ -170,6 +181,14 @@ def __init__(
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.kv_cache_quantized_dtype = None
if kv_cache_dtype != "auto":
if tpu_version < 5:
raise NotImplementedError(
"FP8 KV cache dtype is only supported when TPU version"
" is 5 or higher.")
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
kv_cache_dtype.lower().strip())

def forward(
self,
Expand Down Expand Up @@ -204,7 +223,6 @@ def forward(
output = torch.ones_like(query)
return output

assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
Expand All @@ -225,10 +243,21 @@ def forward(
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key, value, kv_cache, slot_mapping,
key,
value,
kv_cache,
slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices)

attn_metadata.num_kv_update_slices,
self.kv_cache_quantized_dtype,
layer._k_scale_float,
layer._v_scale_float,
)

if self.kv_cache_quantized_dtype is not None and (
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0):
raise ValueError(
"k_scale_float and v_scale_float must be non-zero")
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
Expand All @@ -246,6 +275,8 @@ def forward(
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)

if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
Expand All @@ -261,18 +292,32 @@ def write_to_kv_cache(
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> None:
""" Write the key and values to the KV cache.

Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size,
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT

if kv_cache_quantized_dtype is not None:
dtype_info = torch.finfo(kv_cache_quantized_dtype)
key = key.to(torch.float32) / k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key = torch.clamp(key, dtype_info.min, dtype_info.max)
key = key.to(kv_cache_quantized_dtype)
value = value.to(torch.float32) / v_scale
value = torch.clamp(value, dtype_info.min, dtype_info.max)
value = value.to(kv_cache_quantized_dtype)

kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)

Expand Down
11 changes: 6 additions & 5 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
is_pin_memory_available, prev_power_of_2)
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
prev_power_of_2)
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
PallasAttentionBackend,
PallasMetadata,
get_page_size_bytes)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
Expand Down Expand Up @@ -140,11 +141,11 @@ def __init__(
if cache_config.cache_dtype == "auto":
model_dtype = self.dtype
if isinstance(model_dtype, str):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
self.kv_cache_dtype = model_dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype

Expand Down