Skip to content

Commit 7b86bf1

Browse files
authored
Add Qwen3 0.6B, 1.7B, and 4B (#10539)
Add ExecuTorch support for Qwen3 0.6B, 1.7B, and 4B ### Qwen3 0.6B Export with xnnpack + 8da4w quantization ``` python -m examples.models.llama.export_llama --model qwen3-0_6b --params examples/models/qwen3/0_6b_config.json -kv --use_sdpa_with_kv_cache -X --xnnpack-extended-ops -d fp32 --output_name="qwen3-0_6b_x_8da4w.pte" --verbose -qmode 8da4w ``` Run with pybindings ``` python -m examples.models.llama.runner.native --model qwen3-0_6b --pte qwen3-0_6b_x_8da4w.pte --tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json --tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer_config.json --prompt "Who is the president of the US?" --params examples/models/qwen3/0_6b_config.json --max_len 128 -kv --temperature 0.6 >> Okay, let's see. The user is asking about the president of the US, but they wrote "And who is the president of the US?" and "And who is the president of the US?" So maybe they are using the same question but in a different way. They might be referring to the same president. Let me check. ... # Some rough stats Prefill time: 0.24 s Generation tok/s: 17.15 s Memory: 826.68 MB ``` ### Qwen3 1.7B Export with xnnpack + 8da4w quantization ``` python -m examples.models.llama.export_llama --model qwen3-1_7b --params examples/models/qwen3/1_7b_config.json -kv --use_sdpa_with_kv_cache -X --xnnpack-extended-ops -d fp32 --output_name="qwen3-1_7b_x_8da4w.pte" --verbose -qmode 8da4w ``` Run with pybindings ``` python -m examples.models.llama.runner.native --model qwen3-1_7b --pte qwen3-1_7b_x_8da4w.pte --tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json --tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer_config.json --prompt "Who is the president of the US?" --params examples/models/qwen3/1_7b_config.json --max_len 128 -kv --temperature 0.6 # Some rough stats Prefill time: 0.25 s Generation tok/s: 16.87 s Memory: 1.02 GB ``` ### Qwen3 4B Export with xnnpack + 8da4w quantization ``` python -m examples.models.llama.export_llama --model qwen3-4b --params examples/models/qwen3/4b_config.json -kv --use_sdpa_with_kv_cache -X --xnnpack-extended-ops -d fp32 --output_name="qwen3-4b_x_8da4w.pte" --verbose -qmode 8da4w ``` Run with pybindings ``` python -m examples.models.llama.runner.native --model qwen3-4b --pte qwen3-4b_x_8da4w.pte --tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json --tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer_config.json --prompt "Who is the president of the US?" --params examples/models/qwen3/4b_config.json --max_len 128 -kv --temperature 0.6 # Some rough stats Prefill time: 0.44 s Generation tok/s: 12.12 s Memory: 2.5 GB ``` bypass-github-export-checks
1 parent 32dffbc commit 7b86bf1

10 files changed

+285
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ To get started you can:
5151

5252
- Visit the [Step by Step Tutorial](https://pytorch.org/executorch/stable/getting-started.html) to get things running locally and deploy a model to a device
5353
- Use this [Colab Notebook](https://colab.research.google.com/drive/1qpxrXC3YdJQzly3mRg-4ayYiOjC6rue3?usp=sharing) to start playing around right away
54-
- Jump straight into LLM use cases by following specific instructions for [Llama](examples/models/llama/README.md) and [Llava](examples/models/llava/README.md)
54+
- Jump straight into LLM use cases by following specific instructions for popular open-source models such as [Llama](examples/models/llama/README.md), [Qwen 3](examples/models/qwen3/README.md), [Phi-4-mini](examples/models/phi_4_mini/README.md), and [Llava](examples/models/llava/README.md)
5555

5656
## Feedback and Engagement
5757

examples/models/llama/attention.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
178178
self.dim = args.dim
179179
self.attention_qkv_bias = args.attention_qkv_bias
180180
self.use_qk_norm = args.use_qk_norm
181+
self.qk_norm_before_rope = args.qk_norm_before_rope
181182

182183
if self.use_qk_norm:
183184
q_norm_dim = self.head_dim
@@ -243,14 +244,18 @@ def forward(
243244
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
244245
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
245246

247+
if self.use_qk_norm and self.qk_norm_before_rope:
248+
q = self.q_norm_fn(q)
249+
k = self.k_norm_fn(k)
250+
246251
# RoPE relative positional embeddings
247252
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
248253

249254
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
250255
k = k.transpose(1, 2)
251256
v = v.transpose(1, 2)
252257

253-
if self.use_qk_norm:
258+
if self.use_qk_norm and not self.qk_norm_before_rope:
254259
q = self.q_norm_fn(q)
255260
k = self.k_norm_fn(k)
256261

examples/models/llama/export_llama_lib.py

+10
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@
100100
"llama3_2",
101101
"static_llama",
102102
"qwen2_5",
103+
"qwen3-0_6b",
104+
"qwen3-1_7b",
105+
"qwen3-4b",
103106
"phi_4_mini",
104107
"smollm2",
105108
]
@@ -108,6 +111,9 @@
108111
"qwen2_5": "Qwen/Qwen2.5-1.5B",
109112
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
110113
"smollm2": "HuggingFaceTB/SmolLM-135M",
114+
"qwen3-0_6b": "Qwen/Qwen3-0.6B",
115+
"qwen3-1_7b": "Qwen/Qwen3-1.7B",
116+
"qwen3-4b": "Qwen/Qwen3-4B",
111117
}
112118

113119

@@ -544,6 +550,10 @@ def export_llama(args) -> str:
544550
from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21]
545551
convert_weights,
546552
)
553+
elif args.model.startswith("qwen3"):
554+
from executorch.examples.models.qwen3 import ( # pyre-ignore[21]
555+
convert_weights,
556+
)
547557
elif args.model == "phi_4_mini":
548558
from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21]
549559
convert_weights,

examples/models/llama/model_args.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ModelArgs:
3838
apply_embedding: bool = True # Use embedding inside the transformer
3939
apply_output: bool = True # Use output layer (unembedding) inside the transformer
4040
use_qk_norm: bool = False # apply normalization to q and k in the attention
41+
qk_norm_before_rope: bool = False # when to apply qk norm
4142
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
4243
partial_rotary_factor: float = 1.0
4344
rope_theta: Optional[float] = (
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"dim": 1024,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 3072,
5+
"n_heads": 16,
6+
"head_dim": 128,
7+
"n_kv_heads": 8,
8+
"n_layers": 28,
9+
"norm_eps": 1e-06,
10+
"rope_theta": 1000000.0,
11+
"use_scaled_rope": false,
12+
"vocab_size": 151936,
13+
"use_hf_rope": true,
14+
"attention_qkv_bias": false,
15+
"use_qk_norm": true,
16+
"qk_norm_before_rope": true
17+
}
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 6144,
5+
"n_heads": 16,
6+
"head_dim": 128,
7+
"n_kv_heads": 8,
8+
"n_layers": 28,
9+
"norm_eps": 1e-06,
10+
"rope_theta": 1000000.0,
11+
"use_scaled_rope": false,
12+
"vocab_size": 151936,
13+
"use_hf_rope": true,
14+
"attention_qkv_bias": false,
15+
"use_qk_norm": true,
16+
"qk_norm_before_rope": true
17+
}

examples/models/qwen3/4b_config.json

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"dim": 2560,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 9728,
5+
"n_heads": 32,
6+
"head_dim": 128,
7+
"n_kv_heads": 8,
8+
"n_layers": 36,
9+
"norm_eps": 1e-06,
10+
"rope_theta": 1000000.0,
11+
"use_scaled_rope": false,
12+
"vocab_size": 151936,
13+
"use_hf_rope": true,
14+
"attention_qkv_bias": false,
15+
"use_qk_norm": true,
16+
"qk_norm_before_repo": true
17+
}

examples/models/qwen3/README.md

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
## Summary
2+
Qwen 3 is the latest iteration of the Qwen series of large language models (LLMs) developed by Alibaba. Edge-sized Qwen3 model variations (0.6B, 1.7B, and 4B) are currently supported .
3+
4+
## Instructions
5+
6+
Qwen 3 uses the same example code as our optimized Llama model, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details.
7+
8+
All commands for exporting and running Llama on various backends should also be applicable to Qwen 3, by swapping the following args:
9+
```
10+
--model [qwen3-0.6b,qwen3-1_7b,qwen3-4b]
11+
--params [examples/models/qwen3/0_6b_config.json,examples/models/qwen3/1_7b_config.json,examples/models/qwen3/4b_config.json]
12+
```
13+
14+
### Example export
15+
Here is a basic example for exporting Qwen 3, although please refer to the Llama README's [Step 2: Prepare model](../llama/README.md#step-2-prepare-model) for more advanced usage.
16+
17+
Export 0.6b to XNNPack, quantized with 8da4w:
18+
```
19+
python -m examples.models.llama.export_llama \
20+
--model qwen3-0_6b \
21+
--params examples/models/qwen3/0_6b_config.json \
22+
-kv \
23+
--use_sdpa_with_kv_cache \
24+
-d fp32 \
25+
-X \
26+
--xnnpack-extended-ops \
27+
-qmode 8da4w
28+
--output_name="qwen3-0_6b.pte" \
29+
--verbose
30+
```
31+
32+
Export 1.7b to XNNPack, quantized with 8da4w:
33+
```
34+
python -m examples.models.llama.export_llama \
35+
--model qwen3-1_7b \
36+
--params examples/models/qwen3/1_7b_config.json \
37+
-kv \
38+
--use_sdpa_with_kv_cache \
39+
-d fp32 \
40+
-X \
41+
--xnnpack-extended-ops \
42+
-qmode 8da4w
43+
--output_name="qwen3-1_7b.pte" \
44+
--verbose
45+
```
46+
47+
Export 4b to XNNPack, quantized with 8da4w:
48+
```
49+
python -m examples.models.llama.export_llama \
50+
--model qwen3-4b \
51+
--params examples/models/qwen3/4b_config.json \
52+
-kv \
53+
--use_sdpa_with_kv_cache \
54+
-d fp32 \
55+
-X \
56+
--xnnpack-extended-ops \
57+
-qmode 8da4w
58+
--output_name="qwen3-4b.pte" \
59+
--verbose
60+
```
61+
62+
### Example run
63+
With ExecuTorch pybindings:
64+
```
65+
python -m examples.models.llama.runner.native
66+
--model qwen3-0_6b \
67+
--pte qwen3-0_6b.pte \
68+
--tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json \
69+
--tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer_config.json \
70+
--prompt "Who is the president of the US?" \
71+
--params examples/models/qwen3/0_6b_config.json \
72+
--max_len 128 \
73+
-kv \
74+
--temperature 0.6
75+
```
76+
77+
With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner):
78+
```
79+
cmake-out/examples/models/llama/llama_main
80+
--model_path qwen3-0_6b.pte
81+
--tokenizer_path ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json
82+
--prompt="Who is the president of the US?"
83+
```
84+
85+
To run the model on an example iOS or Android app, see the Llama README's [Step 5: Build Mobile apps](../llama/README.md#step-5-build-mobile-apps) section.

examples/models/qwen3/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.qwen3.convert_weights import convert_weights
6+
7+
8+
class Qwen3Model(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"Qwen3Model",
15+
"convert_weights",
16+
]
+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import argparse
2+
3+
import json
4+
import os
5+
from typing import Dict
6+
7+
import torch
8+
from safetensors.torch import load_file
9+
10+
from torchtune.models.convert_weights import get_mapped_key
11+
12+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
13+
_QWEN_3_FROM_META = {
14+
"tok_embeddings.weight": "model.embed_tokens.weight",
15+
"norm.weight": "model.norm.weight",
16+
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
17+
"layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight",
18+
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
19+
"layers.{}.attention.q_norm_fn.weight": "model.layers.{}.self_attn.q_norm.weight",
20+
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
21+
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
22+
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
23+
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
24+
# Note: gate_proj and up_proj are reversed, usually w1 is the up_proj,
25+
# w2 is the gate_proj, and activation is applied on the up_proj, but since
26+
# Qwen3 applies activation on the gate_proj, we just swap the gate_proj
27+
# and up_proj in the checkpoint itself as a hack.
28+
"layers.{}.feed_forward.w1.weight": "model.layers.{}.mlp.gate_proj.weight",
29+
"layers.{}.feed_forward.w2.weight": "model.layers.{}.mlp.down_proj.weight",
30+
"layers.{}.feed_forward.w3.weight": "model.layers.{}.mlp.up_proj.weight",
31+
}
32+
33+
34+
def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
35+
"""
36+
Convert a state dict from torchtune's format to Meta's format. This function
37+
doesn't handle any sharding or splitting of state dicts. It follows the
38+
state_dict IN -> state_dict OUT pattern.
39+
40+
Args:
41+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
42+
43+
Returns:
44+
Dict[str, torch.Tensor]: State dict in Meta's format.
45+
"""
46+
converted_state_dict = {}
47+
inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()}
48+
49+
for key, value in state_dict.items():
50+
# Tied embeddings for 0.6b and 4b models.
51+
if key == "lm_head.weight":
52+
continue
53+
new_key = get_mapped_key(key, inverted_mapping_dict)
54+
converted_state_dict[new_key] = value
55+
56+
converted_state_dict["output.weight"] = converted_state_dict[
57+
"tok_embeddings.weight"
58+
]
59+
60+
return converted_state_dict
61+
62+
63+
def load_checkpoint(input_dir: str) -> Dict:
64+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
65+
if os.path.exists(index_path):
66+
# Sharded checkpoint.
67+
with open(index_path, "r") as f:
68+
index = json.load(f)
69+
weight_map = index["weight_map"]
70+
checkpoint_shards = sorted(set(weight_map.values()))
71+
72+
# Load all the shards into memory
73+
shard_to_weights = {}
74+
for shard in checkpoint_shards:
75+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
76+
77+
# Merge tensors into consolidated state dict.
78+
merged_state_dict = {}
79+
for weight_name, shard in weight_map.items():
80+
tensor = shard_to_weights[shard][weight_name]
81+
merged_state_dict[weight_name] = tensor
82+
return merged_state_dict
83+
else:
84+
# Single checkpoint.
85+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
86+
return state_dict
87+
88+
89+
def convert_weights(input_dir: str, output_file: str) -> None:
90+
print("Loading checkpoint...")
91+
sd = load_checkpoint(input_dir)
92+
print("Converting checkpoint...")
93+
sd = qwen_3_tune_to_meta(sd)
94+
print("Saving checkpoint...")
95+
torch.save(sd, output_file)
96+
print("Done.")
97+
98+
99+
def main():
100+
parser = argparse.ArgumentParser(
101+
description="Convert Qwen3 weights to Meta format."
102+
)
103+
parser.add_argument(
104+
"input_dir",
105+
type=str,
106+
help="Path to directory containing checkpoint files",
107+
)
108+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
109+
110+
args = parser.parse_args()
111+
convert_weights(args.input_dir, args.output)
112+
113+
114+
if __name__ == "__main__":
115+
main()

0 commit comments

Comments
 (0)