Skip to content

Commit 4aed2af

Browse files
committed
Add convert model script
1 parent 64bbf3b commit 4aed2af

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed

convert_model.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
3+
import json
4+
import struct
5+
import torch
6+
import argparse
7+
from safetensors import safe_open
8+
9+
# MXFP4 conversion constants
10+
FP4_VALUES = [
11+
+0.0, +0.5, +1.0, +1.5, +2.0, +3.0, +4.0, +6.0,
12+
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
13+
]
14+
15+
def convert_mxfp4_to_bf16(blocks_tensor, scales_tensor):
16+
"""Convert MXFP4 format to BF16"""
17+
blocks = blocks_tensor
18+
scales = scales_tensor.to(torch.int32) - 127
19+
20+
lut = torch.tensor(FP4_VALUES, dtype=torch.float32, device=blocks.device)
21+
22+
# Split into low and high nibbles
23+
idx_lo = (blocks_tensor & 0x0F).to(torch.long) # [..., G, B]
24+
idx_hi = (blocks_tensor >> 4).to(torch.long) # [..., G, B]
25+
26+
# Convert to FP32 values
27+
vals_lo = lut[idx_lo] # [..., G, B]
28+
vals_hi = lut[idx_hi] # [..., G, B]
29+
30+
# Interleave along last dimension → [..., G, 2B]
31+
interleaved = torch.stack((vals_lo, vals_hi), dim=-1).reshape(*vals_lo.shape[:-1], -1)
32+
33+
# Apply exponent scaling: x * 2**exp
34+
result = torch.ldexp(interleaved, scales.unsqueeze(-1))
35+
36+
return result.to(torch.bfloat16)
37+
38+
def get_target_dtype(weight_name):
39+
"""Determine target dtype based on weight name"""
40+
if weight_name in ["embedding.weight", "unembedding.weight"]:
41+
return "BF16"
42+
elif "norm.scale" in weight_name and not weight_name.startswith("block."):
43+
return "BF16" # final layer norm
44+
elif any(x in weight_name for x in ["mlp1_weight.blocks", "mlp2_weight.blocks", "mlp1_weight.scales", "mlp2_weight.scales"]):
45+
return "BF16"
46+
elif any(x in weight_name for x in ["attn.norm.scale", "attn.qkv", "attn.sinks", "attn.out", "mlp.norm.scale", "mlp.gate", "mlp1_bias", "mlp2_bias"]):
47+
return "BF16"
48+
else:
49+
print(f"Not recognized weight name {weight_name}")
50+
return "BF16" # default
51+
52+
def read_model_header(model_path):
53+
"""Read header from footer of model.bin file"""
54+
with open(model_path, 'rb') as f:
55+
# Read header size from last 8 bytes
56+
f.seek(-8, 2)
57+
header_size = struct.unpack('<Q', f.read(8))[0]
58+
59+
# Read header JSON
60+
f.seek(-(8 + header_size), 2)
61+
header_json = f.read(header_size).decode('utf-8')
62+
return json.loads(header_json)
63+
64+
def read_tensor_by_name(model_path, tensor_name, header=None):
65+
"""Read specific tensor from model.bin file"""
66+
if header is None:
67+
header = read_model_header(model_path)
68+
69+
tensor_info = header[tensor_name]
70+
start_offset, end_offset = tensor_info["data_offsets"]
71+
shape = tensor_info["shape"]
72+
dtype = tensor_info["dtype"]
73+
74+
with open(model_path, 'rb') as f:
75+
f.seek(start_offset)
76+
tensor_bytes = f.read(end_offset - start_offset)
77+
78+
# Convert bytes to tensor
79+
if dtype == "FP32":
80+
tensor = torch.frombuffer(tensor_bytes, dtype=torch.float32).view(shape)
81+
elif dtype == "BF16":
82+
tensor = torch.frombuffer(tensor_bytes, dtype=torch.bfloat16).view(shape)
83+
else:
84+
tensor = torch.frombuffer(tensor_bytes, dtype=torch.float32).view(shape)
85+
86+
return tensor
87+
88+
def get_dtype_size_multiplier(original_dtype, target_dtype):
89+
"""Get expected size multiplier for dtype conversion"""
90+
dtype_sizes = {"FP32": 4, "BF16": 2, "FP16": 2, "FP4": 0.5}
91+
92+
orig_size = dtype_sizes.get(original_dtype, 4)
93+
target_size = dtype_sizes.get(target_dtype, 4)
94+
95+
return target_size / orig_size
96+
97+
def validate_data_offset_size(original_size, output_size, original_dtype, target_dtype, tensor_name):
98+
"""Validate output data offset size matches expected conversion ratio"""
99+
expected_multiplier = get_dtype_size_multiplier(original_dtype, target_dtype)
100+
expected_size = int(original_size * expected_multiplier)
101+
102+
if output_size != expected_size:
103+
raise ValueError(f"Data offset size mismatch for {tensor_name}: "
104+
f"expected {expected_size} bytes (original {original_size} * {expected_multiplier}), "
105+
f"got {output_size} bytes")
106+
107+
def convert_safetensors_to_modelbin(input_path, output_path):
108+
"""Convert safetensors to custom model.bin format with streaming writes"""
109+
110+
header = {}
111+
current_offset = 0
112+
113+
with open(output_path, 'wb') as out_f:
114+
with safe_open(input_path, framework="pt", device="cpu") as f:
115+
tensor_names = list(f.keys())
116+
117+
# Sort tensor names: embedding, unembedding, norm.scale, then blocks in ascending order
118+
def sort_key(name):
119+
if name == "embedding.weight":
120+
return (0, 0, name)
121+
elif name == "unembedding.weight":
122+
return (1, 0, name)
123+
elif name == "norm.scale":
124+
return (2, 0, name)
125+
elif name.startswith("block."):
126+
# Extract block number for proper numeric sorting
127+
block_num = int(name.split('.')[1])
128+
return (3, block_num, name)
129+
else:
130+
return (4, 0, name)
131+
132+
tensor_names.sort(key=sort_key)
133+
134+
# Process each tensor and write immediately
135+
for name in tensor_names:
136+
target_dtype = get_target_dtype(name)
137+
original_tensor = f.get_tensor(name)
138+
original_dtype = str(original_tensor.dtype).upper().replace("TORCH.", "")
139+
if original_dtype == "BFLOAT16":
140+
original_dtype = "BF16"
141+
elif original_dtype == "FLOAT32":
142+
original_dtype = "FP32"
143+
144+
# Calculate original size
145+
original_size = original_tensor.numel() * original_tensor.element_size()
146+
147+
# Handle MXFP4 weights
148+
if name.endswith(".mlp1_weight.blocks") or name.endswith(".mlp2_weight.blocks"):
149+
base_name = name.replace(".blocks", "")
150+
scales_name = name.replace(".blocks", ".scales")
151+
152+
if scales_name in tensor_names:
153+
blocks = f.get_tensor(name)
154+
scales = f.get_tensor(scales_name)
155+
tensor = convert_mxfp4_to_bf16(blocks, scales)
156+
dtype_str = "BF16"
157+
original_dtype = "FP4" # Override for FP4 blocks
158+
else:
159+
continue
160+
elif name.endswith(".scales"):
161+
continue # Skip scales, handled with blocks
162+
else:
163+
tensor = original_tensor
164+
165+
# Convert to target dtype
166+
if target_dtype == "FP32":
167+
tensor = tensor.float()
168+
dtype_str = "FP32"
169+
elif target_dtype == "BF16":
170+
tensor = tensor.to(torch.bfloat16)
171+
dtype_str = "BF16"
172+
else:
173+
dtype_str = original_dtype # Keep original
174+
175+
# Write tensor data immediately
176+
if tensor.dtype == torch.bfloat16:
177+
tensor_bytes = tensor.view(torch.uint16).numpy().tobytes()
178+
else:
179+
tensor_bytes = tensor.numpy().tobytes()
180+
181+
out_f.write(tensor_bytes)
182+
183+
# Validate data offset size
184+
output_size = len(tensor_bytes)
185+
validate_data_offset_size(original_size, output_size, original_dtype, dtype_str, name)
186+
187+
# Update header with offset info
188+
header[name] = {
189+
"dtype": dtype_str,
190+
"shape": list(tensor.shape),
191+
"data_offsets": [current_offset, current_offset + len(tensor_bytes)]
192+
}
193+
current_offset += len(tensor_bytes)
194+
print(f"Wrote tensor: {name} {header[name]}")
195+
196+
# Free memory
197+
del tensor
198+
if 'blocks' in locals():
199+
del blocks
200+
if 'scales' in locals():
201+
del scales
202+
203+
# Write footer with header info
204+
header_json = json.dumps(header, separators=(',', ':')).encode('utf-8')
205+
header_size = len(header_json)
206+
out_f.write(header_json)
207+
out_f.write(struct.pack('<Q', header_size)) # Footer: header size at end
208+
print(f"Wrote header as footer")
209+
210+
print(f"Converted {len(header)} tensors from {input_path} to {output_path}")
211+
212+
if __name__ == "__main__":
213+
parser = argparse.ArgumentParser(description="Convert GPT-OSS safetensors to model.bin format")
214+
parser.add_argument("input_path", help="Path to input safetensors file")
215+
parser.add_argument("output_path", help="Path to output model.bin file")
216+
217+
args = parser.parse_args()
218+
219+
convert_safetensors_to_modelbin(args.input_path, args.output_path)

0 commit comments

Comments
 (0)