diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index a96f3ee5c30..98b8aeb9d19 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -588,7 +588,70 @@ def _create_weights_8bit( params_dtype: torch.dtype, **extra_weight_attrs, ): - raise NotImplementedError + from bitsandbytes.nn import Int8Params + + # Fused gate_up_proj (column parallel) + # For 8-bit, the data is not packed into uint8 from a lower bitwidth. + # It's directly int8, so pack_factor is 1 (or can be omitted if not used). + # The shape is the dequantized shape, but the dtype is int8. + w13_qweight = Int8Params( + data=torch.empty( + num_experts, + intermediate_size_per_partition * 2, + hidden_size, + dtype=torch.int8, + ), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + set_weight_attrs( + w13_qweight, + { + "num_experts": num_experts, + "input_dim": hidden_size, + "output_dim": 2 * intermediate_size_per_partition, + "experts_shape": ( + num_experts, + intermediate_size_per_partition * 2, + hidden_size, + ), + "pack_factor": 1, # 8-bit means 1 byte per value, no packing ratio + "use_bitsandbytes_8bit": True, + "generation": 0, # Added for 8bit matmul states tracking + }, + ) + + # down_proj (row parallel) + w2_qweight = Int8Params( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + set_weight_attrs( + w2_qweight, + { + "num_experts": num_experts, + "input_dim": intermediate_size_per_partition, + "output_dim": hidden_size, + "experts_shape": ( + num_experts, + hidden_size, + intermediate_size_per_partition, + ), + "pack_factor": 1, # 8-bit means 1 byte per value, no packing ratio + "use_bitsandbytes_8bit": True, + "generation": 0, # Added for 8bit matmul states tracking + }, + ) def _apply_4bit_dequnt( self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: @@ -607,4 +670,21 @@ def _apply_4bit_dequnt( def _apply_8bit_dequant( self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError + from bitsandbytes.functional import dequantize_8bit + + # Dequantize w13_weight + # Int8Params store the actual int8 data in .data and scales in .SCB or bnb_quant_state. + # The dequantize_8bit function typically requires the int8 data and its corresponding scale. + # Here, we assume bnb_quant_state holds the scale (SCB). + w13_qweight = layer.w13_weight + w13 = dequantize_8bit(w13_qweight.data, w13_qweight.bnb_quant_state.SCB) + + # Dequantize w2_weight + w2_qweight = layer.w2_weight + w2 = dequantize_8bit(w2_qweight.data, w2_qweight.bnb_quant_state.SCB) + + # Reshape to the expected expert shape + w13 = w13.reshape(w13_qweight.experts_shape) + w2 = w2.reshape(w2_qweight.experts_shape) + + return w13, w2 diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 907bc3c1361..c53ec23a9cf 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -429,15 +429,6 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: f"MoE Model {type(model).__name__} does not support " "BitsAndBytes quantization yet. Ensure this model has " "'get_expert_mapping' method.") - # TODO: support FusedMoE with prequant and 8bit. - if self.pre_quant: - raise ValueError( - "Prequant BitsAndBytes models with FusedMoE is not " - "supported yet.") - if self.load_8bit: - raise ValueError( - "BitsAndBytes 8bit quantization with FusedMoE is not " - "supported yet.") # Get the corresponding weight name using module name and # get_expert_mapping. expert_mapping = model.get_expert_mapping() diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index fab1c163ac2..b2672fdedbd 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -509,6 +509,22 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """ + Returns a mapping for MoE expert weights to their corresponding + checkpoint names. This is used by BitsAndBytesModelLoader + to correctly identify and fuse pre-quantized expert weights. + """ + # For Llama-4-Scout models, the expert weights are typically named + # w1, w2, w3 in the Hugging Face checkpoint. + # The 'gate_proj' corresponds to 'w1', 'down_proj' to 'w2', + # and 'up_proj' to 'w3' in the FusedMoE context. + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts) class Llama4ForCausalLM(LlamaForCausalLM): @@ -580,3 +596,9 @@ def permute(w: torch.Tensor, n_heads: int): self.config.num_attention_heads) return name, loaded_weight + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """ + Delegates the call to the underlying Llama4Model for expert mapping. + """ + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index dea85d320ad..b834eeda58a 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -1069,3 +1069,14 @@ def load_weights(self, weights: Iterable[tuple[str, stacked_params_mapping)) return updated_params + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """ + Delegates the call to the underlying language model for expert mapping. + """ + if not hasattr(self.language_model, "get_expert_mapping"): + # This should ideally not happen if llama4.py is correctly updated + raise AttributeError( + "Llama4ForCausalLM does not have 'get_expert_mapping' method." + "Ensure vllm/model_executor/models/llama4.py is up-to-date.") + return self.language_model.get_expert_mapping() diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 0f749b3e38f..4b3ad0418c6 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -122,7 +122,7 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, bias=False, - quant_config=None, + quant_config=quant_config, prefix=f"{prefix}.gate") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: