18
18
19
19
import torch
20
20
from compressed_tensors .compressors .base import BaseCompressor
21
- from compressed_tensors .quantization import QuantizationArgs
21
+ from compressed_tensors .quantization import QuantizationScheme
22
22
from compressed_tensors .utils import (
23
23
get_nested_mappings_from_state_dict ,
24
24
get_nested_weight_mappings ,
25
25
merge_names ,
26
+ remove_suffix ,
26
27
)
27
28
from safetensors import safe_open
28
29
from torch import Tensor
@@ -69,95 +70,82 @@ class BaseQuantizationCompressor(BaseCompressor):
69
70
def compress (
70
71
self ,
71
72
model_state : Dict [str , Tensor ],
72
- names_to_scheme : Dict [str , QuantizationArgs ],
73
+ names_to_scheme : Dict [str , QuantizationScheme ],
73
74
** kwargs ,
74
75
) -> Dict [str , Tensor ]:
75
76
"""
76
77
Compresses a dense state dict
77
78
78
- :param model_state: state dict of uncompressed model
79
+ :param model_state: state dict of uncompressed model, consumed during
80
+ compression
79
81
:param names_to_scheme: quantization args for each quantized weight, needed for
80
82
quantize function to calculate bit depth
81
83
:return: compressed state dict
82
84
"""
83
- compressed_dict = {}
84
- weight_suffix = ".weight"
85
- input_zp_suffix = ".input_zero_point"
86
- weight_zp_suffix = ".weight_zero_point"
87
- _LOGGER .debug (
88
- f"Compressing model with { len (model_state )} parameterized layers..."
89
- )
85
+ save_device = "cpu"
86
+
87
+ uncompressed_names = list (model_state .keys ())
88
+ for name in tqdm (uncompressed_names , desc = "Compressing with quantization" ):
89
+ value = model_state [name ]
90
+
91
+ # compress weights
92
+ if name .endswith (".weight" ):
93
+ prefix = remove_suffix (name , ".weight" )
90
94
91
- for name , value in tqdm (model_state .items (), desc = "Quantized Compression" ):
92
- # check if the parameter we're compressing is the weight zp
93
- # or the input zp
94
- is_weight_zp = name .endswith (weight_zp_suffix )
95
- is_input_zp = name .endswith (input_zp_suffix )
96
-
97
- # if we're saving the weight zp, fetch weight quant args
98
- if is_weight_zp :
99
- quant_args_zp = names_to_scheme .get (name [: - (len (weight_zp_suffix ))])
100
- if isinstance (quant_args_zp , tuple ):
101
- # If tuple, first value is weight args, second is input args
102
- quant_args_zp = quant_args_zp [0 ]
103
-
104
- # if we're saving the input zp, fetch input quant args
105
- if is_input_zp :
106
- input_args_zp = names_to_scheme .get (name [: - (len (input_zp_suffix ))])
107
- if isinstance (input_args_zp , tuple ):
108
- # If tuple, first value is weight args, second is input args
109
- input_args_zp = input_args_zp [- 1 ]
110
-
111
- if name .endswith (weight_suffix ):
112
- prefix = name [: - (len (weight_suffix ))]
95
+ # gather qparams
113
96
scale = model_state .get (merge_names (prefix , "weight_scale" ), None )
114
- zp = model_state .get (merge_names (prefix , "weight_zero_point" ), None )
115
97
g_idx = model_state .get (merge_names (prefix , "weight_g_idx" ), None )
116
- if scale is not None :
117
- # weight is quantized, compress it
118
- if isinstance (names_to_scheme [prefix ], tuple ):
119
- quant_args = names_to_scheme [prefix ][0 ]
120
- else :
121
- quant_args = names_to_scheme [prefix ]
122
-
123
- compressed_data = self .compress_weight (
124
- weight = value ,
125
- scale = scale ,
126
- zero_point = zp ,
127
- g_idx = g_idx ,
128
- quantization_args = quant_args ,
129
- device = "cpu" ,
130
- )
131
- for key , value in compressed_data .items ():
132
- compressed_dict [merge_names (prefix , key )] = value
133
- else :
134
- compressed_dict [name ] = value .to ("cpu" )
135
- # only save if asym
136
- elif is_weight_zp and quant_args_zp .symmetric :
137
- continue
138
- # only save if asym
139
- elif is_input_zp and input_args_zp .symmetric :
140
- continue
141
- elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
142
- continue
98
+ zp = model_state .get (merge_names (prefix , "weight_zero_point" ), None )
99
+
100
+ # is scale does not exist, then weight cannot be compressed
101
+ if scale is None :
102
+ model_state [name ] = value .to (save_device )
103
+ continue
104
+
105
+ # compress values on cpu. TODO: experiment with different devices
106
+ quant_args = names_to_scheme [prefix ].weights
107
+ compressed_values = self .compress_weight (
108
+ weight = value ,
109
+ scale = scale ,
110
+ zero_point = zp ,
111
+ g_idx = g_idx ,
112
+ quantization_args = quant_args ,
113
+ device = "cpu" ,
114
+ )
115
+
116
+ # update state dict
117
+ del model_state [name ]
118
+ for key , value in compressed_values .items ():
119
+ model_state [merge_names (prefix , key )] = value .to (save_device )
120
+
143
121
else :
144
- compressed_dict [name ] = value .to ("cpu" )
122
+ # omit saving zero points for symmetric quantization
123
+ if name .endswith ("zero_point" ) and _is_symmetric (name , names_to_scheme ):
124
+ del model_state [name ]
145
125
146
- return compressed_dict
126
+ # omit saving for g_idx if uninitialized
127
+ # TODO: does this case actually occur?
128
+ elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
129
+ del model_state [name ]
130
+
131
+ else :
132
+ model_state [name ] = value .to (save_device )
133
+
134
+ return model_state
147
135
148
136
def decompress (
149
137
self ,
150
138
path_to_model_or_tensors : Union [str , Path , Dict [str , Any ]],
151
- names_to_scheme : Dict [str , QuantizationArgs ],
152
- device : str = "cpu" ,
139
+ names_to_scheme : Dict [str , QuantizationScheme ],
140
+ device : torch . device = "cpu" ,
153
141
) -> Generator [Tuple [str , Tensor ], None , None ]:
154
142
"""
155
143
Reads a compressed state dict located at path_to_model_or_tensors
156
144
and returns a generator for sequentially decompressing back to a
157
145
dense state dict
158
146
:param path_to_model_or_tensors: path to compressed safetensors model (directory
159
147
with one or more safetensors files) or compressed tensors file
160
- :param names_to_scheme: quantization args for each quantized weight
148
+ :param names_to_scheme: quantization scheme for each quantized weight
161
149
:param device: optional device to load intermediate weights into
162
150
:return: compressed state dict
163
151
"""
@@ -171,7 +159,12 @@ def decompress(
171
159
path_to_model_or_tensors , names_to_scheme
172
160
)
173
161
174
- def _decompress_from_path (self , path_to_model , names_to_scheme , device ):
162
+ def _decompress_from_path (
163
+ self ,
164
+ path_to_model : Union [str , Path , Dict [str , Any ]],
165
+ names_to_scheme : Dict [str , QuantizationScheme ],
166
+ device : torch .device ,
167
+ ):
175
168
weight_mappings = get_nested_weight_mappings (
176
169
path_to_model , self .compression_param_names
177
170
)
@@ -182,13 +175,17 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
182
175
with safe_open (safe_path , framework = "pt" , device = device ) as f :
183
176
weight_data [param_name ] = f .get_tensor (full_name )
184
177
if "weight_scale" in weight_data :
185
- quant_args = names_to_scheme [weight_name ]
178
+ quant_args = names_to_scheme [weight_name ]. weights
186
179
decompressed = self .decompress_weight (
187
180
compressed_data = weight_data , quantization_args = quant_args
188
181
)
189
182
yield merge_names (weight_name , "weight" ), decompressed
190
183
191
- def _decompress_from_state_dict (self , state_dict , names_to_scheme ):
184
+ def _decompress_from_state_dict (
185
+ self ,
186
+ state_dict : Dict [str , torch .Tensor ],
187
+ names_to_scheme : Dict [str , QuantizationScheme ],
188
+ ):
192
189
weight_mappings = get_nested_mappings_from_state_dict (
193
190
state_dict , self .compression_param_names
194
191
)
@@ -198,8 +195,23 @@ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
198
195
weight_data [param_name ] = param_value
199
196
200
197
if "weight_scale" in weight_data :
201
- quant_args = names_to_scheme [weight_name ]
198
+ quant_args = names_to_scheme [weight_name ]. weights
202
199
decompressed = self .decompress_weight (
203
200
compressed_data = weight_data , quantization_args = quant_args
204
201
)
205
202
yield merge_names (weight_name , "weight" ), decompressed
203
+
204
+
205
+ def _is_symmetric (name : str , names_to_scheme : Dict [str , QuantizationScheme ]) -> bool :
206
+ weight_name , zp_name = name .rsplit ("." , 1 )
207
+ scheme = names_to_scheme [weight_name ]
208
+
209
+ if zp_name == "weight_zero_point" :
210
+ quant_args = scheme .weights
211
+ if zp_name == "input_zero_point" :
212
+ quant_args = scheme .input_activations
213
+ if zp_name == "output_zero_point" :
214
+ quant_args = scheme .output_activations
215
+
216
+ assert quant_args is not None
217
+ return quant_args .symmetric
0 commit comments