@@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
42
42
def __init__ (
43
43
self ,
44
44
is_checkpoint_fp8_serialized : bool = False ,
45
+ kv_cache_quant_method : Optional [str ] = None ,
46
+ exclude_modules : Optional [list [str ]] = None ,
45
47
) -> None :
46
48
super ().__init__ ()
47
49
self .is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
50
+ self .kv_cache_quant_method = kv_cache_quant_method
51
+ self .exclude_modules = exclude_modules
48
52
if is_checkpoint_fp8_serialized :
49
53
logger .warning ("Detected ModelOpt fp8 checkpoint. Please note that"
50
54
" the format is experimental and could change." )
@@ -69,34 +73,63 @@ def get_config_filenames(cls) -> list[str]:
69
73
def from_config (cls , config : dict [str , Any ]) -> "ModelOptFp8Config" :
70
74
quant_config = cls .get_from_keys (config , ["quantization" ])
71
75
quant_method = quant_config ["quant_algo" ]
76
+ kv_cache_quant_method = cls .get_from_keys (
77
+ config , ["quantization" ]).get ("kv_cache_quant_algo" )
78
+ exclude_modules = cls .get_from_keys (
79
+ config , ["quantization" ]).get ("exclude_modules" )
80
+
72
81
if quant_method not in QUANT_ALGOS :
73
82
raise ValueError (f"ModelOpt currently only supports: { QUANT_ALGOS } "
74
83
" quantizations in vLLM. Please check the "
75
84
"`hf_quant_config.json` file for your model's "
76
85
"quant configuration." )
77
86
is_checkpoint_fp8_serialized = ("FP8" in quant_method )
78
87
79
- return cls (is_checkpoint_fp8_serialized )
88
+ return cls (is_checkpoint_fp8_serialized , kv_cache_quant_method ,
89
+ exclude_modules )
90
+
91
+ def is_layer_excluded (self , prefix : str ) -> bool :
92
+ """
93
+ Check if a layer should be excluded from quantization.
94
+
95
+ This method handles both regular models and multimodal models that use
96
+ the language_model prefix. For multimodal models, it checks if the
97
+ module name (without the language_model prefix) is in the exclude list.
98
+ """
99
+ if self .exclude_modules is None :
100
+ return False
101
+
102
+ # Check if any excluded module matches the prefix
103
+ for module in self .exclude_modules :
104
+ if (module in prefix
105
+ or (prefix .startswith ("language_model." )
106
+ and module in prefix .removeprefix ("language_model." ))):
107
+ return True
108
+ return False
80
109
81
110
def get_quant_method (self , layer : torch .nn .Module ,
82
111
prefix : str ) -> Optional ["QuantizeMethodBase" ]:
83
112
from vllm .attention .layer import Attention # Avoid circular import
84
113
if isinstance (layer , LinearBase ):
114
+ if self .is_layer_excluded (prefix ):
115
+ return UnquantizedLinearMethod ()
85
116
return ModelOptFp8LinearMethod (self )
86
117
elif isinstance (layer , Attention ):
87
118
return ModelOptFp8KVCacheMethod (self )
119
+ elif isinstance (layer , FusedMoE ):
120
+ return ModelOptFp8MoEMethod (self )
88
121
return None
89
122
90
123
91
124
class ModelOptFp8LinearMethod (LinearMethodBase ):
92
125
"""Linear method for Model Optimizer static quantization.
93
126
Supports loading FP8 checkpoints with static weight scale and
94
- activation scale. Future support might be added for dynamic
127
+ activation scale. Future support might be added for dynamic
95
128
scales.
96
129
97
130
Limitations:
98
131
1. Only support per-tensor quantization due to torch._scaled_mm support.
99
- 2. Only support float8_e4m3fn datatype
132
+ 2. Only support float8_e4m3fn datatype
100
133
Args: quant_config: The ModelOpt quantization config.
101
134
"""
102
135
@@ -172,6 +205,223 @@ def apply(
172
205
bias = bias )
173
206
174
207
208
+ class ModelOptFp8MoEMethod (FusedMoEMethodBase ):
209
+ """MoE method for ModelOpt FP8.
210
+ Supports loading FP8 checkpoints with static weight scale and
211
+ activation scale.
212
+ Args:
213
+ quant_config: The ModelOpt quantization config.
214
+ """
215
+
216
+ def __init__ (self , quant_config : ModelOptFp8Config ):
217
+ self .quant_config = quant_config
218
+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
219
+ cutlass_fp8_supported )
220
+ self .cutlass_fp8_supported = cutlass_fp8_supported ()
221
+
222
+ def create_weights (
223
+ self ,
224
+ layer : torch .nn .Module ,
225
+ num_experts : int ,
226
+ hidden_size : int ,
227
+ intermediate_size_per_partition : int ,
228
+ params_dtype : torch .dtype ,
229
+ ** extra_weight_attrs ,
230
+ ):
231
+
232
+ # Use FP8 dtype if checkpoint is serialized
233
+ weight_dtype = (torch .float8_e4m3fn
234
+ if self .quant_config .is_checkpoint_fp8_serialized else
235
+ params_dtype )
236
+ weight_loader = extra_weight_attrs .get ("weight_loader" )
237
+
238
+ w13_weight = ModelWeightParameter (
239
+ data = torch .empty (num_experts ,
240
+ 2 * intermediate_size_per_partition ,
241
+ hidden_size ,
242
+ dtype = weight_dtype ),
243
+ input_dim = 2 ,
244
+ output_dim = 1 ,
245
+ weight_loader = weight_loader ,
246
+ )
247
+ layer .register_parameter ("w13_weight" , w13_weight )
248
+
249
+ w2_weight = ModelWeightParameter (
250
+ data = torch .empty (num_experts ,
251
+ hidden_size ,
252
+ intermediate_size_per_partition ,
253
+ dtype = weight_dtype ),
254
+ input_dim = 2 ,
255
+ output_dim = 1 ,
256
+ weight_loader = weight_loader ,
257
+ )
258
+ layer .register_parameter ("w2_weight" , w2_weight )
259
+
260
+ if self .quant_config .is_checkpoint_fp8_serialized :
261
+ # WEIGHT SCALES - Per-tensor scaling for ModelOpts
262
+ # Allocate 2 scales for w1 and w3 respectively.
263
+ # They will be combined to a single scale after weight loading.
264
+ w13_weight_scale = PerTensorScaleParameter (
265
+ data = torch .full (
266
+ (num_experts , 2 ),
267
+ 1.0 ,
268
+ dtype = torch .float32 ,
269
+ ),
270
+ weight_loader = weight_loader ,
271
+ )
272
+ w2_weight_scale = PerTensorScaleParameter (
273
+ data = torch .full ((num_experts , ), 1.0 , dtype = torch .float32 ),
274
+ weight_loader = weight_loader ,
275
+ )
276
+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
277
+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
278
+
279
+ # Set weight loader attributes for scales
280
+ extra_weight_attrs .update (
281
+ {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
282
+
283
+ # INPUT SCALES - Per-tensor scaling for ModelOpt
284
+ w13_input_scale = PerTensorScaleParameter (
285
+ data = torch .full ((num_experts , ), 1.0 , dtype = torch .float32 ),
286
+ weight_loader = weight_loader ,
287
+ )
288
+ w2_input_scale = PerTensorScaleParameter (
289
+ data = torch .full ((num_experts , ), 1.0 , dtype = torch .float32 ),
290
+ weight_loader = weight_loader ,
291
+ )
292
+ layer .register_parameter ("w13_input_scale" , w13_input_scale )
293
+ layer .register_parameter ("w2_input_scale" , w2_input_scale )
294
+
295
+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
296
+ """Process FP8 MoE weights after loading from serialized checkpoint.
297
+ Only supports pre-quantized checkpoints with FP8 weights and scales.
298
+ """
299
+
300
+ layer .w13_weight = Parameter (layer .w13_weight .data ,
301
+ requires_grad = False )
302
+ layer .w2_weight = Parameter (layer .w2_weight .data , requires_grad = False )
303
+
304
+ from vllm ._custom_ops import scaled_fp8_quant
305
+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
306
+ per_tensor_dequantize )
307
+
308
+ # Handle scale parameters
309
+ if hasattr (layer ,
310
+ "w13_weight_scale" ) and layer .w13_weight_scale is not None :
311
+ # Fp8 moe kernel needs single weight scale for w13 per expert.
312
+ # We take the max of the w1 and w3 scales
313
+ # then dequant and requant each expert.
314
+ if layer .w13_weight_scale .dim () == 2 :
315
+
316
+ # Get the maximum scale across w1 and w3 for each expert
317
+ max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
318
+
319
+ # Requantize each expert's weights using the combined scale
320
+ # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
321
+ # where the first intermediate_size rows are w1, the next are w3
322
+ intermediate_size = layer .w13_weight .shape [1 ] // 2
323
+ for expert_id in range (layer .w13_weight .shape [0 ]):
324
+ start = 0
325
+ for shard_id in range (2 ): # w1 and w3
326
+ # Dequantize using the original scale for this shard
327
+ dq_weight = per_tensor_dequantize (
328
+ layer .w13_weight [expert_id ][start :start +
329
+ intermediate_size , :],
330
+ layer .w13_weight_scale [expert_id ][shard_id ],
331
+ )
332
+ # Requantize using the combined max scale
333
+
334
+ (
335
+ layer .w13_weight [expert_id ][start :start +
336
+ intermediate_size , :],
337
+ _ ,
338
+ ) = scaled_fp8_quant (dq_weight ,
339
+ max_w13_scales [expert_id ])
340
+
341
+ start += intermediate_size
342
+
343
+ # Update the scale parameter to be per-expert
344
+ layer .w13_weight_scale = Parameter (max_w13_scales ,
345
+ requires_grad = False )
346
+ else :
347
+ layer .w13_weight_scale = Parameter (layer .w13_weight_scale .data ,
348
+ requires_grad = False )
349
+
350
+ if hasattr (layer ,
351
+ "w2_weight_scale" ) and layer .w2_weight_scale is not None :
352
+ layer .w2_weight_scale = Parameter (layer .w2_weight_scale .data ,
353
+ requires_grad = False )
354
+ # Input scales must be equal for each expert in fp8 MoE layers.
355
+ if hasattr (layer ,
356
+ "w13_input_scale" ) and layer .w13_input_scale is not None :
357
+ layer .w13_input_scale = Parameter (layer .w13_input_scale .max (),
358
+ requires_grad = False )
359
+ if hasattr (layer ,
360
+ "w2_input_scale" ) and layer .w2_input_scale is not None :
361
+ layer .w2_input_scale = Parameter (layer .w2_input_scale .max (),
362
+ requires_grad = False )
363
+
364
+ def apply (
365
+ self ,
366
+ layer : torch .nn .Module ,
367
+ x : torch .Tensor ,
368
+ router_logits : torch .Tensor ,
369
+ top_k : int ,
370
+ renormalize : bool ,
371
+ use_grouped_topk : bool = False ,
372
+ topk_group : Optional [int ] = None ,
373
+ num_expert_group : Optional [int ] = None ,
374
+ global_num_experts : int = - 1 ,
375
+ expert_map : Optional [torch .Tensor ] = None ,
376
+ custom_routing_function : Optional [Callable ] = None ,
377
+ scoring_func : str = "softmax" ,
378
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
379
+ apply_router_weight_on_input : bool = False ,
380
+ activation : str = "silu" ,
381
+ enable_eplb : bool = False ,
382
+ expert_load_view : Optional [torch .Tensor ] = None ,
383
+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
384
+ logical_replica_count : Optional [torch .Tensor ] = None ,
385
+ ) -> torch .Tensor :
386
+ if enable_eplb :
387
+ raise NotImplementedError (
388
+ "EPLB not supported for `ModelOptFp8MoEMethod` yet." )
389
+
390
+ # Expert selection
391
+ topk_weights , topk_ids = FusedMoE .select_experts (
392
+ hidden_states = x ,
393
+ router_logits = router_logits ,
394
+ use_grouped_topk = use_grouped_topk ,
395
+ top_k = top_k ,
396
+ renormalize = renormalize ,
397
+ topk_group = topk_group ,
398
+ num_expert_group = num_expert_group ,
399
+ custom_routing_function = custom_routing_function ,
400
+ scoring_func = scoring_func ,
401
+ e_score_correction_bias = e_score_correction_bias ,
402
+ )
403
+ from vllm .model_executor .layers .fused_moe .fused_moe import (
404
+ fused_experts )
405
+ return fused_experts (
406
+ x ,
407
+ layer .w13_weight ,
408
+ layer .w2_weight ,
409
+ topk_weights = topk_weights ,
410
+ topk_ids = topk_ids ,
411
+ inplace = True ,
412
+ activation = activation ,
413
+ use_fp8_w8a8 = True ,
414
+ per_channel_quant = False ,
415
+ global_num_experts = global_num_experts ,
416
+ expert_map = expert_map ,
417
+ w1_scale = layer .w13_weight_scale ,
418
+ w2_scale = layer .w2_weight_scale ,
419
+ a1_scale = layer .w13_input_scale ,
420
+ a2_scale = layer .w2_input_scale ,
421
+ apply_router_weight_on_input = apply_router_weight_on_input ,
422
+ )
423
+
424
+
175
425
class ModelOptNvFp4Config (QuantizationConfig ):
176
426
"""Config class for ModelOpt FP4."""
177
427
@@ -274,7 +524,7 @@ def __init__(self, quant_config: Union[ModelOptFp8Config,
274
524
class ModelOptNvFp4LinearMethod (LinearMethodBase ):
275
525
"""Linear method for Model Optimizer NVFP4.
276
526
Supports loading NVFP4 checkpoints with the following structure:
277
-
527
+
278
528
input_scale: torch.float32, scalar ,
279
529
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
280
530
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
@@ -455,7 +705,7 @@ def apply(
455
705
class ModelOptNvFp4FusedMoE (FusedMoEMethodBase ):
456
706
"""
457
707
MoE Method for FP4 Quantization.
458
- Args:
708
+ Args:
459
709
quant_config: NVFP4 Quant Config
460
710
"""
461
711
@@ -472,6 +722,12 @@ def __init__(self, quant_config: ModelOptNvFp4Config):
472
722
" quantization. Please use Blackwell and"
473
723
" above." )
474
724
725
+ def uses_weight_scale_2_pattern (self ) -> bool :
726
+ """
727
+ FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
728
+ """
729
+ return True
730
+
475
731
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
476
732
hidden_size : int , intermediate_size_per_partition : int ,
477
733
params_dtype : torch .dtype , ** extra_weight_attrs ):
0 commit comments