Skip to content

Commit a28a2dc

Browse files
committed
[TPU] support fp8 kv cache quantization
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1 parent 01513a3 commit a28a2dc

File tree

4 files changed

+66
-19
lines changed

4 files changed

+66
-19
lines changed

vllm/engine/arg_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,11 +1363,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13631363
and not envs.is_set("VLLM_ATTENTION_BACKEND")
13641364
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
13651365
supported = False
1366-
if current_platform.is_rocm() or (
1367-
current_platform.is_cuda()
1368-
and current_platform.is_device_capability(100)) or (
1369-
current_platform.device_name
1370-
== "hpu"): # handle hpu also for OOT platform
1366+
if current_platform.is_rocm() or (current_platform.is_cuda(
1367+
) and current_platform.is_device_capability(100)) or (
1368+
current_platform.device_name == "hpu"
1369+
) or current_platform.is_tpu(): # handle hpu also for OOT platform
13711370
supported = True
13721371
elif fp8_attention and will_use_fa:
13731372
from vllm.attention.utils.fa_utils import (

vllm/platforms/tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ class TpuPlatform(Platform):
3535
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
3636
simple_compile_backend: str = "openxla"
3737

38-
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
38+
supported_quantization: list[str] = [
39+
"fp8", "tpu_int8", "compressed-tensors"
40+
]
3941

4042
additional_env_vars: list[str] = [
4143
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"

vllm/v1/attention/backends/pallas.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@
2424
# TPU requires the head size to be a multiple of 128.
2525
TPU_HEAD_SIZE_ALIGNMENT = 128
2626

27+
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
28+
# from to fp32 directly. That's why it has a dtype mapping different from GPU
29+
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
30+
"half": torch.half,
31+
"bfloat16": torch.bfloat16,
32+
"float": torch.float,
33+
"fp8": torch.float8_e4m3fn,
34+
"fp8_e4m3": torch.float8_e4m3fn,
35+
"fp8_e5m2": torch.float8_e5m2,
36+
"int8": torch.int8,
37+
"uint8": torch.uint8,
38+
}
39+
2740

2841
class PallasAttentionBackend(AttentionBackend):
2942

@@ -156,8 +169,6 @@ def __init__(
156169
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
157170
if alibi_slopes is not None:
158171
raise NotImplementedError("Alibi slopes is not supported.")
159-
if kv_cache_dtype != "auto":
160-
raise NotImplementedError("FP8 KV cache dtype is not supported.")
161172
if blocksparse_params is not None:
162173
raise NotImplementedError("Blocksparse is not supported.")
163174

@@ -170,6 +181,14 @@ def __init__(
170181
tpu_version = torch_xla.tpu.version()
171182
if tpu_version < 4:
172183
raise NotImplementedError("TPU version must be 4 or higher.")
184+
self.kv_cache_quantized_dtype = None
185+
if kv_cache_dtype != "auto":
186+
if tpu_version < 5:
187+
raise NotImplementedError(
188+
"FP8 KV cache dtype is only supported when TPU version"
189+
" is 5 or higher.")
190+
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
191+
kv_cache_dtype.lower().strip())
173192

174193
def forward(
175194
self,
@@ -204,7 +223,6 @@ def forward(
204223
output = torch.ones_like(query)
205224
return output
206225

207-
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
208226
num_tokens, hidden_size = query.shape
209227
query = query.view(num_tokens, self.num_heads, self.head_size)
210228
key = key.view(-1, self.num_kv_heads, self.head_size)
@@ -225,10 +243,21 @@ def forward(
225243
# Skip this if sharing KV cache with an earlier attention layer.
226244
slot_mapping = attn_metadata.slot_mapping
227245
write_to_kv_cache(
228-
key, value, kv_cache, slot_mapping,
246+
key,
247+
value,
248+
kv_cache,
249+
slot_mapping,
229250
attn_metadata.num_slices_per_kv_cache_update_block,
230-
attn_metadata.num_kv_update_slices)
231-
251+
attn_metadata.num_kv_update_slices,
252+
self.kv_cache_quantized_dtype,
253+
layer._k_scale_float,
254+
layer._v_scale_float,
255+
)
256+
257+
if self.kv_cache_quantized_dtype is not None and (
258+
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0):
259+
raise ValueError(
260+
"k_scale_float and v_scale_float must be non-zero")
232261
output = torch.ops.xla.ragged_paged_attention(
233262
query,
234263
kv_cache,
@@ -246,6 +275,8 @@ def forward(
246275
sm_scale=self.scale,
247276
sliding_window=self.sliding_window,
248277
soft_cap=self.logits_soft_cap,
278+
k_scale=layer._k_scale_float,
279+
v_scale=layer._v_scale_float,
249280
)
250281

251282
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
@@ -261,18 +292,32 @@ def write_to_kv_cache(
261292
slot_mapping: torch.Tensor,
262293
num_slices_per_kv_cache_update_block: int,
263294
num_kv_update_slices: torch.Tensor,
295+
kv_cache_quantized_dtype: Optional[torch.dtype] = None,
296+
k_scale: float = 1.0,
297+
v_scale: float = 1.0,
264298
) -> None:
265299
""" Write the key and values to the KV cache.
266300
267301
Args:
268-
key: shape = [num_tokens, num_kv_heads * head_size]
269-
value: shape = [num_tokens, num_kv_heads * head_size]
302+
key: shape = [num_tokens, num_kv_heads, head_size]
303+
value: shape = [num_tokens, num_kv_heads, head_size]
270304
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
271305
num_slices_per_kv_cache_update_block: int
272306
"""
273307
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
274308
head_size = cdiv(head_size,
275309
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
310+
311+
if kv_cache_quantized_dtype is not None:
312+
dtype_info = torch.finfo(kv_cache_quantized_dtype)
313+
key = key.to(torch.float32) / k_scale
314+
# NOTE: clamp is added here to avoid out of range of quantized dtype
315+
key = torch.clamp(key, dtype_info.min, dtype_info.max)
316+
key = key.to(kv_cache_quantized_dtype)
317+
value = value.to(torch.float32) / v_scale
318+
value = torch.clamp(value, dtype_info.min, dtype_info.max)
319+
value = value.to(kv_cache_quantized_dtype)
320+
276321
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
277322
head_size)
278323

vllm/v1/worker/tpu_model_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
PlaceholderRange)
3131
from vllm.multimodal.utils import group_mm_inputs_by_modality
3232
from vllm.sequence import IntermediateTensors
33-
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
34-
is_pin_memory_available, prev_power_of_2)
35-
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
33+
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
34+
prev_power_of_2)
35+
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
36+
PallasAttentionBackend,
3637
PallasMetadata,
3738
get_page_size_bytes)
3839
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -140,11 +141,11 @@ def __init__(
140141
if cache_config.cache_dtype == "auto":
141142
model_dtype = self.dtype
142143
if isinstance(model_dtype, str):
143-
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
144+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
144145
else:
145146
self.kv_cache_dtype = model_dtype
146147
else:
147-
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
148+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
148149
cache_config.cache_dtype]
149150
self._hidden_states_dtype = self.dtype
150151

0 commit comments

Comments
 (0)