18
18
19
19
import torch
20
20
from compressed_tensors .compressors .base import BaseCompressor
21
- from compressed_tensors .quantization import QuantizationArgs , QuantizationStrategy
21
+ from compressed_tensors .quantization import QuantizationScheme , QuantizationStrategy
22
22
from compressed_tensors .utils import (
23
23
get_nested_mappings_from_state_dict ,
24
24
get_nested_weight_mappings ,
25
25
merge_names ,
26
+ remove_suffix ,
26
27
)
27
28
from safetensors import safe_open
28
29
from torch import Tensor
@@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
69
70
def compress (
70
71
self ,
71
72
model_state : Dict [str , Tensor ],
72
- names_to_scheme : Dict [str , QuantizationArgs ],
73
+ names_to_scheme : Dict [str , QuantizationScheme ],
73
74
** kwargs ,
74
75
) -> Dict [str , Tensor ]:
75
76
"""
@@ -81,96 +82,98 @@ def compress(
81
82
:return: compressed state dict
82
83
"""
83
84
compressed_dict = {}
84
- weight_suffix = ".weight"
85
- input_zp_suffix = ".input_zero_point"
86
- weight_zp_suffix = ".weight_zero_point"
87
- _LOGGER .debug (
88
- f"Compressing model with { len (model_state )} parameterized layers..."
89
- )
85
+ save_device = "cpu"
86
+
87
+ uncompressed_names = list (model_state .keys ())
88
+ for name in tqdm (uncompressed_names , desc = "Compressing with quantization" ):
89
+ value = model_state [name ]
90
+
91
+ # compress weights
92
+ if name .endswith ("weight" ):
93
+ prefix = remove_suffix (name , "weight" )
94
+
95
+ # gather qparams
96
+ scale = model_state .get (prefix + "weight_scale" , None )
97
+ g_idx = model_state .get (prefix + "weight_g_idx" , None )
98
+ zp = model_state .get (prefix + "weight_zero_point" , None )
99
+
100
+ # is scale does not exist, then weight cannot be compressed
101
+ if scale is None :
102
+ compressed_dict [name ] = value .to (save_device )
103
+ continue
104
+
105
+ # compress values on cpu (memory movement too expensive)
106
+ module_path = prefix [:- 1 ] if prefix .endswith ("." ) else prefix
107
+ quant_args = names_to_scheme [module_path ].weights
108
+ compressed_values = self .compress_weight (
109
+ weight = value ,
110
+ scale = scale ,
111
+ zero_point = zp ,
112
+ g_idx = g_idx ,
113
+ quantization_args = quant_args ,
114
+ device = "cpu" ,
115
+ )
116
+
117
+ # update state dict
118
+ for key , value in compressed_values .items ():
119
+ compressed_dict [prefix + key ] = value .to (save_device )
90
120
91
- for name , value in tqdm (model_state .items (), desc = "Quantized Compression" ):
92
- # check if the parameter we're compressing is the weight zp
93
- # or the input zp
94
- is_weight_zp = name .endswith (weight_zp_suffix )
95
- is_input_zp = name .endswith (input_zp_suffix )
96
-
97
- # if we're saving the weight zp, fetch weight quant args
98
- if is_weight_zp :
99
- quant_args_zp = names_to_scheme .get (name [: - (len (weight_zp_suffix ))])
100
- if isinstance (quant_args_zp , tuple ):
101
- # If tuple, first value is weight args, second is input args
102
- quant_args_zp = quant_args_zp [0 ]
103
-
104
- # if we're saving the input zp, fetch input quant args
105
- if is_input_zp :
106
- input_args_zp = names_to_scheme .get (name [: - (len (input_zp_suffix ))])
107
- if isinstance (input_args_zp , tuple ):
108
- # If tuple, first value is weight args, second is input args
109
- input_args_zp = input_args_zp [- 1 ]
110
-
111
- if name .endswith (weight_suffix ):
112
- prefix = name [: - (len (weight_suffix ))]
113
- scale = model_state .get (merge_names (prefix , "weight_scale" ), None )
114
- zp = model_state .get (merge_names (prefix , "weight_zero_point" ), None )
115
- g_idx = model_state .get (merge_names (prefix , "weight_g_idx" ), None )
116
- if scale is not None :
117
- # weight is quantized, compress it
118
- if isinstance (names_to_scheme [prefix ], tuple ):
119
- quant_args = names_to_scheme [prefix ][0 ]
120
- else :
121
- quant_args = names_to_scheme [prefix ]
122
-
123
- compressed_data = self .compress_weight (
124
- weight = value ,
125
- scale = scale ,
126
- zero_point = zp ,
127
- g_idx = g_idx ,
128
- quantization_args = quant_args ,
129
- device = "cpu" ,
130
- )
131
- for key , value in compressed_data .items ():
132
- compressed_dict [merge_names (prefix , key )] = value
133
- else :
134
- compressed_dict [name ] = value .to ("cpu" )
135
- # only save zp if asym and not packed zp
136
- elif is_weight_zp and (
137
- quant_args_zp .symmetric or self ._check_if_zp_pack_quantized (quant_args )
138
- ):
139
- continue
140
- # only save if asym
141
- elif is_input_zp and input_args_zp .symmetric :
142
- continue
143
- elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
144
- continue
145
121
else :
146
- compressed_dict [name ] = value .to ("cpu" )
122
+ # omit saving zero points for symmetric quantization
123
+ if name .endswith ("zero_point" ) and not self ._should_save_zp (
124
+ name , names_to_scheme
125
+ ):
126
+ continue
127
+
128
+ # omit saving for g_idx if uninitialized
129
+ # TODO: does this case actually occur?
130
+ elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
131
+ continue
132
+
133
+ compressed_dict [name ] = value .to (save_device )
147
134
148
135
return compressed_dict
149
136
150
- def _check_if_zp_pack_quantized (self , quant_args ):
137
+ def _should_save_zp (
138
+ self , name : str , names_to_scheme : Dict [str , QuantizationScheme ]
139
+ ) -> bool :
151
140
from compressed_tensors .compressors import PackedQuantizationCompressor
152
141
153
- if isinstance (self , PackedQuantizationCompressor ):
154
- if not quant_args .symmetric and quant_args .strategy in [
155
- QuantizationStrategy .GROUP .value ,
156
- QuantizationStrategy .CHANNEL .value ,
157
- ]:
158
- return True
159
- return False
142
+ module_name , zp_name = name .rsplit ("." , 1 ) if "." in name else ("" , name )
143
+ scheme = names_to_scheme [module_name ]
144
+
145
+ if zp_name == "weight_zero_point" :
146
+ args = scheme .weights
147
+ if zp_name == "input_zero_point" :
148
+ args = scheme .input_activations
149
+ if zp_name == "output_zero_point" :
150
+ args = scheme .output_activations
151
+
152
+ symmetric = args .symmetric
153
+ packable_strats = [
154
+ QuantizationStrategy .GROUP .value ,
155
+ QuantizationStrategy .CHANNEL .value ,
156
+ ]
157
+ packed = (
158
+ isinstance (self , PackedQuantizationCompressor )
159
+ and args .strategy in packable_strats
160
+ )
161
+
162
+ return not symmetric and not packed
160
163
161
164
def decompress (
162
165
self ,
163
166
path_to_model_or_tensors : Union [str , Path , Dict [str , Any ]],
164
- names_to_scheme : Dict [str , QuantizationArgs ],
165
- device : str = "cpu" ,
167
+ names_to_scheme : Dict [str , QuantizationScheme ],
168
+ device : torch . device = "cpu" ,
166
169
) -> Generator [Tuple [str , Tensor ], None , None ]:
167
170
"""
168
171
Reads a compressed state dict located at path_to_model_or_tensors
169
172
and returns a generator for sequentially decompressing back to a
170
173
dense state dict
171
174
:param path_to_model_or_tensors: path to compressed safetensors model (directory
172
175
with one or more safetensors files) or compressed tensors file
173
- :param names_to_scheme: quantization args for each quantized weight
176
+ :param names_to_scheme: quantization scheme for each quantized weight
174
177
:param device: optional device to load intermediate weights into
175
178
:return: compressed state dict
176
179
"""
@@ -184,7 +187,12 @@ def decompress(
184
187
path_to_model_or_tensors , names_to_scheme
185
188
)
186
189
187
- def _decompress_from_path (self , path_to_model , names_to_scheme , device ):
190
+ def _decompress_from_path (
191
+ self ,
192
+ path_to_model : Union [str , Path , Dict [str , Any ]],
193
+ names_to_scheme : Dict [str , QuantizationScheme ],
194
+ device : torch .device ,
195
+ ):
188
196
weight_mappings = get_nested_weight_mappings (
189
197
path_to_model , self .compression_param_names
190
198
)
@@ -195,7 +203,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
195
203
with safe_open (safe_path , framework = "pt" , device = device ) as f :
196
204
weight_data [param_name ] = f .get_tensor (full_name )
197
205
if "weight_scale" in weight_data :
198
- quant_args = names_to_scheme [weight_name ]
206
+ quant_args = names_to_scheme [weight_name ]. weights
199
207
decompressed = self .decompress_weight (
200
208
compressed_data = weight_data , quantization_args = quant_args
201
209
)
0 commit comments