Skip to content

Support DeepSeekV3-style block FP8 quantization with CT #20279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,12 +759,10 @@ def weight_loader_v2(self,
tp_size = get_tensor_model_parallel_world_size()

if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
# Assume the weight block size has been set by quant method
assert hasattr(self, "weight_block_size")
weight_block_size = self.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
Expand Down Expand Up @@ -934,8 +932,10 @@ def weight_loader_v2(self,
# Note(simon): This is needed for Qwen3's fp8 quantization.
if isinstance(param, BlockQuantScaleParameter):
assert self.quant_method is not None
assert hasattr(self.quant_method, "quant_config")
weight_block_size = self.quant_method.quant_config.weight_block_size
# Assume the weight block size has been set by the quant method
assert hasattr(self, "weight_block_size")
weight_block_size = self.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from pydantic import BaseModel

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -221,7 +220,8 @@ def _check_scheme_supported(self,
else:
return False

def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs):

if weight_quant is None or input_quant is None:
return False
Expand All @@ -241,8 +241,8 @@ def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
return (is_tensor_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)

def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
input_quant: BaseModel):
def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs):

is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = (
Expand All @@ -256,8 +256,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
return (is_weight_only and is_tensor_group_quant and is_float_type
and is_4_bits and is_group_size_16 and is_symmetric)

def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
Expand All @@ -270,8 +270,8 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static

def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
Expand All @@ -284,8 +284,8 @@ def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic

def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
def _is_fp8_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False
Expand All @@ -295,11 +295,12 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK
])
if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
and is_tensor_or_channel_or_block_weight):
return False

# Dynamic quantization is always supported if weights supported.
Expand All @@ -312,13 +313,16 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation

def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
# Block quantization is not supported for SM90 CUTLASS yet
is_block_quant = weight_quant.strategy == QuantizationStrategy.BLOCK
return (self._check_scheme_supported(90, error=False, match_exact=True)
and self._is_fp8_w8a8(weight_quant, input_quant))
and self._is_fp8_w8a8(weight_quant, input_quant)
and not is_block_quant)

def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
def _is_fp8_w8a16(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
Expand All @@ -328,20 +332,22 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel,
return False

# Confirm weight scheme is supported.
is_floating_point = weight_quant.type == QuantizationType.FLOAT
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK
])
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight):
if not (is_floating_point and is_symmetric_weight # noqa: SIM103
and is_static_weight and is_tensor_or_channel_or_block_weight):
return False

# All conditions satisfied.
return True

def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
input_quant_none = input_quant is None
is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
Expand All @@ -351,8 +357,8 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel,
return (is_channel_group and input_quant_none and is_static)

def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> "CompressedTensorsScheme":

# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
Expand Down Expand Up @@ -392,7 +398,7 @@ def _get_scheme_from_parts(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
weight_quant=weight_quant,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
else:
Expand Down
Loading