diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ae5eb46fa96..638b9e67a08 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1363,11 +1363,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: and not envs.is_set("VLLM_ATTENTION_BACKEND") ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False - if current_platform.is_rocm() or ( - current_platform.is_cuda() - and current_platform.is_device_capability(100)) or ( - current_platform.device_name - == "hpu"): # handle hpu also for OOT platform + if current_platform.is_rocm() or (current_platform.is_cuda( + ) and current_platform.is_device_capability(100)) or ( + current_platform.device_name == "hpu" + ) or current_platform.is_tpu(): # handle hpu also for OOT platform supported = True elif fp8_attention and will_use_fa: from vllm.attention.utils.fa_utils import ( diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 5ec3be908e7..febc6ae4662 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -35,7 +35,9 @@ class TpuPlatform(Platform): device_control_env_var: str = "TPU_VISIBLE_CHIPS" simple_compile_backend: str = "openxla" - supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] + supported_quantization: list[str] = [ + "fp8", "tpu_int8", "compressed-tensors" + ] additional_env_vars: list[str] = [ "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index b7fc1ffeb65..4c03090dd20 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -24,6 +24,19 @@ # TPU requires the head size to be a multiple of 128. TPU_HEAD_SIZE_ALIGNMENT = 128 +# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8 +# from to fp32 directly. That's why it has a dtype mapping different from GPU +TPU_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.float8_e4m3fn, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + "int8": torch.int8, + "uint8": torch.uint8, +} + class PallasAttentionBackend(AttentionBackend): @@ -156,8 +169,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") - if kv_cache_dtype != "auto": - raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") @@ -170,6 +181,14 @@ def __init__( tpu_version = torch_xla.tpu.version() if tpu_version < 4: raise NotImplementedError("TPU version must be 4 or higher.") + self.kv_cache_quantized_dtype = None + if kv_cache_dtype != "auto": + if tpu_version < 5: + raise NotImplementedError( + "FP8 KV cache dtype is only supported when TPU version" + " is 5 or higher.") + self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( + kv_cache_dtype.lower().strip()) def forward( self, @@ -204,7 +223,6 @@ def forward( output = torch.ones_like(query) return output - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape query = query.view(num_tokens, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -225,10 +243,21 @@ def forward( # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping write_to_kv_cache( - key, value, kv_cache, slot_mapping, + key, + value, + kv_cache, + slot_mapping, attn_metadata.num_slices_per_kv_cache_update_block, - attn_metadata.num_kv_update_slices) - + attn_metadata.num_kv_update_slices, + self.kv_cache_quantized_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + if self.kv_cache_quantized_dtype is not None and ( + layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0): + raise ValueError( + "k_scale_float and v_scale_float must be non-zero") output = torch.ops.xla.ragged_paged_attention( query, kv_cache, @@ -246,6 +275,8 @@ def forward( sm_scale=self.scale, sliding_window=self.sliding_window, soft_cap=self.logits_soft_cap, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, ) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: @@ -261,18 +292,32 @@ def write_to_kv_cache( slot_mapping: torch.Tensor, num_slices_per_kv_cache_update_block: int, num_kv_update_slices: torch.Tensor, + kv_cache_quantized_dtype: Optional[torch.dtype] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, ) -> None: """ Write the key and values to the KV cache. Args: - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + + if kv_cache_quantized_dtype is not None: + dtype_info = torch.finfo(kv_cache_quantized_dtype) + key = key.to(torch.float32) / k_scale + # NOTE: clamp is added here to avoid out of range of quantized dtype + key = torch.clamp(key, dtype_info.min, dtype_info.max) + key = key.to(kv_cache_quantized_dtype) + value = value.to(torch.float32) / v_scale + value = torch.clamp(value, dtype_info.min, dtype_info.max) + value = value.to(kv_cache_quantized_dtype) + kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ad62d204381..948a87b34e3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -30,9 +30,10 @@ PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sequence import IntermediateTensors -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, - is_pin_memory_available, prev_power_of_2) -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, +from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, + prev_power_of_2) +from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, + PallasAttentionBackend, PallasMetadata, get_page_size_bytes) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -140,11 +141,11 @@ def __init__( if cache_config.cache_dtype == "auto": model_dtype = self.dtype if isinstance(model_dtype, str): - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype] else: self.kv_cache_dtype = model_dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] self._hidden_states_dtype = self.dtype