Skip to content

[TPU] support fp8 kv cache quantization #19292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Jun 6, 2025

Purpose

To support fp8 kv cache quantization on TPU.

Test Plan

chengjiyao/Llama-3.1-8B-Instruct-FP8-KV was created based on https://docs.vllm.ai/en/stable/features/quantization/quantized_kvcache.html

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import os

from vllm import LLM, SamplingParams

prompts = [
    "A robot may not injure a human being",
    "It is only with the heart that one can see rightly;",
    "The greatest glory in living lies not in never falling,",
]
answers = [
    " or, through inaction, allow a human being to come to harm.",
    " what is essential is invisible to the eye.",
    " but in rising every time we fall.",
]
N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)


def main():
    parser = argparse.ArgumentParser(description="TPU offline inference example")
    parser.add_argument("--use-spmd", action="store_true", help="Enable SPMD mode")
    args = parser.parse_args()

    llm_args = {
        "model": "chengjiyao/Llama-3.1-8B-Instruct-FP8-KV",
        "max_num_batched_tokens": 64,
        "max_num_seqs": 4,
        "max_model_len": 128,
        "kv_cache_dtype": "fp8",
    }
    if args.use_spmd:
        os.environ["VLLM_XLA_USE_SPMD"] = "1"
        # Can only hardcode the number of chips for now.
        # calling xr.global_runtime_device_count() beforeing init SPMD env in
        # torch_xla will mess up the distributed env.
        llm_args["tensor_parallel_size"] = 8
        # Use Llama, for num_kv_heads = 8.
        llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"

    # Set `enforce_eager=True` to avoid ahead-of-time compilation.
    # In real workloads, `enforace_eager` should be `False`.
    llm = LLM(**llm_args)
    outputs = llm.generate(prompts, sampling_params)
    print("-" * 50)
    for output, answer in zip(outputs, answers):
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)


if __name__ == "__main__":
    main()

Test Result

The model quality doesn't look good mainly due to it's static per-tensor kv cache quantization. And I can get similar results on GPU.

--------------------------------------------------
Prompt: 'A robot may not injure a human being'
Generated text: ' to perform the following action:\nthe robot may not injure a human being to'
--------------------------------------------------
Prompt: 'It is only with the heart that one can see rightly;'
Generated text: ' the one that is only with the heart that one can see rightly; the one'
--------------------------------------------------
Prompt: 'The greatest glory in living lies not in never falling,'
Generated text: ' but in rising at the awkwardly given to the best owed darlings'
--------------------------------------------------

Copy link

github-actions bot commented Jun 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini/gemini-code-assist here, providing a summary of this pull request. This PR introduces support for FP8 KV cache quantization specifically for the TPU platform within vLLM. The changes involve enabling FP8 as a supported quantization type for TPUs, updating the internal mapping of string dtype names to PyTorch dtypes to correctly use torch.float8 types, and modifying the Pallas attention backend to handle writing and reading FP8 quantized KV cache data, including applying necessary scales. A test plan using a specific Llama-3.1 model with FP8 KV cache is included in the description, along with the successful test results.

Highlights

  • TPU FP8 KV Cache Support: Adds the necessary logic and configuration to enable using FP8 quantization for the Key/Value cache when running on TPU platforms.
  • FP8 Dtype Mapping: Updates the internal utility function that maps string representations of dtypes (like 'fp8') to their corresponding PyTorch dtype objects, now correctly using torch.float8_e4m3fn and torch.float8_e5m2.
  • Pallas Attention Backend Modifications: The Pallas attention backend for TPUs is updated to handle FP8 KV cache. This includes removing a previous limitation, adding a check for TPU version 5 or higher (as FP8 KV cache requires it), and passing quantization scales to the KV cache write and attention kernel operations.

Changelog

Click here to see the changelog
  • vllm/engine/arg_utils.py
    • Includes TPU platform in the condition for supporting the v1 attention backend when FP8 attention is enabled (line 1317).
  • vllm/platforms/tpu.py
    • Adds 'fp8' to the list of supported quantization types for the TpuPlatform (line 39).
  • vllm/utils.py
    • Updates STR_DTYPE_TO_TORCH_DTYPE to map 'fp8', 'fp8_e4m3', and 'fp8_e5m2' to the correct torch.float8 dtypes instead of torch.uint8 (lines 178-180).
    • Adds a mapping for 'uint8' to torch.uint8 (line 182).
  • vllm/v1/attention/backends/pallas.py
    • Imports STR_DTYPE_TO_TORCH_DTYPE from vllm.utils (line 16).
    • Removes the NotImplementedError check for kv_cache_dtype != "auto" in the constructor (lines 140-141).
    • Adds a check in the constructor to ensure TPU version is 5 or higher when kv_cache_dtype is not 'auto', storing the quantized dtype (lines 152-159).
    • Removes an assertion that _k_scale_float and _v_scale_float are 1.0 (line 182).
    • Modifies the call to write_to_kv_cache to pass the quantized dtype and scales (lines 195-197).
    • Adds k_scale and v_scale parameters to the torch.ops.xla.ragged_paged_attention call (lines 216-217).
    • Updates the write_to_kv_cache function signature to accept optional kv_cache_quantized_dtype, k_scale, and v_scale (lines 228-230).
    • Adds logic within write_to_kv_cache to scale and convert key and value tensors to the specified quantized dtype if provided (lines 245-249).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added v1 tpu Related to Google TPUs labels Jun 6, 2025
@@ -175,10 +175,11 @@
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 do you know why all the fp8 went to torch.uint8?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this was for support on older versions of torch/triton and hardware that didn't natively support fp8 types or operations

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about for those hardware which can support fp8?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think STR_DTYPE_TO_TORCH_DTYPE is only used for kv_cache_dtype, so like I said above we only need to refer to the storage type rather than the actual precision type. For instance you can see the usage of __nv_fp8_storage_t and how it uses uint8 as a storage container

__nv_fp8_storage_t res =

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively adds support for fp8 KV cache quantization on TPUs. The changes are logical and primarily involve enabling this feature in the TPU platform configuration and the Pallas attention backend.

The update to STR_DTYPE_TO_TORCH_DTYPE in vllm/utils.py to use actual torch.float8 types instead of torch.uint8 is a good correctness improvement.

One critical concern has been identified regarding potential division by zero when calculating inverse scales if dynamic KV cache scaling is enabled and input tensors happen to be all zeros. While the default configuration (calculate_kv_scales=False) avoids this, the code should be robust to this scenario.

Overall, the changes are clear and well-aligned with the PR's objective. Addressing the identified critical issue will ensure greater robustness.

Summary of Findings

  • Potential Division by Zero in Scale Calculation: In vllm/v1/attention/backends/pallas.py, when calculating k_scale = 1 / layer._k_scale_float (and similarly for v_scale), if layer._k_scale_float is 0.0, this will lead to a division by zero. This can occur if dynamic KV cache scaling (calculate_kv_scales=True) is active and an input tensor (key or value) is all zeros. This could result in NaNs in the attention mechanism.

Merge Readiness

The pull request makes good progress in adding FP8 KV cache support for TPUs. However, there is a critical issue related to potential division by zero in scale calculations that needs to be addressed before merging. Once this is resolved, the PR should be in a much better state for merging. As an AI, I am not authorized to approve pull requests; please ensure further review and approval from designated maintainers.

@yaochengji yaochengji requested a review from mgoin June 6, 2025 18:00
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
@mgoin
Copy link
Member

mgoin commented Jul 7, 2025

Hey @yaochengji I ran some benchmarks on GPU (L40s) using the Xformers backend with FP8 kv cache and found the quality degradation to not be so severe. Even when using uncalibrated scales (i.e. just 1.0) I see <2% relative accuracy loss on GSM8k. With the calibrated scales from your checkpoint I see ~0.2% relative accuracy loss.

# Baseline
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,max_model_len=4096 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
Processed prompts: 100%|███████████| 1319/1319 [02:28<00:00,  8.89it/s, est. speed input: 7750.81 toks/s, output: 863.94 toks/s]
Running generate_until requests: 100%|██████████| 1319/1319 [02:28<00:00,  8.88it/s]
2025-07-07:19:24:56 INFO     [loggers.evaluation_tracker:280] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,max_model_len=4096,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7832|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7582|±  |0.0118|

# FP8 KV Cache with uncalibrated scales
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,max_model_len=4096,kv_cache_dtype=fp8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
Processed prompts: 100%|███████████| 1319/1319 [02:06<00:00, 10.41it/s, est. speed input: 9079.00 toks/s, output: 1008.75 toks/s]
Running generate_until requests: 100%|██████████| 1319/1319 [02:06<00:00, 10.40it/s]
2025-07-07:19:30:07 INFO     [loggers.evaluation_tracker:280] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,max_model_len=4096,kv_cache_dtype=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7688|±  |0.0116|
|     |       |strict-match    |     5|exact_match|↑  |0.7453|±  |0.0120|

# FP8 KV Cache with calibrated scales
lm_eval --model vllm --model_args pretrained=chengjiyao/Llama-3.1-8B-Instruct-FP8-KV,max_model_len=4096,kv_cache_dtype=fp8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
Processed prompts: 100%|███████████| 1319/1319 [02:07<00:00, 10.35it/s, est. speed input: 9026.33 toks/s, output: 1003.69 toks/s]
Running generate_until requests: 100%|██████████| 1319/1319 [02:07<00:00, 10.34it/s]
vllm (pretrained=chengjiyao/Llama-3.1-8B-Instruct-FP8-KV,max_model_len=4096,kv_cache_dtype=fp8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7817|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7521|±  |0.0119|

@yaochengji
Copy link
Collaborator Author

Hey @yaochengji I ran some benchmarks on GPU (L40s) using the Xformers backend with FP8 kv cache

@mgoin thanks for spending time evaluating this fp8 implementation. I also evaluated it on TPU today.

The uncalibrated one looks good on my side.

lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,max_model_len=4096,kv_cache_dtype=fp8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7817|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7642|±  |0.0117|

But the calibrated model's result is quite abnormal.

lm_eval --model vllm --model_args pretrained=chengjiyao/Llama-3.1-8B-Instruct-FP8-KV,max_model_len=4096,kv_cache_dtype=fp8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.0250|±  |0.0043|
|     |       |strict-match    |     5|exact_match|↑  |0.0008|±  |0.0008|

I will mark it as draft first and do some investigation.

@yaochengji yaochengji marked this pull request as draft July 8, 2025 04:19
Copy link

mergify bot commented Jul 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yaochengji.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants