|
1 | 1 | import argparse
|
2 |
| -from typing import Dict |
3 | 2 |
|
| 3 | +import json |
4 | 4 | import os
|
5 |
| -from safetensors import safe_open |
| 5 | +from typing import Dict |
| 6 | + |
6 | 7 | import torch
|
| 8 | +from safetensors.torch import load_file |
7 | 9 |
|
8 | 10 | from torchtune.models.convert_weights import get_mapped_key
|
9 | 11 |
|
@@ -58,13 +60,35 @@ def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
|
58 | 60 | return converted_state_dict
|
59 | 61 |
|
60 | 62 |
|
| 63 | +def load_checkpoint(input_dir: str) -> Dict: |
| 64 | + index_path = os.path.join(input_dir, "model.safetensors.index.json") |
| 65 | + if os.path.exists(index_path): |
| 66 | + # Sharded checkpoint. |
| 67 | + with open(index_path, "r") as f: |
| 68 | + index = json.load(f) |
| 69 | + weight_map = index["weight_map"] |
| 70 | + checkpoint_shards = sorted(set(weight_map.values())) |
| 71 | + |
| 72 | + # Load all the shards into memory |
| 73 | + shard_to_weights = {} |
| 74 | + for shard in checkpoint_shards: |
| 75 | + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) |
| 76 | + |
| 77 | + # Merge tensors into consolidated state dict. |
| 78 | + merged_state_dict = {} |
| 79 | + for weight_name, shard in weight_map.items(): |
| 80 | + tensor = shard_to_weights[shard][weight_name] |
| 81 | + merged_state_dict[weight_name] = tensor |
| 82 | + return merged_state_dict |
| 83 | + else: |
| 84 | + # Single checkpoint. |
| 85 | + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) |
| 86 | + return state_dict |
| 87 | + |
| 88 | + |
61 | 89 | def convert_weights(input_dir: str, output_file: str) -> None:
|
62 | 90 | print("Loading checkpoint...")
|
63 |
| - sd = {} |
64 |
| - with safe_open(os.path.join(input_dir, "model.safetensors"), framework="pt", device="cpu") as f: |
65 |
| - for key in f.keys(): |
66 |
| - sd[key] = f.get_tensor(key) |
67 |
| - |
| 91 | + sd = load_checkpoint(input_dir) |
68 | 92 | print("Converting checkpoint...")
|
69 | 93 | sd = qwen_3_tune_to_meta(sd)
|
70 | 94 | print("Saving checkpoint...")
|
|
0 commit comments