diff --git a/vllm/config.py b/vllm/config.py index 508e09174cc..c065bdb4158 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -887,6 +887,10 @@ def _parse_quant_hf_config(self): if quant_cfg is None: # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) + if quant_cfg is not None: + if quant_cfg["producer"]["name"].lower() == "modelopt": + if "quant_algo" in quant_cfg.keys() and quant_cfg["quant_algo"].lower() == "fp8": + quant_cfg = {"quant_method": "modelopt"} return quant_cfg def _verify_quantization(self) -> None: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36ac75a8df4..2129de083a7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -81,6 +81,16 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError + def uses_weight_scale_2_pattern(self) -> bool: + """ + Returns True if this quantization method uses 'weight_scale_2' pattern + for per-tensor weight scales (e.g., FP4 variants), False otherwise. + + This method should be overridden by subclasses that use the + 'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. + """ + return False + def init_prepare_finalize(self, moe: FusedMoEConfig, quant_config: Optional[QuantizationConfig]): all2all_manager = get_ep_group().device_communicator.all2all_manager @@ -1049,12 +1059,23 @@ def weight_loader(self, # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern if "ModelOpt" in quant_method_name: - if ('weight_scale_2' in weight_name - or 'input_scale' in weight_name): - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + # Determine per-tensor weight scale patterns based on variant + # Use the dedicated method instead of brittle string matching + uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( + ) + + # For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale" + per_tensor_conditions = ( + "weight_scale_2" in weight_name if uses_weight_scale_2 else + "weight_scale" in weight_name) or "input_scale" in weight_name + + if per_tensor_conditions: + self._load_per_tensor_weight_scale( + shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id, + ) elif "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, @@ -1526,3 +1547,7 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) + +# Mark the FusedMoE weight_loader as supporting MoE-specific parameters +# to avoid expensive runtime reflection in model loading code +FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 47eca80609e..67083c3b4b5 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -190,6 +190,8 @@ def quantize_and_call_weight_loader(param: torch.nn.Parameter, weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) + # Mark as supporting MoE-specific loading to avoid expensive reflection + quantize_and_call_weight_loader.supports_moe_loading = True # type: ignore[attr-defined] return quantize_and_call_weight_loader diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2295c0e5fe9..6cdcf3f781b 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig): def __init__( self, is_checkpoint_fp8_serialized: bool = False, + kv_cache_quant_method: Optional[str] = None, + exclude_modules: Optional[list[str]] = None, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.kv_cache_quant_method = kv_cache_quant_method + self.exclude_modules = exclude_modules if is_checkpoint_fp8_serialized: logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change.") @@ -67,8 +71,18 @@ def get_config_filenames(cls) -> list[str]: @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": - quant_config = cls.get_from_keys(config, ["quantization"]) - quant_method = quant_config["quant_algo"] + try: + quant_method = cls.get_from_keys(config, ["quant_algo"]) + kv_cache_quant_method = cls.get_from_keys(config, ["kv_cache_scheme"]) + exclude_modules = cls.get_from_keys(config, ["ignore"]) + except: + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + kv_cache_quant_method = cls.get_from_keys( + config, ["quantization"]).get("kv_cache_quant_algo") + exclude_modules = cls.get_from_keys( + config, ["quantization"]).get("exclude_modules") + if quant_method not in QUANT_ALGOS: raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" " quantizations in vLLM. Please check the " @@ -76,27 +90,51 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": "quant configuration.") is_checkpoint_fp8_serialized = ("FP8" in quant_method) - return cls(is_checkpoint_fp8_serialized) + return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, + exclude_modules) + + def is_layer_excluded(self, prefix: str) -> bool: + """ + Check if a layer should be excluded from quantization. + + This method handles both regular models and multimodal models that use + the language_model prefix. For multimodal models, it checks if the + module name (without the language_model prefix) is in the exclude list. + """ + if self.exclude_modules is None: + return False + + # Check if any excluded module matches the prefix + for module in self.exclude_modules: + if (module in prefix + or (prefix.startswith("language_model.") + and module in prefix.removeprefix("language_model."))): + return True + return False def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): + if self.is_layer_excluded(prefix): + return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FusedMoE): + return ModelOptFp8MoEMethod(self) return None class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and - activation scale. Future support might be added for dynamic + activation scale. Future support might be added for dynamic scales. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. - 2. Only support float8_e4m3fn datatype + 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config. """ @@ -171,6 +209,223 @@ def apply( bias=bias) +class ModelOptFp8MoEMethod(FusedMoEMethodBase): + """MoE method for ModelOpt FP8. + Supports loading FP8 checkpoints with static weight scale and + activation scale. + Args: + quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_fp8_supported) + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + # Use FP8 dtype if checkpoint is serialized + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight_loader = extra_weight_attrs.get("weight_loader") + + w13_weight = ModelWeightParameter( + data=torch.empty(num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=weight_dtype), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + w2_weight = ModelWeightParameter( + data=torch.empty(num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=weight_dtype), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALES - Per-tensor scaling for ModelOpts + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = PerTensorScaleParameter( + data=torch.full( + (num_experts, 2), + 1.0, + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + w2_weight_scale = PerTensorScaleParameter( + data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Set weight loader attributes for scales + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + + # INPUT SCALES - Per-tensor scaling for ModelOpt + w13_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + w2_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Process FP8 MoE weights after loading from serialized checkpoint. + Only supports pre-quantized checkpoints with FP8 weights and scales. + """ + + layer.w13_weight = Parameter(layer.w13_weight.data, + requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + from vllm._custom_ops import scaled_fp8_quant + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + per_tensor_dequantize) + + # Handle scale parameters + if hasattr(layer, + "w13_weight_scale") and layer.w13_weight_scale is not None: + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max of the w1 and w3 scales + # then dequant and requant each expert. + if layer.w13_weight_scale.dim() == 2: + + # Get the maximum scale across w1 and w3 for each expert + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + + # Requantize each expert's weights using the combined scale + # w13_weight (num_experts, 2 * intermediate_size, hidden_size) + # where the first intermediate_size rows are w1, the next are w3 + intermediate_size = layer.w13_weight.shape[1] // 2 + for expert_id in range(layer.w13_weight.shape[0]): + start = 0 + for shard_id in range(2): # w1 and w3 + # Dequantize using the original scale for this shard + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + intermediate_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + # Requantize using the combined max scale + + ( + layer.w13_weight[expert_id][start:start + + intermediate_size, :], + _, + ) = scaled_fp8_quant(dq_weight, + max_w13_scales[expert_id]) + + start += intermediate_size + + # Update the scale parameter to be per-expert + layer.w13_weight_scale = Parameter(max_w13_scales, + requires_grad=False) + else: + layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, + requires_grad=False) + + if hasattr(layer, + "w2_weight_scale") and layer.w2_weight_scale is not None: + layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, + requires_grad=False) + # Input scales must be equal for each expert in fp8 MoE layers. + if hasattr(layer, + "w13_input_scale") and layer.w13_input_scale is not None: + layer.w13_input_scale = Parameter(layer.w13_input_scale.max(), + requires_grad=False) + if hasattr(layer, + "w2_input_scale") and layer.w2_input_scale is not None: + layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptFp8MoEMethod` yet.") + + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + class ModelOptNvFp4Config(QuantizationConfig): """Config class for ModelOpt FP4.""" @@ -273,7 +528,7 @@ def __init__(self, quant_config: Union[ModelOptFp8Config, class ModelOptNvFp4LinearMethod(LinearMethodBase): """Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure: - + input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, @@ -454,7 +709,7 @@ def apply( class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. - Args: + Args: quant_config: NVFP4 Quant Config """ @@ -471,6 +726,12 @@ def __init__(self, quant_config: ModelOptNvFp4Config): " quantization. Please use Blackwell and" " above.") + def uses_weight_scale_2_pattern(self) -> bool: + """ + FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. + """ + return True + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index c5055a02fa3..f03c7b3d501 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -454,4 +454,6 @@ def moe_wna16_weight_loader(param: torch.nn.Parameter, weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) + # Mark as supporting MoE-specific loading to avoid expensive reflection + moe_wna16_weight_loader.supports_moe_loading = True # type: ignore[attr-defined] return moe_wna16_weight_loader diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 857f4bca682..a70c89f2d82 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -758,6 +758,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: modelopt_scale_names = [ ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" ] + # Also support qkv_proj scale parameters (from stacked parameter processing) + qkv_proj_scale_names = [ + ".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale" + ] for scale_name in possible_scale_names: if name.endswith(scale_name): if any(mo_scale_name in name @@ -765,6 +769,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: remapped_name = name.replace( f".self_attn.{scale_name[1]}_proj{scale_name}", f".self_attn.attn{scale_name}") + elif any(qkv_scale_name in name + for qkv_scale_name in qkv_proj_scale_names): + # Handle qkv_proj scale parameters + remapped_name = name.replace( + f".self_attn.qkv_proj{scale_name}", + f".self_attn.attn{scale_name}") else: remapped_name = name.replace(scale_name, f".attn{scale_name}") if remapped_name not in params_dict: diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0c9baab1f2e..e740e00c3cd 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -35,7 +35,8 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, @@ -432,12 +433,23 @@ def load_weights(self, weights: Iterable[tuple[str, for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name or "experts" in name: continue - name = name.replace(weight_name, param_name) + if not (name.endswith( + (".k_scale", ".v_scale")) and "self_attn" in name): + name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue + if name.endswith("scale") and "expert" not in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break else: @@ -452,6 +464,49 @@ def load_weights(self, weights: Iterable[tuple[str, if not moe_loaded: if is_pp_missing_parameter(name, self): continue + + # Handle flat expert scale parameters that + # don't match per-expert patterns + if ("experts." in name and ("w13_input_scale" in name + or "w13_weight_scale" in name + or "w2_input_scale" in name + or "w2_weight_scale" in name)): + # These are flat expert scales that apply to all experts + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + # Check if this is a MoE-specific weight loader that + # needs extra arguments + if hasattr(param, 'weight_loader'): + # Check for MoE-specific loading support via + # attribute instead of expensive runtime reflection + supports_moe = getattr(weight_loader, + 'supports_moe_loading', + False) + + if supports_moe: + # This is a MoE weight loader + if "w13_" in name: + shard_id = "w1" + elif "w2_" in name: + shard_id = "w2" + else: + shard_id = "w1" + + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=0) + else: + # Regular weight loader + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 1276d626a7c..b7d0a1ddafd 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -902,32 +903,106 @@ def _consolidate_qkv_weights( qkv_weight = torch.cat(weight, dim=0) yield key, qkv_weight - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def _rename_weight_for_checkpoint(self, name: str) -> str: + """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM + format.""" + if name.startswith("model."): + # Handle expert scale parameters with flat naming + if "feed_forward.experts." in name and ("_input_scale" in name or + "_weight_scale" in name): + renamed = name.replace("model.", "language_model.model.", 1) + # Map checkpoint naming to vLLM's expected naming + if "down_proj_input_scale" in renamed: + return renamed.replace("down_proj_input_scale", + "w2_input_scale") + elif "down_proj_weight_scale" in renamed: + return renamed.replace("down_proj_weight_scale", + "w2_weight_scale") + elif "gate_up_proj_input_scale" in renamed: + return renamed.replace("gate_up_proj_input_scale", + "w13_input_scale") + elif "gate_up_proj_weight_scale" in renamed: + return renamed.replace("gate_up_proj_weight_scale", + "w13_weight_scale") + return renamed + + # Handle attention scale parameters + elif "self_attn." in name and (".k_scale" in name + or ".v_scale" in name): + renamed = name.replace("model.", "language_model.model.", 1) + if ".k_proj.k_scale" in renamed: + return renamed.replace(".k_proj.k_scale", ".attn.k_scale") + elif ".v_proj.v_scale" in renamed: + return renamed.replace(".v_proj.v_scale", ".attn.v_scale") + return renamed + + # Standard model.* to language_model.model.* renaming + return name.replace("model.", "language_model.model.", 1) + + elif name.startswith("lm_head.weight"): + return name.replace("lm_head.weight", + "language_model.lm_head.weight") + + return name + + def _separate_and_rename_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]: + """Rename weights and separate them into language_model and other + weights.""" + language_model_weights = [] + other_weights = [] - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), - (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), - (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - updated_params: set[str] = set() + for name, weight in weights: + renamed = self._rename_weight_for_checkpoint(name) - # language_model is an Llama4ForCausalLM instance. We load it's - # using llama4's load_weights routine. - language_model_weights, other_weights = self.separate_weights( - weights, prefix="language_model.") - loader = AutoWeightsLoader(self) - loaded_language_model_params = loader.load_weights( - language_model_weights) - assert loaded_language_model_params is not None - updated_params.update(loaded_language_model_params) + if renamed.startswith("language_model."): + language_model_weights.append((renamed, weight)) + else: + other_weights.append((renamed, weight)) + + return language_model_weights, other_weights + + def _handle_expert_scale_broadcasting( + self, weights: list[tuple[str, torch.Tensor]], params_dict: dict + ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]: + """Handle expert scale parameters that need broadcasting.""" + regular_weights = [] + expert_scale_weights = [] + updated_params = set() + + for name, weight in weights: + # Check if this is an expert scale parameter that needs broadcasting + if ("feed_forward.experts." in name and "scale" in name + and ".shared_expert" not in name): + if name in params_dict: + param = params_dict[name] + if (hasattr(param, 'data') and param.data.numel() > 1 + and weight.numel() == 1): + # Broadcast single value to all experts + param.data.fill_(weight.item()) + updated_params.add(name) + continue + + expert_scale_weights.append((name, weight)) + else: + regular_weights.append((name, weight)) + + return regular_weights, expert_scale_weights, updated_params + + def _load_other_weights(self, other_weights: Iterable[tuple[str, + torch.Tensor]], + params_dict: dict, + stacked_params_mapping: list) -> set[str]: + """Load non-language-model weights with stacking support.""" + updated_params = set() if self.use_data_parallel: other_weights = self._consolidate_qkv_weights(other_weights) for name, loaded_weight in other_weights: + # Try stacked parameter mapping first + mapped = False for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name or self.use_data_parallel: continue @@ -936,12 +1011,60 @@ def load_weights(self, weights: Iterable[tuple[str, updated_params.add(name) weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + mapped = True break - else: + + if not mapped: + # Use regular weight loading param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) updated_params.add(name) + + return updated_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + # Shared expert gate_up_proj stacking + (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0), + (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1), + # Feed forward gate_up_proj stacking (for non-MoE layers if any) + (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0), + (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params: set[str] = set() + + # Separate and rename weights + language_model_weights, other_weights = ( + self._separate_and_rename_weights(weights)) + + # Handle expert scale parameters + regular_weights, expert_scale_weights, updated_params_from_experts = ( + self._handle_expert_scale_broadcasting(language_model_weights, + params_dict)) + updated_params.update(updated_params_from_experts) + + loader = AutoWeightsLoader(self) + loaded_language_model_params = loader.load_weights(regular_weights) + assert loaded_language_model_params is not None + updated_params.update(loaded_language_model_params) + + if expert_scale_weights: + loaded_expert_scale_params = loader.load_weights( + expert_scale_weights) + if loaded_expert_scale_params: + updated_params.update(loaded_expert_scale_params) + + updated_params.update( + self._load_other_weights(other_weights, params_dict, + stacked_params_mapping)) + return updated_params diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7ef9d248da4..1186d65425f 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -400,9 +400,17 @@ def load_weights(self, weights: Iterable[tuple[str, continue if is_pp_missing_parameter(name, self): continue + if name.endswith("scale"): + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index ff182aadf73..09b32d16038 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -46,7 +46,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -406,6 +408,10 @@ def load_weights(self, weights: Iterable[tuple[str, # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + if name.endswith("scale"): + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise @@ -427,8 +433,12 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: