Skip to content

Commit 803ff1d

Browse files
committed
Add Qwen3 0.6B
1 parent 2837867 commit 803ff1d

File tree

5 files changed

+133
-4
lines changed

5 files changed

+133
-4
lines changed

examples/models/llama/attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,17 +243,17 @@ def forward(
243243
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
244244
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
245245

246+
if self.use_qk_norm:
247+
q = self.q_norm_fn(q)
248+
k = self.k_norm_fn(k)
249+
246250
# RoPE relative positional embeddings
247251
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
248252

249253
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
250254
k = k.transpose(1, 2)
251255
v = v.transpose(1, 2)
252256

253-
if self.use_qk_norm:
254-
q = self.q_norm_fn(q)
255-
k = self.k_norm_fn(k)
256-
257257
if self.use_kv_cache:
258258
assert input_pos is not None
259259
k, v = self.kv_cache.update(input_pos, k, v)

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
"llama3_2",
101101
"static_llama",
102102
"qwen2_5",
103+
"qwen3",
103104
"phi_4_mini",
104105
"smollm2",
105106
]
@@ -108,6 +109,7 @@
108109
"qwen2_5": "Qwen/Qwen2.5-1.5B",
109110
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
110111
"smollm2": "HuggingFaceTB/SmolLM-135M",
112+
"qwen3": "Qwen/Qwen3-0.6B",
111113
}
112114

113115

@@ -544,6 +546,10 @@ def export_llama(args) -> str:
544546
from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21]
545547
convert_weights,
546548
)
549+
elif args.model == "qwen3":
550+
from executorch.examples.models.qwen3 import ( # pyre-ignore[21]
551+
convert_weights,
552+
)
547553
elif args.model == "phi_4_mini":
548554
from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21]
549555
convert_weights,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"dim": 1024,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 3072,
5+
"n_heads": 16,
6+
"head_dim": 128,
7+
"n_kv_heads": 8,
8+
"n_layers": 28,
9+
"norm_eps": 1e-06,
10+
"rope_theta": 1000000.0,
11+
"use_scaled_rope": false,
12+
"vocab_size": 151936,
13+
"use_hf_rope": true,
14+
"attention_qkv_bias": false,
15+
"use_qk_norm": true
16+
}

examples/models/qwen3/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.qwen3.convert_weights import convert_weights
6+
7+
8+
class Qwen3Model(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"Qwen3Model",
15+
"convert_weights",
16+
]
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import os
5+
from safetensors import safe_open
6+
import torch
7+
8+
from torchtune.models.convert_weights import get_mapped_key
9+
10+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
11+
_QWEN_3_FROM_META = {
12+
"tok_embeddings.weight": "model.embed_tokens.weight",
13+
"norm.weight": "model.norm.weight",
14+
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
15+
"layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight",
16+
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
17+
"layers.{}.attention.q_norm_fn.weight": "model.layers.{}.self_attn.q_norm.weight",
18+
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
19+
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
20+
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
21+
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
22+
# Note: gate_proj and up_proj are reversed, usually w1 is the up_proj,
23+
# w2 is the gate_proj, and activation is applied on the up_proj, but since
24+
# Qwen3 applies activation on the gate_proj, we just swap the gate_proj
25+
# and up_proj in the checkpoint itself as a hack.
26+
"layers.{}.feed_forward.w1.weight": "model.layers.{}.mlp.gate_proj.weight",
27+
"layers.{}.feed_forward.w2.weight": "model.layers.{}.mlp.down_proj.weight",
28+
"layers.{}.feed_forward.w3.weight": "model.layers.{}.mlp.up_proj.weight",
29+
}
30+
31+
32+
def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
33+
"""
34+
Convert a state dict from torchtune's format to Meta's format. This function
35+
doesn't handle any sharding or splitting of state dicts. It follows the
36+
state_dict IN -> state_dict OUT pattern.
37+
38+
Args:
39+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
40+
41+
Returns:
42+
Dict[str, torch.Tensor]: State dict in Meta's format.
43+
"""
44+
converted_state_dict = {}
45+
inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()}
46+
47+
for key, value in state_dict.items():
48+
# Tied embeddings for 0.6b and 4b models.
49+
if key == "lm_head.weight":
50+
continue
51+
new_key = get_mapped_key(key, inverted_mapping_dict)
52+
converted_state_dict[new_key] = value
53+
54+
converted_state_dict["output.weight"] = converted_state_dict[
55+
"tok_embeddings.weight"
56+
]
57+
58+
return converted_state_dict
59+
60+
61+
def convert_weights(input_dir: str, output_file: str) -> None:
62+
print("Loading checkpoint...")
63+
sd = {}
64+
with safe_open(os.path.join(input_dir, "model.safetensors"), framework="pt", device="cpu") as f:
65+
for key in f.keys():
66+
sd[key] = f.get_tensor(key)
67+
68+
print("Converting checkpoint...")
69+
sd = qwen_3_tune_to_meta(sd)
70+
print("Saving checkpoint...")
71+
torch.save(sd, output_file)
72+
print("Done.")
73+
74+
75+
def main():
76+
parser = argparse.ArgumentParser(
77+
description="Convert Qwen3 weights to Meta format."
78+
)
79+
parser.add_argument(
80+
"input_dir",
81+
type=str,
82+
help="Path to directory containing checkpoint files",
83+
)
84+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
85+
86+
args = parser.parse_args()
87+
convert_weights(args.input_dir, args.output)
88+
89+
90+
if __name__ == "__main__":
91+
main()

0 commit comments

Comments
 (0)