1
+ import argparse
2
+ import os
3
+ import json
4
+ import torch
5
+ import safetensors .torch
6
+
7
+ def per_tensor_quantize (tensor ):
8
+ """Quantize a tensor to FP8 using per-tensor static scaling factor."""
9
+ finfo = torch .finfo (torch .float8_e4m3fn )
10
+ if tensor .numel () == 0 :
11
+ min_val , max_val = torch .tensor (- 16.0 , dtype = tensor .dtype ), torch .tensor (16.0 , dtype = tensor .dtype )
12
+ else :
13
+ min_val , max_val = tensor .aminmax ()
14
+ amax = torch .maximum (min_val .abs (), max_val .abs ())
15
+ scale = finfo .max / amax .clamp (min = 1e-12 )
16
+ qweight = (tensor * scale ).clamp (min = finfo .min , max = finfo .max ).to (torch .float8_e4m3fn )
17
+ scale = scale .float ().reciprocal ()
18
+ return qweight , scale
19
+
20
+ def is_quantizable (name ):
21
+ """Check if the tensor name indicates it can be quantized."""
22
+ return name .startswith ('layers.' ) and name .endswith (('.wk.weight' , '.wo.weight' , '.wq.weight' , '.wv.weight' , '.w1.weight' , '.w2.weight' , '.w3.weight' ))
23
+
24
+ def process_safetensors_file (file_path ):
25
+ """Process a single safetensors file in-place, quantizing weights to FP8."""
26
+ print (f"Processing { file_path } " )
27
+ tensors = safetensors .torch .load_file (file_path )
28
+
29
+ modified_tensors = {}
30
+ for name , tensor in tensors .items ():
31
+ if is_quantizable (name ):
32
+ print ("Quantizing" , name )
33
+ qweight , scale = per_tensor_quantize (tensor )
34
+ modified_tensors [name ] = qweight
35
+ modified_tensors [f"{ name [:- len ("weight" )]} qscale_weight" ] = scale
36
+ else :
37
+ modified_tensors [name ] = tensor
38
+
39
+ safetensors .torch .save_file (modified_tensors , file_path )
40
+ print (f"Updated { file_path } with quantized tensors" )
41
+
42
+ def update_index_file (index_file_path ):
43
+ """Update the index file for the quantized model."""
44
+ print (f"Updating index file: { index_file_path } " )
45
+ with open (index_file_path , 'r' ) as f :
46
+ index = json .load (f )
47
+
48
+ new_weight_map = {}
49
+ for tensor_name , file_name in index ['weight_map' ].items ():
50
+ new_weight_map [tensor_name ] = file_name
51
+ if is_quantizable (tensor_name ):
52
+ new_weight_map [f"{ tensor_name [:- len ("weight" )]} qscale_weight" ] = file_name
53
+
54
+ index ['weight_map' ] = new_weight_map
55
+
56
+ # Recalculate total_size
57
+ total_size = sum (os .path .getsize (os .path .join (os .path .dirname (index_file_path ), file ))
58
+ for file in set (index ['weight_map' ].values ()))
59
+ index ['metadata' ]['total_size' ] = total_size
60
+
61
+ with open (index_file_path , 'w' ) as f :
62
+ json .dump (index , f , indent = 2 )
63
+ print (f"Updated index file { index_file_path } " )
64
+
65
+ def update_config (config_file_path ):
66
+ """Update the params.json file for the quantized model."""
67
+ print (f"Updating config file: { config_file_path } " )
68
+ with open (config_file_path , 'r' ) as f :
69
+ config = json .load (f )
70
+
71
+ config ["quantization" ] = {
72
+ "config_groups" : {
73
+ "group_0" : {
74
+ "input_activations" : {
75
+ "dynamic" : True ,
76
+ "num_bits" : 8 ,
77
+ "observer" : None ,
78
+ "strategy" : "token" ,
79
+ "symmetric" : True ,
80
+ "type" : "float"
81
+ },
82
+ "targets" : ["Linear" ],
83
+ "weights" : {
84
+ "dynamic" : False ,
85
+ "num_bits" : 8 ,
86
+ "observer" : "minmax" ,
87
+ "strategy" : "tensor" ,
88
+ "symmetric" : True ,
89
+ "type" : "float"
90
+ }
91
+ }},
92
+ "format" : "float-quantized" ,
93
+ "ignore" : ["lm_head" , "output" ],
94
+ "quant_method" : "compressed-tensors" ,
95
+ "quantization_status" : "compressed"
96
+ }
97
+
98
+ with open (config_file_path , 'w' ) as f :
99
+ json .dump (config , f , indent = 2 )
100
+ print (f"Updated config file { config_file_path } " )
101
+
102
+ def process_directory (directory ):
103
+ """Process all safetensors files in the given directory."""
104
+ for filename in os .listdir (directory ):
105
+ file_path = os .path .join (directory , filename )
106
+ if filename .endswith ('.safetensors' ):
107
+ process_safetensors_file (file_path )
108
+ elif filename == 'consolidated.safetensors.index.json' :
109
+ update_index_file (file_path )
110
+ elif filename == 'params.json' :
111
+ update_config (file_path )
112
+ else :
113
+ print (f"Skipping unrecognized file: { filename } " )
114
+
115
+ if __name__ == '__main__' :
116
+ parser = argparse .ArgumentParser (description = 'Convert mistral safetensors model to FP8 in-place.' )
117
+ parser .add_argument ('directory' , type = str , help = 'The directory containing the safetensors files and index file.' )
118
+
119
+ args = parser .parse_args ()
120
+ process_directory (args .directory )
0 commit comments