Skip to content

Commit fff63b2

Browse files
committed
[TPU] support fp8 kv cache quantization
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 8267f99 commit fff63b2

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13141314
and not envs.is_set("VLLM_ATTENTION_BACKEND")
13151315
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
13161316
supported = False
1317-
if current_platform.is_rocm():
1317+
if current_platform.is_rocm() or current_platform.is_tpu():
13181318
supported = True
13191319
elif fp8_attention and will_use_fa:
13201320
from vllm.attention.utils.fa_utils import (

vllm/platforms/tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ class TpuPlatform(Platform):
3535
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
3636
simple_compile_backend: str = "openxla"
3737

38-
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
38+
supported_quantization: list[str] = [
39+
"fp8", "tpu_int8", "compressed-tensors"
40+
]
3941

4042
additional_env_vars: list[str] = [
4143
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"

vllm/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,11 @@
175175
"half": torch.half,
176176
"bfloat16": torch.bfloat16,
177177
"float": torch.float,
178-
"fp8": torch.uint8,
179-
"fp8_e4m3": torch.uint8,
180-
"fp8_e5m2": torch.uint8,
178+
"fp8": torch.float8_e4m3fn,
179+
"fp8_e4m3": torch.float8_e4m3fn,
180+
"fp8_e5m2": torch.float8_e5m2,
181181
"int8": torch.int8,
182+
"uint8": torch.uint8,
182183
}
183184

184185
TORCH_DTYPE_TO_NUMPY_DTYPE = {

vllm/v1/attention/backends/pallas.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.attention.backends.utils import CommonAttentionState
1414
from vllm.config import VllmConfig
1515
from vllm.logger import init_logger
16-
from vllm.utils import cdiv, next_power_of_2
16+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, next_power_of_2
1717

1818
logger = init_logger(__name__)
1919

@@ -137,8 +137,6 @@ def __init__(
137137
raise NotImplementedError("Head size must be a multiple of 128.")
138138
if alibi_slopes is not None:
139139
raise NotImplementedError("Alibi slopes is not supported.")
140-
if kv_cache_dtype != "auto":
141-
raise NotImplementedError("FP8 KV cache dtype is not supported.")
142140
if blocksparse_params is not None:
143141
raise NotImplementedError("Blocksparse is not supported.")
144142

@@ -151,6 +149,14 @@ def __init__(
151149
tpu_version = torch_xla.tpu.version()
152150
if tpu_version < 4:
153151
raise NotImplementedError("TPU version must be 4 or higher.")
152+
self.kv_cache_quantized_dtype = None
153+
if kv_cache_dtype != "auto":
154+
if tpu_version < 5:
155+
raise NotImplementedError(
156+
"FP8 KV cache dtype is only supported when TPU version"
157+
" is 5 or higher.")
158+
self.kv_cache_quantized_dtype = STR_DTYPE_TO_TORCH_DTYPE.get(
159+
kv_cache_dtype.lower().strip())
154160

155161
def forward(
156162
self,
@@ -179,15 +185,16 @@ def forward(
179185
output = torch.ones_like(query)
180186
return output
181187

182-
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
183188
num_tokens, hidden_size = query.shape
184189
query = query.view(num_tokens, self.num_heads, self.head_size)
185190

186191
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
187192
# Write input keys and values to the KV cache.
188193
# Skip this if sharing KV cache with an earlier attention layer.
189194
slot_mapping = attn_metadata.slot_mapping
190-
write_to_kv_cache(key, value, kv_cache, slot_mapping)
195+
write_to_kv_cache(key, value, kv_cache, slot_mapping,
196+
self.kv_cache_quantized_dtype,
197+
layer._k_scale_float, layer._v_scale_float)
191198

192199
output = torch.ops.xla.ragged_paged_attention(
193200
query,
@@ -206,6 +213,8 @@ def forward(
206213
sm_scale=self.scale,
207214
sliding_window=self.sliding_window,
208215
soft_cap=self.logits_soft_cap,
216+
k_scale=1 / layer._k_scale_float,
217+
v_scale=1 / layer._v_scale_float,
209218
)
210219

211220
return output.reshape(num_tokens, hidden_size)
@@ -216,6 +225,9 @@ def write_to_kv_cache(
216225
value: torch.Tensor,
217226
kv_cache: torch.Tensor,
218227
slot_mapping: torch.Tensor,
228+
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
229+
k_scale: float = 1.0,
230+
v_scale: float = 1.0,
219231
) -> None:
220232
""" Write the key and values to the KV cache.
221233
@@ -230,6 +242,11 @@ def write_to_kv_cache(
230242

231243
key = key.view(-1, num_kv_heads, head_size)
232244
value = value.view(-1, num_kv_heads, head_size)
245+
if kv_cache_quantized_dtype is not None:
246+
key = key * k_scale
247+
key = key.to(kv_cache_quantized_dtype)
248+
value = value * v_scale
249+
value = value.to(kv_cache_quantized_dtype)
233250

234251
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
235252
head_size)

0 commit comments

Comments
 (0)