Skip to content

Enable mob pre quant #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions vllm/model_executor/layers/quantization/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,70 @@ def _create_weights_8bit(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
from bitsandbytes.nn import Int8Params

# Fused gate_up_proj (column parallel)
# For 8-bit, the data is not packed into uint8 from a lower bitwidth.
# It's directly int8, so pack_factor is 1 (or can be omitted if not used).
# The shape is the dequantized shape, but the dtype is int8.
w13_qweight = Int8Params(
data=torch.empty(
num_experts,
intermediate_size_per_partition * 2,
hidden_size,
dtype=torch.int8,
),
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
set_weight_attrs(
w13_qweight,
{
"num_experts": num_experts,
"input_dim": hidden_size,
"output_dim": 2 * intermediate_size_per_partition,
"experts_shape": (
num_experts,
intermediate_size_per_partition * 2,
hidden_size,
),
"pack_factor": 1, # 8-bit means 1 byte per value, no packing ratio
"use_bitsandbytes_8bit": True,
"generation": 0, # Added for 8bit matmul states tracking
},
)

# down_proj (row parallel)
w2_qweight = Int8Params(
data=torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=torch.int8,
),
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
set_weight_attrs(
w2_qweight,
{
"num_experts": num_experts,
"input_dim": intermediate_size_per_partition,
"output_dim": hidden_size,
"experts_shape": (
num_experts,
hidden_size,
intermediate_size_per_partition,
),
"pack_factor": 1, # 8-bit means 1 byte per value, no packing ratio
"use_bitsandbytes_8bit": True,
"generation": 0, # Added for 8bit matmul states tracking
},
)

def _apply_4bit_dequnt(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -607,4 +670,21 @@ def _apply_4bit_dequnt(

def _apply_8bit_dequant(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
from bitsandbytes.functional import dequantize_8bit

# Dequantize w13_weight
# Int8Params store the actual int8 data in .data and scales in .SCB or bnb_quant_state.
# The dequantize_8bit function typically requires the int8 data and its corresponding scale.
# Here, we assume bnb_quant_state holds the scale (SCB).
w13_qweight = layer.w13_weight
w13 = dequantize_8bit(w13_qweight.data, w13_qweight.bnb_quant_state.SCB)

# Dequantize w2_weight
w2_qweight = layer.w2_weight
w2 = dequantize_8bit(w2_qweight.data, w2_qweight.bnb_quant_state.SCB)

# Reshape to the expected expert shape
w13 = w13.reshape(w13_qweight.experts_shape)
w2 = w2.reshape(w2_qweight.experts_shape)

return w13, w2
9 changes: 0 additions & 9 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,6 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant:
raise ValueError(
"Prequant BitsAndBytes models with FusedMoE is not "
"supported yet.")
if self.load_8bit:
raise ValueError(
"BitsAndBytes 8bit quantization with FusedMoE is not "
"supported yet.")
# Get the corresponding weight name using module name and
# get_expert_mapping.
expert_mapping = model.get_expert_mapping()
Expand Down
22 changes: 22 additions & 0 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,22 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""
Returns a mapping for MoE expert weights to their corresponding
checkpoint names. This is used by BitsAndBytesModelLoader
to correctly identify and fuse pre-quantized expert weights.
"""
# For Llama-4-Scout models, the expert weights are typically named
# w1, w2, w3 in the Hugging Face checkpoint.
# The 'gate_proj' corresponds to 'w1', 'down_proj' to 'w2',
# and 'up_proj' to 'w3' in the FusedMoE context.
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_local_experts)


class Llama4ForCausalLM(LlamaForCausalLM):
Expand Down Expand Up @@ -580,3 +596,9 @@ def permute(w: torch.Tensor, n_heads: int):
self.config.num_attention_heads)

return name, loaded_weight

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""
Delegates the call to the underlying Llama4Model for expert mapping.
"""
return self.model.get_expert_mapping()
11 changes: 11 additions & 0 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,3 +1069,14 @@ def load_weights(self, weights: Iterable[tuple[str,
stacked_params_mapping))

return updated_params

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""
Delegates the call to the underlying language model for expert mapping.
"""
if not hasattr(self.language_model, "get_expert_mapping"):
# This should ideally not happen if llama4.py is correctly updated
raise AttributeError(
"Llama4ForCausalLM does not have 'get_expert_mapping' method."
"Ensure vllm/model_executor/models/llama4.py is up-to-date.")
return self.language_model.get_expert_mapping()
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
quant_config=quant_config,
prefix=f"{prefix}.gate")

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down