Skip to content

Support DeepSeekV3-style block FP8 quantization with CT #20279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Jun 30, 2025

Purpose

Relies on recent support in compressed-tensors (neuralmagic/compressed-tensors#372) and llm-compressor (vllm-project/llm-compressor#1607) to produce the models.

This PR implements W8A8 FP8 block quantization support for compressed-tensors models. This is focused on supporting the DeepSeekV3-style format, which has 128x128 block weights and 1x128 block activations (really per-token-group).

Most of the logic is ported directly from fp8.py and I hope to refactor the utilities to be shared eventually.

Test Plan

Manual testing with newly produced models. I'll add lm-eval in another PR

Test Result

Dense

CT result:

lm_eval --model vllm --model_args pretrained=mgoin/Qwen3-0.6B-FP8-BLOCK --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
...
vllm (pretrained=mgoin/Qwen3-0.6B-FP8-BLOCK,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3973|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.3995|±  |0.0135|

Ref:

lm_eval --model vllm --model_args pretrained=Qwen/Qwen3-0.6B-FP8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
...
vllm (pretrained=Qwen/Qwen3-0.6B-FP8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3973|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.3995|±  |0.0135|

MoE

CT result:

vllm (pretrained=mgoin/Qwen3-30B-A3B-FP8-BLOCK,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8158|±  |0.0107|
|     |       |strict-match    |     5|exact_match|↑  |0.8923|±  |0.0085|

Ref:

vllm (pretrained=Qwen/Qwen3-30B-A3B-FP8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8158|±  |0.0107|
|     |       |strict-match    |     5|exact_match|↑  |0.8923|±  |0.0085|

Signed-off-by: mgoin <mgoin64@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @mgoin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly extends vLLM's quantization capabilities by integrating support for DeepSeekV3-style block FP8 quantization. It involves substantial changes to how quantized weights are handled, from their creation and loading to their application during inference, ensuring compatibility with new model types and optimizing performance through specialized kernel utilization.

Highlights

  • DeepSeekV3-style Block FP8 Support: This pull request introduces comprehensive support for DeepSeekV3-style block FP8 quantization, enabling vLLM to load and utilize models quantized with this specific scheme.
  • Enhanced Weight Management: New logic has been added to handle block-quantized weights, including the introduction of BlockQuantScaleParameter for managing block-wise scales, and specific processing steps like FNUZ normalization and ROCm padding during weight loading.
  • Optimized Kernel Dispatch: The system now dispatches to a specialized torch.ops.vllm.apply_w8a8_block_fp8_linear operation for block FP8 linear layers, leveraging optimized kernels for efficient execution based on the detected quantization strategy.
  • Refactored Quantization Logic: Internal quantization parameter management has been streamlined by storing weight_block_size directly on the layer object and consistently passing QuantizationArgs objects, improving type safety and code clarity.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DeepSeekV3-style block FP8 quantization, primarily by integrating the BLOCK quantization strategy into the compressed-tensors framework. Key changes include refactoring how weight_block_size is handled across linear layers and quantization methods, implementing new weight creation and processing logic for block-quantized weights, and ensuring the correct kernel dispatch for block FP8 linear operations. The changes are well-structured and address the core requirements for this new quantization scheme.

Comment on lines +60 to +144
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
weight_loader: Callable, **kwargs):
maybe_create_device_identity()

output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.weight_block_size = None

if self.strategy == QuantizationStrategy.BLOCK:
tp_size = get_tensor_model_parallel_world_size()
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
block_n, block_k = (
layer.weight_block_size[0],
layer.weight_block_size[1],
)
# Required by row parallel
if (tp_size > 1
and input_size // input_size_per_partition == tp_size
and input_size_per_partition % block_k != 0):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}.")

# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)

# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
elif self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
weight_scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)

# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new create_weights method is a substantial addition that correctly handles the creation of weight and scale parameters for block quantization. The divisibility checks for block_n and block_k are critical for ensuring proper alignment with tensor parallelism and merged weights, preventing potential errors during computation. The use of BlockQuantScaleParameter for block strategy is also correct.

Comment on lines +237 to +247
if layer.weight_block_size is not None:
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=layer.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if block correctly dispatches to torch.ops.vllm.apply_w8a8_block_fp8_linear when layer.weight_block_size is set, indicating that block quantization is active. This ensures that the appropriate kernel is used for the forward pass, leveraging the new block quantization capabilities.

Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin mgoin marked this pull request as ready for review July 1, 2025 00:54
Copy link

mergify bot commented Jul 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added needs-rebase deepseek Related to DeepSeek models labels Jul 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant