Skip to content

Commit a103303

Browse files
Edwardf0t1Chen-zexi
authored andcommitted
Enable ModelOpt Llama4 fp8 checkpoint deployment (vllm-project#20419)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 1d5c616 commit a103303

File tree

5 files changed

+501
-35
lines changed

5 files changed

+501
-35
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
8181
params_dtype: torch.dtype, **extra_weight_attrs):
8282
raise NotImplementedError
8383

84+
def uses_weight_scale_2_pattern(self) -> bool:
85+
"""
86+
Returns True if this quantization method uses 'weight_scale_2' pattern
87+
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
88+
89+
This method should be overridden by subclasses that use the
90+
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
91+
"""
92+
return False
93+
8494
@staticmethod
8595
def maybe_make_prepare_finalize(
8696
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
@@ -1081,12 +1091,23 @@ def weight_loader(self,
10811091

10821092
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
10831093
if "ModelOpt" in quant_method_name:
1084-
if ('weight_scale_2' in weight_name
1085-
or 'input_scale' in weight_name):
1086-
self._load_per_tensor_weight_scale(shard_id=shard_id,
1087-
param=param,
1088-
loaded_weight=loaded_weight,
1089-
expert_id=expert_id)
1094+
# Determine per-tensor weight scale patterns based on variant
1095+
# Use the dedicated method instead of brittle string matching
1096+
uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern(
1097+
)
1098+
1099+
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
1100+
per_tensor_conditions = (
1101+
"weight_scale_2" in weight_name if uses_weight_scale_2 else
1102+
"weight_scale" in weight_name) or "input_scale" in weight_name
1103+
1104+
if per_tensor_conditions:
1105+
self._load_per_tensor_weight_scale(
1106+
shard_id=shard_id,
1107+
param=param,
1108+
loaded_weight=loaded_weight,
1109+
expert_id=expert_id,
1110+
)
10901111
elif "weight" in weight_name:
10911112
self._load_model_weight_or_group_weight_scale(
10921113
shard_id=shard_id,
@@ -1558,3 +1579,7 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
15581579
dispatch_key=current_platform.dispatch_key,
15591580
tags=(torch.Tag.needs_fixed_stride_order, ),
15601581
)
1582+
1583+
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
1584+
# to avoid expensive runtime reflection in model loading code
1585+
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 261 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
4242
def __init__(
4343
self,
4444
is_checkpoint_fp8_serialized: bool = False,
45+
kv_cache_quant_method: Optional[str] = None,
46+
exclude_modules: Optional[list[str]] = None,
4547
) -> None:
4648
super().__init__()
4749
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
50+
self.kv_cache_quant_method = kv_cache_quant_method
51+
self.exclude_modules = exclude_modules
4852
if is_checkpoint_fp8_serialized:
4953
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
5054
" the format is experimental and could change.")
@@ -69,34 +73,63 @@ def get_config_filenames(cls) -> list[str]:
6973
def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config":
7074
quant_config = cls.get_from_keys(config, ["quantization"])
7175
quant_method = quant_config["quant_algo"]
76+
kv_cache_quant_method = cls.get_from_keys(
77+
config, ["quantization"]).get("kv_cache_quant_algo")
78+
exclude_modules = cls.get_from_keys(
79+
config, ["quantization"]).get("exclude_modules")
80+
7281
if quant_method not in QUANT_ALGOS:
7382
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
7483
" quantizations in vLLM. Please check the "
7584
"`hf_quant_config.json` file for your model's "
7685
"quant configuration.")
7786
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
7887

79-
return cls(is_checkpoint_fp8_serialized)
88+
return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method,
89+
exclude_modules)
90+
91+
def is_layer_excluded(self, prefix: str) -> bool:
92+
"""
93+
Check if a layer should be excluded from quantization.
94+
95+
This method handles both regular models and multimodal models that use
96+
the language_model prefix. For multimodal models, it checks if the
97+
module name (without the language_model prefix) is in the exclude list.
98+
"""
99+
if self.exclude_modules is None:
100+
return False
101+
102+
# Check if any excluded module matches the prefix
103+
for module in self.exclude_modules:
104+
if (module in prefix
105+
or (prefix.startswith("language_model.")
106+
and module in prefix.removeprefix("language_model."))):
107+
return True
108+
return False
80109

81110
def get_quant_method(self, layer: torch.nn.Module,
82111
prefix: str) -> Optional["QuantizeMethodBase"]:
83112
from vllm.attention.layer import Attention # Avoid circular import
84113
if isinstance(layer, LinearBase):
114+
if self.is_layer_excluded(prefix):
115+
return UnquantizedLinearMethod()
85116
return ModelOptFp8LinearMethod(self)
86117
elif isinstance(layer, Attention):
87118
return ModelOptFp8KVCacheMethod(self)
119+
elif isinstance(layer, FusedMoE):
120+
return ModelOptFp8MoEMethod(self)
88121
return None
89122

90123

91124
class ModelOptFp8LinearMethod(LinearMethodBase):
92125
"""Linear method for Model Optimizer static quantization.
93126
Supports loading FP8 checkpoints with static weight scale and
94-
activation scale. Future support might be added for dynamic
127+
activation scale. Future support might be added for dynamic
95128
scales.
96129
97130
Limitations:
98131
1. Only support per-tensor quantization due to torch._scaled_mm support.
99-
2. Only support float8_e4m3fn datatype
132+
2. Only support float8_e4m3fn datatype
100133
Args: quant_config: The ModelOpt quantization config.
101134
"""
102135

@@ -172,6 +205,223 @@ def apply(
172205
bias=bias)
173206

174207

208+
class ModelOptFp8MoEMethod(FusedMoEMethodBase):
209+
"""MoE method for ModelOpt FP8.
210+
Supports loading FP8 checkpoints with static weight scale and
211+
activation scale.
212+
Args:
213+
quant_config: The ModelOpt quantization config.
214+
"""
215+
216+
def __init__(self, quant_config: ModelOptFp8Config):
217+
self.quant_config = quant_config
218+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
219+
cutlass_fp8_supported)
220+
self.cutlass_fp8_supported = cutlass_fp8_supported()
221+
222+
def create_weights(
223+
self,
224+
layer: torch.nn.Module,
225+
num_experts: int,
226+
hidden_size: int,
227+
intermediate_size_per_partition: int,
228+
params_dtype: torch.dtype,
229+
**extra_weight_attrs,
230+
):
231+
232+
# Use FP8 dtype if checkpoint is serialized
233+
weight_dtype = (torch.float8_e4m3fn
234+
if self.quant_config.is_checkpoint_fp8_serialized else
235+
params_dtype)
236+
weight_loader = extra_weight_attrs.get("weight_loader")
237+
238+
w13_weight = ModelWeightParameter(
239+
data=torch.empty(num_experts,
240+
2 * intermediate_size_per_partition,
241+
hidden_size,
242+
dtype=weight_dtype),
243+
input_dim=2,
244+
output_dim=1,
245+
weight_loader=weight_loader,
246+
)
247+
layer.register_parameter("w13_weight", w13_weight)
248+
249+
w2_weight = ModelWeightParameter(
250+
data=torch.empty(num_experts,
251+
hidden_size,
252+
intermediate_size_per_partition,
253+
dtype=weight_dtype),
254+
input_dim=2,
255+
output_dim=1,
256+
weight_loader=weight_loader,
257+
)
258+
layer.register_parameter("w2_weight", w2_weight)
259+
260+
if self.quant_config.is_checkpoint_fp8_serialized:
261+
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
262+
# Allocate 2 scales for w1 and w3 respectively.
263+
# They will be combined to a single scale after weight loading.
264+
w13_weight_scale = PerTensorScaleParameter(
265+
data=torch.full(
266+
(num_experts, 2),
267+
1.0,
268+
dtype=torch.float32,
269+
),
270+
weight_loader=weight_loader,
271+
)
272+
w2_weight_scale = PerTensorScaleParameter(
273+
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
274+
weight_loader=weight_loader,
275+
)
276+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
277+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
278+
279+
# Set weight loader attributes for scales
280+
extra_weight_attrs.update(
281+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
282+
283+
# INPUT SCALES - Per-tensor scaling for ModelOpt
284+
w13_input_scale = PerTensorScaleParameter(
285+
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
286+
weight_loader=weight_loader,
287+
)
288+
w2_input_scale = PerTensorScaleParameter(
289+
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
290+
weight_loader=weight_loader,
291+
)
292+
layer.register_parameter("w13_input_scale", w13_input_scale)
293+
layer.register_parameter("w2_input_scale", w2_input_scale)
294+
295+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
296+
"""Process FP8 MoE weights after loading from serialized checkpoint.
297+
Only supports pre-quantized checkpoints with FP8 weights and scales.
298+
"""
299+
300+
layer.w13_weight = Parameter(layer.w13_weight.data,
301+
requires_grad=False)
302+
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
303+
304+
from vllm._custom_ops import scaled_fp8_quant
305+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
306+
per_tensor_dequantize)
307+
308+
# Handle scale parameters
309+
if hasattr(layer,
310+
"w13_weight_scale") and layer.w13_weight_scale is not None:
311+
# Fp8 moe kernel needs single weight scale for w13 per expert.
312+
# We take the max of the w1 and w3 scales
313+
# then dequant and requant each expert.
314+
if layer.w13_weight_scale.dim() == 2:
315+
316+
# Get the maximum scale across w1 and w3 for each expert
317+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
318+
319+
# Requantize each expert's weights using the combined scale
320+
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
321+
# where the first intermediate_size rows are w1, the next are w3
322+
intermediate_size = layer.w13_weight.shape[1] // 2
323+
for expert_id in range(layer.w13_weight.shape[0]):
324+
start = 0
325+
for shard_id in range(2): # w1 and w3
326+
# Dequantize using the original scale for this shard
327+
dq_weight = per_tensor_dequantize(
328+
layer.w13_weight[expert_id][start:start +
329+
intermediate_size, :],
330+
layer.w13_weight_scale[expert_id][shard_id],
331+
)
332+
# Requantize using the combined max scale
333+
334+
(
335+
layer.w13_weight[expert_id][start:start +
336+
intermediate_size, :],
337+
_,
338+
) = scaled_fp8_quant(dq_weight,
339+
max_w13_scales[expert_id])
340+
341+
start += intermediate_size
342+
343+
# Update the scale parameter to be per-expert
344+
layer.w13_weight_scale = Parameter(max_w13_scales,
345+
requires_grad=False)
346+
else:
347+
layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data,
348+
requires_grad=False)
349+
350+
if hasattr(layer,
351+
"w2_weight_scale") and layer.w2_weight_scale is not None:
352+
layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data,
353+
requires_grad=False)
354+
# Input scales must be equal for each expert in fp8 MoE layers.
355+
if hasattr(layer,
356+
"w13_input_scale") and layer.w13_input_scale is not None:
357+
layer.w13_input_scale = Parameter(layer.w13_input_scale.max(),
358+
requires_grad=False)
359+
if hasattr(layer,
360+
"w2_input_scale") and layer.w2_input_scale is not None:
361+
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
362+
requires_grad=False)
363+
364+
def apply(
365+
self,
366+
layer: torch.nn.Module,
367+
x: torch.Tensor,
368+
router_logits: torch.Tensor,
369+
top_k: int,
370+
renormalize: bool,
371+
use_grouped_topk: bool = False,
372+
topk_group: Optional[int] = None,
373+
num_expert_group: Optional[int] = None,
374+
global_num_experts: int = -1,
375+
expert_map: Optional[torch.Tensor] = None,
376+
custom_routing_function: Optional[Callable] = None,
377+
scoring_func: str = "softmax",
378+
e_score_correction_bias: Optional[torch.Tensor] = None,
379+
apply_router_weight_on_input: bool = False,
380+
activation: str = "silu",
381+
enable_eplb: bool = False,
382+
expert_load_view: Optional[torch.Tensor] = None,
383+
logical_to_physical_map: Optional[torch.Tensor] = None,
384+
logical_replica_count: Optional[torch.Tensor] = None,
385+
) -> torch.Tensor:
386+
if enable_eplb:
387+
raise NotImplementedError(
388+
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
389+
390+
# Expert selection
391+
topk_weights, topk_ids = FusedMoE.select_experts(
392+
hidden_states=x,
393+
router_logits=router_logits,
394+
use_grouped_topk=use_grouped_topk,
395+
top_k=top_k,
396+
renormalize=renormalize,
397+
topk_group=topk_group,
398+
num_expert_group=num_expert_group,
399+
custom_routing_function=custom_routing_function,
400+
scoring_func=scoring_func,
401+
e_score_correction_bias=e_score_correction_bias,
402+
)
403+
from vllm.model_executor.layers.fused_moe.fused_moe import (
404+
fused_experts)
405+
return fused_experts(
406+
x,
407+
layer.w13_weight,
408+
layer.w2_weight,
409+
topk_weights=topk_weights,
410+
topk_ids=topk_ids,
411+
inplace=True,
412+
activation=activation,
413+
use_fp8_w8a8=True,
414+
per_channel_quant=False,
415+
global_num_experts=global_num_experts,
416+
expert_map=expert_map,
417+
w1_scale=layer.w13_weight_scale,
418+
w2_scale=layer.w2_weight_scale,
419+
a1_scale=layer.w13_input_scale,
420+
a2_scale=layer.w2_input_scale,
421+
apply_router_weight_on_input=apply_router_weight_on_input,
422+
)
423+
424+
175425
class ModelOptNvFp4Config(QuantizationConfig):
176426
"""Config class for ModelOpt FP4."""
177427

@@ -274,7 +524,7 @@ def __init__(self, quant_config: Union[ModelOptFp8Config,
274524
class ModelOptNvFp4LinearMethod(LinearMethodBase):
275525
"""Linear method for Model Optimizer NVFP4.
276526
Supports loading NVFP4 checkpoints with the following structure:
277-
527+
278528
input_scale: torch.float32, scalar ,
279529
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
280530
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
@@ -455,7 +705,7 @@ def apply(
455705
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
456706
"""
457707
MoE Method for FP4 Quantization.
458-
Args:
708+
Args:
459709
quant_config: NVFP4 Quant Config
460710
"""
461711

@@ -472,6 +722,12 @@ def __init__(self, quant_config: ModelOptNvFp4Config):
472722
" quantization. Please use Blackwell and"
473723
" above.")
474724

725+
def uses_weight_scale_2_pattern(self) -> bool:
726+
"""
727+
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
728+
"""
729+
return True
730+
475731
def create_weights(self, layer: torch.nn.Module, num_experts: int,
476732
hidden_size: int, intermediate_size_per_partition: int,
477733
params_dtype: torch.dtype, **extra_weight_attrs):

0 commit comments

Comments
 (0)