Skip to content

Support DeepSeekV3-style block FP8 quantization #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Jun 30, 2025

Quite a few things packed into one here, but the goal is to support the 128x128 weight and 1x128 input quantization adopted by deepseekv3 and qwen3 models. See examples: https://huggingface.co/deepseek-ai/DeepSeek-V3 and https://huggingface.co/Qwen/Qwen3-0.6B-FP8

  • Added BLOCK static quantization paths for weight quantization.
  • Added GROUP dynamic quantization paths for per-token-group input quantization. I feel like this is more understandable than the "1x128" block input quantization deepseek uses.
  • I’ve updated all of the places where block_structure was previously treated as an “NxM” string so that it now uses a Python list of two integers (e.g. [128, 128]). I added a pydantic validator that can convert this automatically for old checkpoints that use the string.

Here is the scheme I am proposing to support this:

# Block‐wise FP8 (deepseekv3-style quantization):
# static 128x128 per‐block weights and 
# dynamic per‐token‐group activations
FP8_BLOCK = dict(
    weights=QuantizationArgs(
        num_bits=8,
        type=QuantizationType.FLOAT,
        strategy=QuantizationStrategy.BLOCK,
        symmetric=True,
        dynamic=False,
        block_structure=[128, 128],
    ),
    input_activations=QuantizationArgs(
        num_bits=8,
        type=QuantizationType.FLOAT,
        strategy=QuantizationStrategy.GROUP,
        symmetric=True,
        dynamic=True,
        observer=None,
        group_size=128,
    ),
)

mgoin added 2 commits June 30, 2025 19:26
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: mgoin <michael@neuralmagic.com>
@@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
symmetric: bool = True
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
block_structure: Optional[List[int]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
block_structure: Optional[List[int]] = None
block_structure: Optional[Tuple[int, int]] = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel towards keeping it a list since I don't think we can distinguish between list and tuple in the final json config

QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.TENSOR_GROUP,
):
QuantizationStrategy.GROUP,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly an aesthetic choice, but it might have aesthetic consequences if vllm wants to support fused input-weight quantization. Ex if input_quant_strategy == group and weight_quant_strategy == group

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add some validation on quant_scheme related to this as well

mgoin added 4 commits July 1, 2025 00:44
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: mgoin <michael@neuralmagic.com>
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you produce a test model to nm-testing andd it to this PR?

@@ -111,11 +111,15 @@ def dequantize(
elif scale.ndim == 2:
if scale.shape[1] == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
else:
elif scale.shape[0] == 1:
group_size = int(x_q.shape[1] / scale.shape[1])
args = QuantizationArgs(
strategy=QuantizationStrategy.GROUP, group_size=group_size
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a docstring explaining why this falls into the else condition?
Otherwise, I think this has grown complicated enough to easily fall prey to a bug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsikka Do you have a sense of why this logic exists at all/ what cases its used? I'm not sure if inferring quant strat is really safe practice

@kylesayrs kylesayrs self-assigned this Jul 8, 2025
@@ -154,6 +154,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

@@ -111,11 +111,15 @@ def dequantize(
elif scale.ndim == 2:
if scale.shape[1] == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
else:
elif scale.shape[0] == 1:
group_size = int(x_q.shape[1] / scale.shape[1])
args = QuantizationArgs(
strategy=QuantizationStrategy.GROUP, group_size=group_size
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsikka Do you have a sense of why this logic exists at all/ what cases its used? I'm not sure if inferring quant strat is really safe practice

@kylesayrs kylesayrs assigned shanjiaz and unassigned kylesayrs Jul 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants