1+ #!/usr/bin/env python3
2+
3+ import json
4+ import struct
5+ import torch
6+ import argparse
7+ from safetensors import safe_open
8+
9+ # MXFP4 conversion constants
10+ FP4_VALUES = [
11+ + 0.0 , + 0.5 , + 1.0 , + 1.5 , + 2.0 , + 3.0 , + 4.0 , + 6.0 ,
12+ - 0.0 , - 0.5 , - 1.0 , - 1.5 , - 2.0 , - 3.0 , - 4.0 , - 6.0 ,
13+ ]
14+
15+ def convert_mxfp4_to_bf16 (blocks_tensor , scales_tensor ):
16+ """Convert MXFP4 format to BF16"""
17+ blocks = blocks_tensor
18+ scales = scales_tensor .to (torch .int32 ) - 127
19+
20+ lut = torch .tensor (FP4_VALUES , dtype = torch .float32 , device = blocks .device )
21+
22+ # Split into low and high nibbles
23+ idx_lo = (blocks_tensor & 0x0F ).to (torch .long ) # [..., G, B]
24+ idx_hi = (blocks_tensor >> 4 ).to (torch .long ) # [..., G, B]
25+
26+ # Convert to FP32 values
27+ vals_lo = lut [idx_lo ] # [..., G, B]
28+ vals_hi = lut [idx_hi ] # [..., G, B]
29+
30+ # Interleave along last dimension → [..., G, 2B]
31+ interleaved = torch .stack ((vals_lo , vals_hi ), dim = - 1 ).reshape (* vals_lo .shape [:- 1 ], - 1 )
32+
33+ # Apply exponent scaling: x * 2**exp
34+ result = torch .ldexp (interleaved , scales .unsqueeze (- 1 ))
35+
36+ return result .to (torch .bfloat16 )
37+
38+ def get_target_dtype (weight_name ):
39+ """Determine target dtype based on weight name"""
40+ if weight_name in ["embedding.weight" , "unembedding.weight" ]:
41+ return "BF16"
42+ elif "norm.scale" in weight_name and not weight_name .startswith ("block." ):
43+ return "BF16" # final layer norm
44+ elif any (x in weight_name for x in ["mlp1_weight.blocks" , "mlp2_weight.blocks" , "mlp1_weight.scales" , "mlp2_weight.scales" ]):
45+ return "BF16"
46+ elif any (x in weight_name for x in ["attn.norm.scale" , "attn.qkv" , "attn.sinks" , "attn.out" , "mlp.norm.scale" , "mlp.gate" , "mlp1_bias" , "mlp2_bias" ]):
47+ return "BF16"
48+ else :
49+ print (f"Not recognized weight name { weight_name } " )
50+ return "BF16" # default
51+
52+ def read_model_header (model_path ):
53+ """Read header from footer of model.bin file"""
54+ with open (model_path , 'rb' ) as f :
55+ # Read header size from last 8 bytes
56+ f .seek (- 8 , 2 )
57+ header_size = struct .unpack ('<Q' , f .read (8 ))[0 ]
58+
59+ # Read header JSON
60+ f .seek (- (8 + header_size ), 2 )
61+ header_json = f .read (header_size ).decode ('utf-8' )
62+ return json .loads (header_json )
63+
64+ def read_tensor_by_name (model_path , tensor_name , header = None ):
65+ """Read specific tensor from model.bin file"""
66+ if header is None :
67+ header = read_model_header (model_path )
68+
69+ tensor_info = header [tensor_name ]
70+ start_offset , end_offset = tensor_info ["data_offsets" ]
71+ shape = tensor_info ["shape" ]
72+ dtype = tensor_info ["dtype" ]
73+
74+ with open (model_path , 'rb' ) as f :
75+ f .seek (start_offset )
76+ tensor_bytes = f .read (end_offset - start_offset )
77+
78+ # Convert bytes to tensor
79+ if dtype == "FP32" :
80+ tensor = torch .frombuffer (tensor_bytes , dtype = torch .float32 ).view (shape )
81+ elif dtype == "BF16" :
82+ tensor = torch .frombuffer (tensor_bytes , dtype = torch .bfloat16 ).view (shape )
83+ else :
84+ tensor = torch .frombuffer (tensor_bytes , dtype = torch .float32 ).view (shape )
85+
86+ return tensor
87+
88+ def get_dtype_size_multiplier (original_dtype , target_dtype ):
89+ """Get expected size multiplier for dtype conversion"""
90+ dtype_sizes = {"FP32" : 4 , "BF16" : 2 , "FP16" : 2 , "FP4" : 0.5 }
91+
92+ orig_size = dtype_sizes .get (original_dtype , 4 )
93+ target_size = dtype_sizes .get (target_dtype , 4 )
94+
95+ return target_size / orig_size
96+
97+ def validate_data_offset_size (original_size , output_size , original_dtype , target_dtype , tensor_name ):
98+ """Validate output data offset size matches expected conversion ratio"""
99+ expected_multiplier = get_dtype_size_multiplier (original_dtype , target_dtype )
100+ expected_size = int (original_size * expected_multiplier )
101+
102+ if output_size != expected_size :
103+ raise ValueError (f"Data offset size mismatch for { tensor_name } : "
104+ f"expected { expected_size } bytes (original { original_size } * { expected_multiplier } ), "
105+ f"got { output_size } bytes" )
106+
107+ def convert_safetensors_to_modelbin (input_path , output_path ):
108+ """Convert safetensors to custom model.bin format with streaming writes"""
109+
110+ header = {}
111+ current_offset = 0
112+
113+ with open (output_path , 'wb' ) as out_f :
114+ with safe_open (input_path , framework = "pt" , device = "cpu" ) as f :
115+ tensor_names = list (f .keys ())
116+
117+ # Sort tensor names: embedding, unembedding, norm.scale, then blocks in ascending order
118+ def sort_key (name ):
119+ if name == "embedding.weight" :
120+ return (0 , 0 , name )
121+ elif name == "unembedding.weight" :
122+ return (1 , 0 , name )
123+ elif name == "norm.scale" :
124+ return (2 , 0 , name )
125+ elif name .startswith ("block." ):
126+ # Extract block number for proper numeric sorting
127+ block_num = int (name .split ('.' )[1 ])
128+ return (3 , block_num , name )
129+ else :
130+ return (4 , 0 , name )
131+
132+ tensor_names .sort (key = sort_key )
133+
134+ # Process each tensor and write immediately
135+ for name in tensor_names :
136+ target_dtype = get_target_dtype (name )
137+ original_tensor = f .get_tensor (name )
138+ original_dtype = str (original_tensor .dtype ).upper ().replace ("TORCH." , "" )
139+ if original_dtype == "BFLOAT16" :
140+ original_dtype = "BF16"
141+ elif original_dtype == "FLOAT32" :
142+ original_dtype = "FP32"
143+
144+ # Calculate original size
145+ original_size = original_tensor .numel () * original_tensor .element_size ()
146+
147+ # Handle MXFP4 weights
148+ if name .endswith (".mlp1_weight.blocks" ) or name .endswith (".mlp2_weight.blocks" ):
149+ base_name = name .replace (".blocks" , "" )
150+ scales_name = name .replace (".blocks" , ".scales" )
151+
152+ if scales_name in tensor_names :
153+ blocks = f .get_tensor (name )
154+ scales = f .get_tensor (scales_name )
155+ tensor = convert_mxfp4_to_bf16 (blocks , scales )
156+ dtype_str = "BF16"
157+ original_dtype = "FP4" # Override for FP4 blocks
158+ else :
159+ continue
160+ elif name .endswith (".scales" ):
161+ continue # Skip scales, handled with blocks
162+ else :
163+ tensor = original_tensor
164+
165+ # Convert to target dtype
166+ if target_dtype == "FP32" :
167+ tensor = tensor .float ()
168+ dtype_str = "FP32"
169+ elif target_dtype == "BF16" :
170+ tensor = tensor .to (torch .bfloat16 )
171+ dtype_str = "BF16"
172+ else :
173+ dtype_str = original_dtype # Keep original
174+
175+ # Write tensor data immediately
176+ if tensor .dtype == torch .bfloat16 :
177+ tensor_bytes = tensor .view (torch .uint16 ).numpy ().tobytes ()
178+ else :
179+ tensor_bytes = tensor .numpy ().tobytes ()
180+
181+ out_f .write (tensor_bytes )
182+
183+ # Validate data offset size
184+ output_size = len (tensor_bytes )
185+ validate_data_offset_size (original_size , output_size , original_dtype , dtype_str , name )
186+
187+ # Update header with offset info
188+ header [name ] = {
189+ "dtype" : dtype_str ,
190+ "shape" : list (tensor .shape ),
191+ "data_offsets" : [current_offset , current_offset + len (tensor_bytes )]
192+ }
193+ current_offset += len (tensor_bytes )
194+ print (f"Wrote tensor: { name } { header [name ]} " )
195+
196+ # Free memory
197+ del tensor
198+ if 'blocks' in locals ():
199+ del blocks
200+ if 'scales' in locals ():
201+ del scales
202+
203+ # Write footer with header info
204+ header_json = json .dumps (header , separators = (',' , ':' )).encode ('utf-8' )
205+ header_size = len (header_json )
206+ out_f .write (header_json )
207+ out_f .write (struct .pack ('<Q' , header_size )) # Footer: header size at end
208+ print (f"Wrote header as footer" )
209+
210+ print (f"Converted { len (header )} tensors from { input_path } to { output_path } " )
211+
212+ if __name__ == "__main__" :
213+ parser = argparse .ArgumentParser (description = "Convert GPT-OSS safetensors to model.bin format" )
214+ parser .add_argument ("input_path" , help = "Path to input safetensors file" )
215+ parser .add_argument ("output_path" , help = "Path to output model.bin file" )
216+
217+ args = parser .parse_args ()
218+
219+ convert_safetensors_to_modelbin (args .input_path , args .output_path )
0 commit comments