1414
1515import logging
1616from pathlib import Path
17- from typing import Any , Dict , Generator , Optional , Tuple , Union
17+ from typing import Any , Dict , Generator , Tuple , Union
1818
1919import torch
2020from compressed_tensors .compressors .base import BaseCompressor
21- from compressed_tensors .quantization import QuantizationArgs , QuantizationStrategy
21+ from compressed_tensors .quantization import QuantizationScheme , QuantizationStrategy
2222from compressed_tensors .utils import (
2323 get_nested_mappings_from_state_dict ,
2424 get_nested_weight_mappings ,
2525 merge_names ,
26+ remove_suffix ,
2627)
2728from safetensors import safe_open
2829from torch import Tensor
@@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
6970 def compress (
7071 self ,
7172 model_state : Dict [str , Tensor ],
72- names_to_scheme : Dict [str , QuantizationArgs ],
73+ names_to_scheme : Dict [str , QuantizationScheme ],
7374 ** kwargs ,
7475 ) -> Dict [str , Tensor ]:
7576 """
@@ -81,87 +82,87 @@ def compress(
8182 :return: compressed state dict
8283 """
8384 compressed_dict = {}
84- weight_suffix = ".weight"
85- input_zp_suffix = ".input_zero_point"
86- weight_zp_suffix = ".weight_zero_point"
87- _LOGGER .debug (
88- f"Compressing model with { len (model_state )} parameterized layers..."
89- )
85+ save_device = "cpu"
86+
87+ uncompressed_names = list (model_state .keys ())
88+ for name in tqdm (uncompressed_names , desc = "Compressing with quantization" ):
89+ value = model_state [name ]
90+
91+ # compress weights
92+ if name .endswith ("weight" ):
93+ prefix = remove_suffix (name , "weight" )
94+
95+ # gather qparams
96+ scale = model_state .get (prefix + "weight_scale" , None )
97+ g_idx = model_state .get (prefix + "weight_g_idx" , None )
98+ zp = model_state .get (prefix + "weight_zero_point" , None )
99+
100+ # is scale does not exist, then weight cannot be compressed
101+ if scale is None :
102+ compressed_dict [name ] = value .to (save_device )
103+ continue
104+
105+ # compress values on cpu (memory movement too expensive)
106+ module_path = prefix [:- 1 ] if prefix .endswith ("." ) else prefix
107+ quant_args = names_to_scheme [module_path ].weights
108+ compressed_values = self .compress_weight (
109+ weight = value ,
110+ scale = scale ,
111+ zero_point = zp ,
112+ g_idx = g_idx ,
113+ quantization_args = quant_args ,
114+ device = "cpu" ,
115+ )
116+
117+ # update state dict
118+ for key , value in compressed_values .items ():
119+ compressed_dict [prefix + key ] = value .to (save_device )
90120
91- for name , value in tqdm (model_state .items (), desc = "Quantized Compression" ):
92- # check if the parameter we're compressing is the weight zp
93- # or the input zp
94- is_weight_zp = name .endswith (weight_zp_suffix )
95- is_input_zp = name .endswith (input_zp_suffix )
96-
97- # if we're saving the weight zp, fetch weight quant args
98- if is_weight_zp :
99- quant_args_zp = names_to_scheme .get (name [: - (len (weight_zp_suffix ))])
100- if isinstance (quant_args_zp , tuple ):
101- # If tuple, first value is weight args, second is input args
102- quant_args_zp = quant_args_zp [0 ]
103-
104- # if we're saving the input zp, fetch input quant args
105- if is_input_zp :
106- input_args_zp = names_to_scheme .get (name [: - (len (input_zp_suffix ))])
107- if isinstance (input_args_zp , tuple ):
108- # If tuple, first value is weight args, second is input args
109- input_args_zp = input_args_zp [- 1 ]
110-
111- if name .endswith (weight_suffix ):
112- prefix = name [: - (len (weight_suffix ))]
113- scale = model_state .get (merge_names (prefix , "weight_scale" ), None )
114- zp = model_state .get (merge_names (prefix , "weight_zero_point" ), None )
115- g_idx = model_state .get (merge_names (prefix , "weight_g_idx" ), None )
116- if scale is not None :
117- # weight is quantized, compress it
118- if isinstance (names_to_scheme [prefix ], tuple ):
119- quant_args = names_to_scheme [prefix ][0 ]
120- else :
121- quant_args = names_to_scheme [prefix ]
122-
123- compressed_data = self .compress_weight (
124- weight = value ,
125- scale = scale ,
126- zero_point = zp ,
127- g_idx = g_idx ,
128- quantization_args = quant_args ,
129- device = "cpu" ,
130- )
131- for key , value in compressed_data .items ():
132- compressed_dict [merge_names (prefix , key )] = value
133- else :
134- compressed_dict [name ] = value .to ("cpu" )
135- # only save zp if asym and not packed zp
136- elif is_weight_zp and (
137- quant_args_zp .symmetric or self ._check_if_zp_pack_quantized (quant_args )
138- ):
139- continue
140- # only save if asym
141- elif is_input_zp and input_args_zp .symmetric :
142- continue
143- elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
144- continue
145121 else :
146- compressed_dict [name ] = value .to ("cpu" )
122+ # omit saving zero points for symmetric or packed quantization
123+ if name .endswith ("zero_point" ) and self ._skip_zp (name , names_to_scheme ):
124+ continue
125+
126+ # omit saving for g_idx if uninitialized
127+ # TODO: does this case actually occur?
128+ elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
129+ continue
130+
131+ compressed_dict [name ] = value .to (save_device )
147132
148133 return compressed_dict
149134
150- def _check_if_zp_pack_quantized (self , quant_args ):
135+ def _skip_zp (
136+ self , name : str , names_to_scheme : Dict [str , QuantizationScheme ]
137+ ) -> bool :
151138 from compressed_tensors .compressors import PackedQuantizationCompressor
152139
153- if isinstance (self , PackedQuantizationCompressor ):
154- if not quant_args .symmetric and quant_args .strategy in [
155- QuantizationStrategy .GROUP .value ,
156- QuantizationStrategy .CHANNEL .value ,
157- ]:
158- return True
159- return False
140+ module_name , zp_name = name .rsplit ("." , 1 ) if "." in name else ("" , name )
141+ scheme = names_to_scheme [module_name ]
142+
143+ if zp_name == "weight_zero_point" :
144+ args = scheme .weights
145+ if zp_name == "input_zero_point" :
146+ args = scheme .input_activations
147+ if zp_name == "output_zero_point" :
148+ args = scheme .output_activations
149+
150+ symmetric = args .symmetric
151+ packable_strategies = [
152+ QuantizationStrategy .GROUP .value ,
153+ QuantizationStrategy .CHANNEL .value ,
154+ ]
155+ packed = (
156+ isinstance (self , PackedQuantizationCompressor )
157+ and args .strategy in packable_strategies
158+ )
159+
160+ return symmetric or packed
160161
161162 def decompress (
162163 self ,
163164 path_to_model_or_tensors : Union [str , Path , Dict [str , Any ]],
164- names_to_scheme : Dict [str , QuantizationArgs ],
165+ names_to_scheme : Dict [str , QuantizationScheme ],
165166 device : str = "cpu" ,
166167 ) -> Generator [Tuple [str , Tensor ], None , None ]:
167168 """
@@ -170,8 +171,9 @@ def decompress(
170171 dense state dict
171172 :param path_to_model_or_tensors: path to compressed safetensors model (directory
172173 with one or more safetensors files) or compressed tensors file
173- :param names_to_scheme: quantization args for each quantized weight
174- :param device: optional device to load intermediate weights into
174+ :param names_to_scheme: quantization scheme for each quantized weight
175+ :param device: optional device to load intermediate weights into (must be `str`,
176+ not `torch.device`)
175177 :return: compressed state dict
176178 """
177179 if isinstance (path_to_model_or_tensors , (str , Path )):
@@ -184,7 +186,12 @@ def decompress(
184186 path_to_model_or_tensors , names_to_scheme
185187 )
186188
187- def _decompress_from_path (self , path_to_model , names_to_scheme , device ):
189+ def _decompress_from_path (
190+ self ,
191+ path_to_model : Union [str , Path , Dict [str , Any ]],
192+ names_to_scheme : Dict [str , QuantizationScheme ],
193+ device : str ,
194+ ):
188195 weight_mappings = get_nested_weight_mappings (
189196 path_to_model , self .compression_param_names
190197 )
@@ -195,7 +202,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
195202 with safe_open (safe_path , framework = "pt" , device = device ) as f :
196203 weight_data [param_name ] = f .get_tensor (full_name )
197204 if "weight_scale" in weight_data :
198- quant_args = names_to_scheme [weight_name ]
205+ quant_args = names_to_scheme [weight_name ]. weights
199206 decompressed = self .decompress_weight (
200207 compressed_data = weight_data , quantization_args = quant_args
201208 )
0 commit comments