Skip to content

Commit 507dbc0

Browse files
committed
1.7B and 4B
1 parent 803ff1d commit 507dbc0

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@
100100
"llama3_2",
101101
"static_llama",
102102
"qwen2_5",
103-
"qwen3",
103+
"qwen3-0_6b",
104+
"qwen3-1_7b",
105+
"qwen3-4b",
104106
"phi_4_mini",
105107
"smollm2",
106108
]
@@ -109,7 +111,9 @@
109111
"qwen2_5": "Qwen/Qwen2.5-1.5B",
110112
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
111113
"smollm2": "HuggingFaceTB/SmolLM-135M",
112-
"qwen3": "Qwen/Qwen3-0.6B",
114+
"qwen3-0_6b": "Qwen/Qwen3-0.6B",
115+
"qwen3-1_7b": "Qwen/Qwen3-1.7B",
116+
"qwen3-4b": "Qwen/Qwen3-4B",
113117
}
114118

115119

@@ -546,7 +550,7 @@ def export_llama(args) -> str:
546550
from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21]
547551
convert_weights,
548552
)
549-
elif args.model == "qwen3":
553+
elif args.model.startswith("qwen3"):
550554
from executorch.examples.models.qwen3 import ( # pyre-ignore[21]
551555
convert_weights,
552556
)

examples/models/qwen3/convert_weights.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import argparse
2-
from typing import Dict
32

3+
import json
44
import os
5-
from safetensors import safe_open
5+
from typing import Dict
6+
67
import torch
8+
from safetensors.torch import load_file
79

810
from torchtune.models.convert_weights import get_mapped_key
911

@@ -58,13 +60,35 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
5860
return converted_state_dict
5961

6062

63+
def load_checkpoint(input_dir: str) -> Dict:
64+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
65+
if os.path.exists(index_path):
66+
# Sharded checkpoint.
67+
with open(index_path, "r") as f:
68+
index = json.load(f)
69+
weight_map = index["weight_map"]
70+
checkpoint_shards = sorted(set(weight_map.values()))
71+
72+
# Load all the shards into memory
73+
shard_to_weights = {}
74+
for shard in checkpoint_shards:
75+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
76+
77+
# Merge tensors into consolidated state dict.
78+
merged_state_dict = {}
79+
for weight_name, shard in weight_map.items():
80+
tensor = shard_to_weights[shard][weight_name]
81+
merged_state_dict[weight_name] = tensor
82+
return merged_state_dict
83+
else:
84+
# Single checkpoint.
85+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
86+
return state_dict
87+
88+
6189
def convert_weights(input_dir: str, output_file: str) -> None:
6290
print("Loading checkpoint...")
63-
sd = {}
64-
with safe_open(os.path.join(input_dir, "model.safetensors"), framework="pt", device="cpu") as f:
65-
for key in f.keys():
66-
sd[key] = f.get_tensor(key)
67-
91+
sd = load_checkpoint(input_dir)
6892
print("Converting checkpoint...")
6993
sd = qwen_3_tune_to_meta(sd)
7094
print("Saving checkpoint...")

0 commit comments

Comments
 (0)