1414
1515import logging
1616from pathlib import Path
17- from typing import Any , Dict , Generator , Tuple , Union
17+ from typing import Any , Dict , Generator , Optional , Tuple , Union
1818
1919import torch
2020from compressed_tensors .compressors .base import BaseCompressor
21- from compressed_tensors .quantization import QuantizationScheme , QuantizationStrategy
21+ from compressed_tensors .quantization import QuantizationArgs , QuantizationStrategy
2222from compressed_tensors .utils import (
2323 get_nested_mappings_from_state_dict ,
2424 get_nested_weight_mappings ,
2525 merge_names ,
26- remove_suffix ,
2726)
2827from safetensors import safe_open
2928from torch import Tensor
@@ -70,7 +69,7 @@ class BaseQuantizationCompressor(BaseCompressor):
7069 def compress (
7170 self ,
7271 model_state : Dict [str , Tensor ],
73- names_to_scheme : Dict [str , QuantizationScheme ],
72+ names_to_scheme : Dict [str , QuantizationArgs ],
7473 ** kwargs ,
7574 ) -> Dict [str , Tensor ]:
7675 """
@@ -82,87 +81,87 @@ def compress(
8281 :return: compressed state dict
8382 """
8483 compressed_dict = {}
85- save_device = "cpu"
86-
87- uncompressed_names = list (model_state .keys ())
88- for name in tqdm (uncompressed_names , desc = "Compressing with quantization" ):
89- value = model_state [name ]
90-
91- # compress weights
92- if name .endswith ("weight" ):
93- prefix = remove_suffix (name , "weight" )
94-
95- # gather qparams
96- scale = model_state .get (prefix + "weight_scale" , None )
97- g_idx = model_state .get (prefix + "weight_g_idx" , None )
98- zp = model_state .get (prefix + "weight_zero_point" , None )
99-
100- # is scale does not exist, then weight cannot be compressed
101- if scale is None :
102- compressed_dict [name ] = value .to (save_device )
103- continue
104-
105- # compress values on cpu (memory movement too expensive)
106- module_path = prefix [:- 1 ] if prefix .endswith ("." ) else prefix
107- quant_args = names_to_scheme [module_path ].weights
108- compressed_values = self .compress_weight (
109- weight = value ,
110- scale = scale ,
111- zero_point = zp ,
112- g_idx = g_idx ,
113- quantization_args = quant_args ,
114- device = "cpu" ,
115- )
116-
117- # update state dict
118- for key , value in compressed_values .items ():
119- compressed_dict [prefix + key ] = value .to (save_device )
84+ weight_suffix = ".weight"
85+ input_zp_suffix = ".input_zero_point"
86+ weight_zp_suffix = ".weight_zero_point"
87+ _LOGGER .debug (
88+ f"Compressing model with { len (model_state )} parameterized layers..."
89+ )
12090
91+ for name , value in tqdm (model_state .items (), desc = "Quantized Compression" ):
92+ # check if the parameter we're compressing is the weight zp
93+ # or the input zp
94+ is_weight_zp = name .endswith (weight_zp_suffix )
95+ is_input_zp = name .endswith (input_zp_suffix )
96+
97+ # if we're saving the weight zp, fetch weight quant args
98+ if is_weight_zp :
99+ quant_args_zp = names_to_scheme .get (name [: - (len (weight_zp_suffix ))])
100+ if isinstance (quant_args_zp , tuple ):
101+ # If tuple, first value is weight args, second is input args
102+ quant_args_zp = quant_args_zp [0 ]
103+
104+ # if we're saving the input zp, fetch input quant args
105+ if is_input_zp :
106+ input_args_zp = names_to_scheme .get (name [: - (len (input_zp_suffix ))])
107+ if isinstance (input_args_zp , tuple ):
108+ # If tuple, first value is weight args, second is input args
109+ input_args_zp = input_args_zp [- 1 ]
110+
111+ if name .endswith (weight_suffix ):
112+ prefix = name [: - (len (weight_suffix ))]
113+ scale = model_state .get (merge_names (prefix , "weight_scale" ), None )
114+ zp = model_state .get (merge_names (prefix , "weight_zero_point" ), None )
115+ g_idx = model_state .get (merge_names (prefix , "weight_g_idx" ), None )
116+ if scale is not None :
117+ # weight is quantized, compress it
118+ if isinstance (names_to_scheme [prefix ], tuple ):
119+ quant_args = names_to_scheme [prefix ][0 ]
120+ else :
121+ quant_args = names_to_scheme [prefix ]
122+
123+ compressed_data = self .compress_weight (
124+ weight = value ,
125+ scale = scale ,
126+ zero_point = zp ,
127+ g_idx = g_idx ,
128+ quantization_args = quant_args ,
129+ device = "cpu" ,
130+ )
131+ for key , value in compressed_data .items ():
132+ compressed_dict [merge_names (prefix , key )] = value
133+ else :
134+ compressed_dict [name ] = value .to ("cpu" )
135+ # only save zp if asym and not packed zp
136+ elif is_weight_zp and (
137+ quant_args_zp .symmetric or self ._check_if_zp_pack_quantized (quant_args )
138+ ):
139+ continue
140+ # only save if asym
141+ elif is_input_zp and input_args_zp .symmetric :
142+ continue
143+ elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
144+ continue
121145 else :
122- # omit saving zero points for symmetric or packed quantization
123- if name .endswith ("zero_point" ) and self ._skip_zp (name , names_to_scheme ):
124- continue
125-
126- # omit saving for g_idx if uninitialized
127- # TODO: does this case actually occur?
128- elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
129- continue
130-
131- compressed_dict [name ] = value .to (save_device )
146+ compressed_dict [name ] = value .to ("cpu" )
132147
133148 return compressed_dict
134149
135- def _skip_zp (
136- self , name : str , names_to_scheme : Dict [str , QuantizationScheme ]
137- ) -> bool :
150+ def _check_if_zp_pack_quantized (self , quant_args ):
138151 from compressed_tensors .compressors import PackedQuantizationCompressor
139152
140- module_name , zp_name = name .rsplit ("." , 1 ) if "." in name else ("" , name )
141- scheme = names_to_scheme [module_name ]
142-
143- if zp_name == "weight_zero_point" :
144- args = scheme .weights
145- if zp_name == "input_zero_point" :
146- args = scheme .input_activations
147- if zp_name == "output_zero_point" :
148- args = scheme .output_activations
149-
150- symmetric = args .symmetric
151- packable_strategies = [
152- QuantizationStrategy .GROUP .value ,
153- QuantizationStrategy .CHANNEL .value ,
154- ]
155- packed = (
156- isinstance (self , PackedQuantizationCompressor )
157- and args .strategy in packable_strategies
158- )
159-
160- return symmetric or packed
153+ if isinstance (self , PackedQuantizationCompressor ):
154+ if not quant_args .symmetric and quant_args .strategy in [
155+ QuantizationStrategy .GROUP .value ,
156+ QuantizationStrategy .CHANNEL .value ,
157+ ]:
158+ return True
159+ return False
161160
162161 def decompress (
163162 self ,
164163 path_to_model_or_tensors : Union [str , Path , Dict [str , Any ]],
165- names_to_scheme : Dict [str , QuantizationScheme ],
164+ names_to_scheme : Dict [str , QuantizationArgs ],
166165 device : str = "cpu" ,
167166 ) -> Generator [Tuple [str , Tensor ], None , None ]:
168167 """
@@ -171,9 +170,8 @@ def decompress(
171170 dense state dict
172171 :param path_to_model_or_tensors: path to compressed safetensors model (directory
173172 with one or more safetensors files) or compressed tensors file
174- :param names_to_scheme: quantization scheme for each quantized weight
175- :param device: optional device to load intermediate weights into (must be `str`,
176- not `torch.device`)
173+ :param names_to_scheme: quantization args for each quantized weight
174+ :param device: optional device to load intermediate weights into
177175 :return: compressed state dict
178176 """
179177 if isinstance (path_to_model_or_tensors , (str , Path )):
@@ -186,12 +184,7 @@ def decompress(
186184 path_to_model_or_tensors , names_to_scheme
187185 )
188186
189- def _decompress_from_path (
190- self ,
191- path_to_model : Union [str , Path , Dict [str , Any ]],
192- names_to_scheme : Dict [str , QuantizationScheme ],
193- device : str ,
194- ):
187+ def _decompress_from_path (self , path_to_model , names_to_scheme , device ):
195188 weight_mappings = get_nested_weight_mappings (
196189 path_to_model , self .compression_param_names
197190 )
@@ -202,7 +195,7 @@ def _decompress_from_path(
202195 with safe_open (safe_path , framework = "pt" , device = device ) as f :
203196 weight_data [param_name ] = f .get_tensor (full_name )
204197 if "weight_scale" in weight_data :
205- quant_args = names_to_scheme [weight_name ]. weights
198+ quant_args = names_to_scheme [weight_name ]
206199 decompressed = self .decompress_weight (
207200 compressed_data = weight_data , quantization_args = quant_args
208201 )
0 commit comments