-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[TPU] support fp8 kv cache quantization #19292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello! Gemini/gemini-code-assist here, providing a summary of this pull request. This PR introduces support for FP8 KV cache quantization specifically for the TPU platform within vLLM. The changes involve enabling FP8 as a supported quantization type for TPUs, updating the internal mapping of string dtype names to PyTorch dtypes to correctly use torch.float8
types, and modifying the Pallas attention backend to handle writing and reading FP8 quantized KV cache data, including applying necessary scales. A test plan using a specific Llama-3.1 model with FP8 KV cache is included in the description, along with the successful test results.
Highlights
- TPU FP8 KV Cache Support: Adds the necessary logic and configuration to enable using FP8 quantization for the Key/Value cache when running on TPU platforms.
- FP8 Dtype Mapping: Updates the internal utility function that maps string representations of dtypes (like 'fp8') to their corresponding PyTorch dtype objects, now correctly using
torch.float8_e4m3fn
andtorch.float8_e5m2
. - Pallas Attention Backend Modifications: The Pallas attention backend for TPUs is updated to handle FP8 KV cache. This includes removing a previous limitation, adding a check for TPU version 5 or higher (as FP8 KV cache requires it), and passing quantization scales to the KV cache write and attention kernel operations.
Changelog
Click here to see the changelog
- vllm/engine/arg_utils.py
- Includes TPU platform in the condition for supporting the v1 attention backend when FP8 attention is enabled (line 1317).
- vllm/platforms/tpu.py
- Adds 'fp8' to the list of supported quantization types for the TpuPlatform (line 39).
- vllm/utils.py
- Updates
STR_DTYPE_TO_TORCH_DTYPE
to map 'fp8', 'fp8_e4m3', and 'fp8_e5m2' to the correcttorch.float8
dtypes instead oftorch.uint8
(lines 178-180). - Adds a mapping for 'uint8' to
torch.uint8
(line 182).
- Updates
- vllm/v1/attention/backends/pallas.py
- Imports
STR_DTYPE_TO_TORCH_DTYPE
fromvllm.utils
(line 16). - Removes the
NotImplementedError
check forkv_cache_dtype != "auto"
in the constructor (lines 140-141). - Adds a check in the constructor to ensure TPU version is 5 or higher when
kv_cache_dtype
is not 'auto', storing the quantized dtype (lines 152-159). - Removes an assertion that
_k_scale_float
and_v_scale_float
are 1.0 (line 182). - Modifies the call to
write_to_kv_cache
to pass the quantized dtype and scales (lines 195-197). - Adds
k_scale
andv_scale
parameters to thetorch.ops.xla.ragged_paged_attention
call (lines 216-217). - Updates the
write_to_kv_cache
function signature to accept optionalkv_cache_quantized_dtype
,k_scale
, andv_scale
(lines 228-230). - Adds logic within
write_to_kv_cache
to scale and convert key and value tensors to the specified quantized dtype if provided (lines 245-249).
- Imports
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
@@ -175,10 +175,11 @@ | |||
"half": torch.half, | |||
"bfloat16": torch.bfloat16, | |||
"float": torch.float, | |||
"fp8": torch.uint8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@heheda12345 do you know why all the fp8
went to torch.uint8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC this was for support on older versions of torch/triton and hardware that didn't natively support fp8 types or operations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about for those hardware which can support fp8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think STR_DTYPE_TO_TORCH_DTYPE
is only used for kv_cache_dtype
, so like I said above we only need to refer to the storage type rather than the actual precision type. For instance you can see the usage of __nv_fp8_storage_t
and how it uses uint8 as a storage container
__nv_fp8_storage_t res = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively adds support for fp8 KV cache quantization on TPUs. The changes are logical and primarily involve enabling this feature in the TPU platform configuration and the Pallas attention backend.
The update to STR_DTYPE_TO_TORCH_DTYPE
in vllm/utils.py
to use actual torch.float8
types instead of torch.uint8
is a good correctness improvement.
One critical concern has been identified regarding potential division by zero when calculating inverse scales if dynamic KV cache scaling is enabled and input tensors happen to be all zeros. While the default configuration (calculate_kv_scales=False
) avoids this, the code should be robust to this scenario.
Overall, the changes are clear and well-aligned with the PR's objective. Addressing the identified critical issue will ensure greater robustness.
Summary of Findings
- Potential Division by Zero in Scale Calculation: In
vllm/v1/attention/backends/pallas.py
, when calculatingk_scale = 1 / layer._k_scale_float
(and similarly forv_scale
), iflayer._k_scale_float
is 0.0, this will lead to a division by zero. This can occur if dynamic KV cache scaling (calculate_kv_scales=True
) is active and an input tensor (key or value) is all zeros. This could result in NaNs in the attention mechanism.
Merge Readiness
The pull request makes good progress in adding FP8 KV cache support for TPUs. However, there is a critical issue related to potential division by zero in scale calculations that needs to be addressed before merging. Once this is resolved, the PR should be in a much better state for merging. As an AI, I am not authorized to approve pull requests; please ensure further review and approval from designated maintainers.
Signed-off-by: Chengji Yao <chengjiyao@google.com>
1f7c757
to
fff63b2
Compare
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Hey @yaochengji I ran some benchmarks on GPU (L40s) using the Xformers backend with FP8 kv cache and found the quality degradation to not be so severe. Even when using uncalibrated scales (i.e. just
|
@mgoin thanks for spending time evaluating this fp8 implementation. I also evaluated it on TPU today. The uncalibrated one looks good on my side.
But the calibrated model's result is quite abnormal.
I will mark it as draft first and do some investigation. |
This pull request has merge conflicts that must be resolved before it can be |
Purpose
To support fp8 kv cache quantization on TPU.
Test Plan
chengjiyao/Llama-3.1-8B-Instruct-FP8-KV
was created based on https://docs.vllm.ai/en/stable/features/quantization/quantized_kvcache.htmlTest Result
The model quality doesn't look good mainly due to it's static per-tensor kv cache quantization. And I can get similar results on GPU.