Skip to content

Commit fd8e84a

Browse files
authored
[Model] Qwen3 FP8 support (#3219)
This PR enables the FP8 support for Qwen3 family.
1 parent f870395 commit fd8e84a

File tree

7 files changed

+194
-43
lines changed

7 files changed

+194
-43
lines changed

python/mlc_llm/model/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ class Model:
340340
"no-quant": qwen3_quantization.no_quant,
341341
"group-quant": qwen3_quantization.group_quant,
342342
"ft-quant": qwen3_quantization.ft_quant,
343+
"block-scale-quant": qwen3_quantization.block_scale_quant,
343344
},
344345
),
345346
"qwen3_moe": Model(
@@ -354,6 +355,7 @@ class Model:
354355
"no-quant": qwen3_moe_quantization.no_quant,
355356
"group-quant": qwen3_moe_quantization.group_quant,
356357
"ft-quant": qwen3_moe_quantization.ft_quant,
358+
"block-scale-quant": qwen3_moe_quantization.block_scale_quant,
357359
},
358360
),
359361
"deepseek_v2": Model(

python/mlc_llm/model/qwen3/qwen3_loader.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
"""
55

66
import functools
7+
from typing import Callable, List
78

89
import numpy as np
910

10-
from mlc_llm.loader import ExternMapping
11-
from mlc_llm.quantization import Quantization
11+
from mlc_llm.loader import ExternMapping, QuantizeMapping
12+
from mlc_llm.quantization import BlockScaleQuantize, Quantization
1213

1314
from .qwen3_model import Qwen3Config, Qwen3LMHeadModel
1415

@@ -33,6 +34,15 @@ def huggingface(model_config: Qwen3Config, quantization: Quantization) -> Extern
3334
model = Qwen3LMHeadModel(model_config)
3435
if quantization is not None:
3536
model.to(quantization.model_dtype)
37+
if isinstance(quantization, BlockScaleQuantize):
38+
# Convert the model to block-scale quantized model before loading parameters
39+
model = quantization.quantize_model(model, QuantizeMapping({}, {}), "")
40+
if model_config.weight_block_size is None:
41+
raise ValueError(
42+
"The input Qwen3 model is not fp8 block quantized. "
43+
"Thus BlockScaleQuantize is not supported."
44+
)
45+
3646
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
3747
spec=model.get_default_spec(),
3848
allow_extern=True,
@@ -41,19 +51,60 @@ def huggingface(model_config: Qwen3Config, quantization: Quantization) -> Extern
4151

4252
mapping = ExternMapping()
4353

54+
if (
55+
not isinstance(quantization, BlockScaleQuantize)
56+
and model_config.weight_block_size is not None
57+
):
58+
raise ValueError(
59+
"The input Qwen3 model is fp8 block quantized. "
60+
"Please use BlockScaleQuantize for the model."
61+
)
62+
63+
# Helper function to add both weight and scale mappings
64+
def add_weight_and_scale_mapping(
65+
weight_mlc_name: str,
66+
weight_hf_names: List[str],
67+
weight_transform_func: Callable,
68+
):
69+
mlc_param = named_parameters[weight_mlc_name]
70+
mapping.add_mapping(
71+
weight_mlc_name,
72+
weight_hf_names,
73+
functools.partial(weight_transform_func, dtype=mlc_param.dtype),
74+
)
75+
76+
if isinstance(quantization, BlockScaleQuantize):
77+
scale_mlc_name = f"{weight_mlc_name}_scale_inv"
78+
if scale_mlc_name in named_parameters:
79+
scale_hf_names = [f"{name}_scale_inv" for name in weight_hf_names]
80+
scale_param = named_parameters[scale_mlc_name]
81+
mapping.add_mapping(
82+
scale_mlc_name,
83+
scale_hf_names,
84+
functools.partial(weight_transform_func, dtype=scale_param.dtype),
85+
)
86+
4487
for i in range(model_config.num_hidden_layers):
4588
# map attention weight
4689
attn = f"model.layers.{i}.self_attn"
47-
weight_names = ["weight", "bias"] if model_config.attention_bias else ["weight"]
48-
for weight_type in weight_names:
49-
mlc_name = f"{attn}.c_attn.{weight_type}"
90+
add_weight_and_scale_mapping(
91+
f"{attn}.c_attn.weight",
92+
[
93+
f"{attn}.q_proj.weight",
94+
f"{attn}.k_proj.weight",
95+
f"{attn}.v_proj.weight",
96+
],
97+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
98+
)
99+
if model_config.attention_bias:
100+
mlc_name = f"{attn}.c_attn.bias"
50101
mlc_param = named_parameters[mlc_name]
51102
mapping.add_mapping(
52103
mlc_name,
53104
[
54-
f"{attn}.q_proj.{weight_type}",
55-
f"{attn}.k_proj.{weight_type}",
56-
f"{attn}.v_proj.{weight_type}",
105+
f"{attn}.q_proj.bias",
106+
f"{attn}.k_proj.bias",
107+
f"{attn}.v_proj.bias",
57108
],
58109
functools.partial(
59110
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
@@ -62,18 +113,13 @@ def huggingface(model_config: Qwen3Config, quantization: Quantization) -> Extern
62113
)
63114
# map mlp weight
64115
mlp = f"model.layers.{i}.mlp"
65-
mlc_name = f"{mlp}.gate_up_proj.weight"
66-
mlc_param = named_parameters[mlc_name]
67-
mapping.add_mapping(
68-
mlc_name,
116+
add_weight_and_scale_mapping(
117+
f"{mlp}.gate_up_proj.weight",
69118
[
70119
f"{mlp}.gate_proj.weight",
71120
f"{mlp}.up_proj.weight",
72121
],
73-
functools.partial(
74-
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
75-
dtype=mlc_param.dtype,
76-
),
122+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
77123
)
78124

79125
for mlc_name, mlc_param in named_parameters.items():

python/mlc_llm/model/qwen3/qwen3_model.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import dataclasses
66
from functools import partial
7-
from typing import Any, Dict, Optional
7+
from typing import Any, Dict, Optional, Tuple
88

99
from tvm import te, tir
1010
from tvm.relax.frontend import nn
@@ -41,9 +41,34 @@ class Qwen3Config(ConfigBase): # pylint: disable=too-many-instance-attributes
4141
head_dim: int = 0
4242
dtype: str = "float32"
4343
max_batch_size: int = 1
44+
weight_block_size: Optional[Tuple[int, int]] = None
4445
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
4546

4647
def __post_init__(self):
48+
if "quantization_config" in self.kwargs:
49+
quantization_config = self.kwargs.get("quantization_config")
50+
if (
51+
isinstance(quantization_config, dict)
52+
and quantization_config.get("activation_scheme", "") == "dynamic"
53+
and quantization_config.get("fmt", "") == "e4m3"
54+
and quantization_config.get("quant_method", "") == "fp8"
55+
and "weight_block_size" in quantization_config
56+
):
57+
self.weight_block_size = quantization_config.get("weight_block_size")
58+
if (
59+
not isinstance(self.weight_block_size, (tuple, list))
60+
or len(self.weight_block_size) != 2
61+
):
62+
raise ValueError(
63+
"Invalid DeepSeek model quantization config: "
64+
"weight_block_size must be a tuple of two integers, "
65+
f"got {self.weight_block_size} of type {type(self.weight_block_size)}"
66+
)
67+
else:
68+
raise ValueError(
69+
"Invalid DeepSeek model quantization config: unrecognized quantization config: "
70+
f"{quantization_config}"
71+
)
4772
if self.context_window_size == 0:
4873
for name in ["max_position_embeddings", "max_sequence_length"]:
4974
if name in self.kwargs:
@@ -247,6 +272,7 @@ def __init__(self, config: Qwen3Config):
247272
self.vocab_size = config.vocab_size
248273
self.tensor_parallel_shards = config.tensor_parallel_shards
249274
self.head_dim = config.head_dim
275+
self.weight_block_size = config.weight_block_size
250276

251277
def to(self, dtype: Optional[str] = None):
252278
super().to(dtype=dtype)

python/mlc_llm/model/qwen3/qwen3_quantization.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from tvm.relax.frontend import nn
77

88
from mlc_llm.loader import QuantizeMapping
9-
from mlc_llm.quantization import FTQuantize, GroupQuantize, NoQuantize
9+
from mlc_llm.quantization import (
10+
BlockScaleQuantize,
11+
FTQuantize,
12+
GroupQuantize,
13+
NoQuantize,
14+
)
1015

1116
from .qwen3_model import Qwen3Config, Qwen3LMHeadModel
1217

@@ -53,3 +58,15 @@ def no_quant(
5358
model.to(quantization.model_dtype)
5459
quant_map = QuantizeMapping({}, {})
5560
return model, quant_map
61+
62+
63+
def block_scale_quant(
64+
model_config: Qwen3Config,
65+
quantization: BlockScaleQuantize,
66+
) -> Tuple[nn.Module, QuantizeMapping]:
67+
"""Quantize a Qwen3 model using block-scale quantization."""
68+
model: nn.Module = Qwen3LMHeadModel(model_config)
69+
model.to(quantization.model_dtype)
70+
quant_map = QuantizeMapping({}, {})
71+
model = quantization.quantize_model(model, quant_map, "")
72+
return model, quant_map

python/mlc_llm/model/qwen3_moe/qwen3_moe_loader.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
"""
55

66
import functools
7+
from typing import Callable, List
78

89
import numpy as np
910

10-
from mlc_llm.loader import ExternMapping
11-
from mlc_llm.quantization import Quantization
11+
from mlc_llm.loader import ExternMapping, QuantizeMapping
12+
from mlc_llm.quantization import BlockScaleQuantize, Quantization
1213

1314
from .qwen3_moe_model import Qwen3MoeConfig, Qwen3MoeForCausalLM
1415

@@ -33,6 +34,15 @@ def huggingface(model_config: Qwen3MoeConfig, quantization: Quantization) -> Ext
3334
model = Qwen3MoeForCausalLM(model_config)
3435
if quantization is not None:
3536
model.to(quantization.model_dtype)
37+
if isinstance(quantization, BlockScaleQuantize):
38+
# Convert the model to block-scale quantized model before loading parameters
39+
model = quantization.quantize_model(model, QuantizeMapping({}, {}), "")
40+
if model_config.weight_block_size is None:
41+
raise ValueError(
42+
"The input Qwen3 model is not fp8 block quantized. "
43+
"Thus BlockScaleQuantize is not supported."
44+
)
45+
3646
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
3747
spec=model.get_default_spec(),
3848
allow_extern=True,
@@ -41,19 +51,60 @@ def huggingface(model_config: Qwen3MoeConfig, quantization: Quantization) -> Ext
4151

4252
mapping = ExternMapping()
4353

54+
if (
55+
not isinstance(quantization, BlockScaleQuantize)
56+
and model_config.weight_block_size is not None
57+
):
58+
raise ValueError(
59+
"The input Qwen3 model is fp8 block quantized. "
60+
"Please use BlockScaleQuantize for the model."
61+
)
62+
63+
# Helper function to add both weight and scale mappings
64+
def add_weight_and_scale_mapping(
65+
weight_mlc_name: str,
66+
weight_hf_names: List[str],
67+
weight_transform_func: Callable,
68+
):
69+
mlc_param = named_parameters[weight_mlc_name]
70+
mapping.add_mapping(
71+
weight_mlc_name,
72+
weight_hf_names,
73+
functools.partial(weight_transform_func, dtype=mlc_param.dtype),
74+
)
75+
76+
if isinstance(quantization, BlockScaleQuantize):
77+
scale_mlc_name = f"{weight_mlc_name}_scale_inv"
78+
if scale_mlc_name in named_parameters:
79+
scale_hf_names = [f"{name}_scale_inv" for name in weight_hf_names]
80+
scale_param = named_parameters[scale_mlc_name]
81+
mapping.add_mapping(
82+
scale_mlc_name,
83+
scale_hf_names,
84+
functools.partial(weight_transform_func, dtype=scale_param.dtype),
85+
)
86+
4487
for i in range(model_config.num_hidden_layers):
4588
# map attention weight
4689
attn = f"model.layers.{i}.self_attn"
47-
weight_names = ["weight", "bias"] if model_config.attention_bias else ["weight"]
48-
for weight_type in weight_names:
49-
mlc_name = f"{attn}.c_attn.{weight_type}"
90+
add_weight_and_scale_mapping(
91+
f"{attn}.c_attn.weight",
92+
[
93+
f"{attn}.q_proj.weight",
94+
f"{attn}.k_proj.weight",
95+
f"{attn}.v_proj.weight",
96+
],
97+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
98+
)
99+
if model_config.attention_bias:
100+
mlc_name = f"{attn}.c_attn.bias"
50101
mlc_param = named_parameters[mlc_name]
51102
mapping.add_mapping(
52103
mlc_name,
53104
[
54-
f"{attn}.q_proj.{weight_type}",
55-
f"{attn}.k_proj.{weight_type}",
56-
f"{attn}.v_proj.{weight_type}",
105+
f"{attn}.q_proj.bias",
106+
f"{attn}.k_proj.bias",
107+
f"{attn}.v_proj.bias",
57108
],
58109
functools.partial(
59110
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
@@ -62,16 +113,15 @@ def huggingface(model_config: Qwen3MoeConfig, quantization: Quantization) -> Ext
62113
)
63114
# map mlp moe gate and up weight
64115
mlp = f"model.layers.{i}.mlp"
65-
mlc_name = f"{mlp}.moe_gate_up_proj.weight"
66116

67117
def combine_expert_gate_up(*hf_params, dtype):
68118
stack = []
69119
for i in range(0, len(hf_params), 2):
70120
stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))
71121
return np.stack(stack, axis=0).astype(dtype)
72122

73-
mapping.add_mapping(
74-
mlc_name,
123+
add_weight_and_scale_mapping(
124+
f"{mlp}.moe_gate_up_proj.weight",
75125
functools.reduce(
76126
lambda a, b: a + b,
77127
[
@@ -82,25 +132,17 @@ def combine_expert_gate_up(*hf_params, dtype):
82132
for expert_id in range(model_config.num_experts)
83133
],
84134
),
85-
functools.partial(
86-
combine_expert_gate_up,
87-
dtype=mlc_param.dtype,
88-
),
135+
combine_expert_gate_up,
89136
)
90137

91-
# map mlp moe gate and up weight
92-
mlc_name = f"{mlp}.moe_down_proj.weight"
93-
mlc_param = named_parameters[mlc_name]
94-
mapping.add_mapping(
95-
mlc_name,
138+
# map mlp moe down projection weight
139+
add_weight_and_scale_mapping(
140+
f"{mlp}.moe_down_proj.weight",
96141
[
97142
f"{mlp}.experts.{expert_id}.down_proj.weight"
98143
for expert_id in range(model_config.num_experts)
99144
],
100-
functools.partial(
101-
lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),
102-
dtype=mlc_param.dtype,
103-
),
145+
lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),
104146
)
105147

106148
for mlc_name, mlc_param in named_parameters.items():

python/mlc_llm/model/qwen3_moe/qwen3_moe_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def __init__(self, config: Qwen3MoeConfig):
218218
self.vocab_size = config.vocab_size
219219
self.tensor_parallel_shards = config.tensor_parallel_shards
220220
self.head_dim = config.head_dim
221+
self.weight_block_size = config.weight_block_size
221222

222223
def to(self, dtype: Optional[str] = None):
223224
super().to(dtype=dtype)

0 commit comments

Comments
 (0)