Skip to content

Commit f870395

Browse files
authored
[Model] Qwen3 support (#3218)
This PR supports the Qwen3 and Qwen3MoE models. The FP8 support is still in progress. A bug of topk-softmax is fixed in the PR.
1 parent ff9467e commit f870395

File tree

13 files changed

+1195
-71
lines changed

13 files changed

+1195
-71
lines changed

python/mlc_llm/interface/compiler_flags.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from io import StringIO
66
from typing import Optional
77

8-
import tvm
9-
108
from mlc_llm.support import argparse, logging
119
from mlc_llm.support.config import ConfigOverrideBase
1210

@@ -91,8 +89,6 @@ def _flashinfer(target) -> bool:
9189
return False
9290
if target.kind.name != "cuda":
9391
return False
94-
if tvm.get_global_func("support.GetLibInfo")()["USE_FLASHINFER"] != "ON":
95-
return False
9692
arch_list = detect_cuda_arch_list(target)
9793
for arch in arch_list:
9894
if arch < 80:

python/mlc_llm/model/model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from .qwen import qwen_loader, qwen_model, qwen_quantization
4040
from .qwen2 import qwen2_loader, qwen2_model, qwen2_quantization
4141
from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization
42+
from .qwen3 import qwen3_loader, qwen3_model, qwen3_quantization
43+
from .qwen3_moe import qwen3_moe_loader, qwen3_moe_model, qwen3_moe_quantization
4244
from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization
4345
from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization
4446
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization
@@ -326,6 +328,34 @@ class Model:
326328
"ft-quant": qwen2_moe_quantization.ft_quant,
327329
},
328330
),
331+
"qwen3": Model(
332+
name="qwen3",
333+
model=qwen3_model.Qwen3LMHeadModel,
334+
config=qwen3_model.Qwen3Config,
335+
source={
336+
"huggingface-torch": qwen3_loader.huggingface,
337+
"huggingface-safetensor": qwen3_loader.huggingface,
338+
},
339+
quantize={
340+
"no-quant": qwen3_quantization.no_quant,
341+
"group-quant": qwen3_quantization.group_quant,
342+
"ft-quant": qwen3_quantization.ft_quant,
343+
},
344+
),
345+
"qwen3_moe": Model(
346+
name="qwen3_moe",
347+
model=qwen3_moe_model.Qwen3MoeForCausalLM,
348+
config=qwen3_moe_model.Qwen3MoeConfig,
349+
source={
350+
"huggingface-torch": qwen3_moe_loader.huggingface,
351+
"huggingface-safetensor": qwen3_moe_loader.huggingface,
352+
},
353+
quantize={
354+
"no-quant": qwen3_moe_quantization.no_quant,
355+
"group-quant": qwen3_moe_quantization.group_quant,
356+
"ft-quant": qwen3_moe_quantization.ft_quant,
357+
},
358+
),
329359
"deepseek_v2": Model(
330360
name="deepseek_v2",
331361
model=deepseek_v2_model.DeepseekV2ForCausalLM,

python/mlc_llm/model/qwen3/__init__.py

Whitespace-only changes.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
This file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .qwen3_model import Qwen3Config, Qwen3LMHeadModel
14+
15+
16+
def huggingface(model_config: Qwen3Config, quantization: Quantization) -> ExternMapping:
17+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
18+
the names of HuggingFace PyTorch parameters.
19+
20+
Parameters
21+
----------
22+
model_config : Qwen3Config
23+
The configuration of the Qwen3 model.
24+
25+
quantization : Quantization
26+
The quantization configuration.
27+
28+
Returns
29+
-------
30+
param_map : ExternMapping
31+
The parameter mapping from MLC to HuggingFace PyTorch.
32+
"""
33+
model = Qwen3LMHeadModel(model_config)
34+
if quantization is not None:
35+
model.to(quantization.model_dtype)
36+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
37+
spec=model.get_default_spec(),
38+
allow_extern=True,
39+
)
40+
named_parameters = dict(_named_params)
41+
42+
mapping = ExternMapping()
43+
44+
for i in range(model_config.num_hidden_layers):
45+
# map attention weight
46+
attn = f"model.layers.{i}.self_attn"
47+
weight_names = ["weight", "bias"] if model_config.attention_bias else ["weight"]
48+
for weight_type in weight_names:
49+
mlc_name = f"{attn}.c_attn.{weight_type}"
50+
mlc_param = named_parameters[mlc_name]
51+
mapping.add_mapping(
52+
mlc_name,
53+
[
54+
f"{attn}.q_proj.{weight_type}",
55+
f"{attn}.k_proj.{weight_type}",
56+
f"{attn}.v_proj.{weight_type}",
57+
],
58+
functools.partial(
59+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
60+
dtype=mlc_param.dtype,
61+
),
62+
)
63+
# map mlp weight
64+
mlp = f"model.layers.{i}.mlp"
65+
mlc_name = f"{mlp}.gate_up_proj.weight"
66+
mlc_param = named_parameters[mlc_name]
67+
mapping.add_mapping(
68+
mlc_name,
69+
[
70+
f"{mlp}.gate_proj.weight",
71+
f"{mlp}.up_proj.weight",
72+
],
73+
functools.partial(
74+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
75+
dtype=mlc_param.dtype,
76+
),
77+
)
78+
79+
for mlc_name, mlc_param in named_parameters.items():
80+
if mlc_name not in mapping.param_map:
81+
mapping.add_mapping(
82+
mlc_name,
83+
[mlc_name],
84+
functools.partial(
85+
lambda x, dtype: x.astype(dtype),
86+
dtype=mlc_param.dtype,
87+
),
88+
)
89+
return mapping

0 commit comments

Comments
 (0)