14
14
15
15
import logging
16
16
from pathlib import Path
17
- from typing import Any , Dict , Generator , Tuple , Union
17
+ from typing import Any , Dict , Generator , Optional , Tuple , Union
18
18
19
19
import torch
20
20
from compressed_tensors .compressors .base import BaseCompressor
21
- from compressed_tensors .quantization import QuantizationScheme , QuantizationStrategy
21
+ from compressed_tensors .quantization import QuantizationArgs , QuantizationStrategy
22
22
from compressed_tensors .utils import (
23
23
get_nested_mappings_from_state_dict ,
24
24
get_nested_weight_mappings ,
25
25
merge_names ,
26
- remove_suffix ,
27
26
)
28
27
from safetensors import safe_open
29
28
from torch import Tensor
@@ -70,7 +69,7 @@ class BaseQuantizationCompressor(BaseCompressor):
70
69
def compress (
71
70
self ,
72
71
model_state : Dict [str , Tensor ],
73
- names_to_scheme : Dict [str , QuantizationScheme ],
72
+ names_to_scheme : Dict [str , QuantizationArgs ],
74
73
** kwargs ,
75
74
) -> Dict [str , Tensor ]:
76
75
"""
@@ -82,87 +81,87 @@ def compress(
82
81
:return: compressed state dict
83
82
"""
84
83
compressed_dict = {}
85
- save_device = "cpu"
86
-
87
- uncompressed_names = list (model_state .keys ())
88
- for name in tqdm (uncompressed_names , desc = "Compressing with quantization" ):
89
- value = model_state [name ]
90
-
91
- # compress weights
92
- if name .endswith ("weight" ):
93
- prefix = remove_suffix (name , "weight" )
94
-
95
- # gather qparams
96
- scale = model_state .get (prefix + "weight_scale" , None )
97
- g_idx = model_state .get (prefix + "weight_g_idx" , None )
98
- zp = model_state .get (prefix + "weight_zero_point" , None )
99
-
100
- # is scale does not exist, then weight cannot be compressed
101
- if scale is None :
102
- compressed_dict [name ] = value .to (save_device )
103
- continue
104
-
105
- # compress values on cpu (memory movement too expensive)
106
- module_path = prefix [:- 1 ] if prefix .endswith ("." ) else prefix
107
- quant_args = names_to_scheme [module_path ].weights
108
- compressed_values = self .compress_weight (
109
- weight = value ,
110
- scale = scale ,
111
- zero_point = zp ,
112
- g_idx = g_idx ,
113
- quantization_args = quant_args ,
114
- device = "cpu" ,
115
- )
116
-
117
- # update state dict
118
- for key , value in compressed_values .items ():
119
- compressed_dict [prefix + key ] = value .to (save_device )
84
+ weight_suffix = ".weight"
85
+ input_zp_suffix = ".input_zero_point"
86
+ weight_zp_suffix = ".weight_zero_point"
87
+ _LOGGER .debug (
88
+ f"Compressing model with { len (model_state )} parameterized layers..."
89
+ )
120
90
91
+ for name , value in tqdm (model_state .items (), desc = "Quantized Compression" ):
92
+ # check if the parameter we're compressing is the weight zp
93
+ # or the input zp
94
+ is_weight_zp = name .endswith (weight_zp_suffix )
95
+ is_input_zp = name .endswith (input_zp_suffix )
96
+
97
+ # if we're saving the weight zp, fetch weight quant args
98
+ if is_weight_zp :
99
+ quant_args_zp = names_to_scheme .get (name [: - (len (weight_zp_suffix ))])
100
+ if isinstance (quant_args_zp , tuple ):
101
+ # If tuple, first value is weight args, second is input args
102
+ quant_args_zp = quant_args_zp [0 ]
103
+
104
+ # if we're saving the input zp, fetch input quant args
105
+ if is_input_zp :
106
+ input_args_zp = names_to_scheme .get (name [: - (len (input_zp_suffix ))])
107
+ if isinstance (input_args_zp , tuple ):
108
+ # If tuple, first value is weight args, second is input args
109
+ input_args_zp = input_args_zp [- 1 ]
110
+
111
+ if name .endswith (weight_suffix ):
112
+ prefix = name [: - (len (weight_suffix ))]
113
+ scale = model_state .get (merge_names (prefix , "weight_scale" ), None )
114
+ zp = model_state .get (merge_names (prefix , "weight_zero_point" ), None )
115
+ g_idx = model_state .get (merge_names (prefix , "weight_g_idx" ), None )
116
+ if scale is not None :
117
+ # weight is quantized, compress it
118
+ if isinstance (names_to_scheme [prefix ], tuple ):
119
+ quant_args = names_to_scheme [prefix ][0 ]
120
+ else :
121
+ quant_args = names_to_scheme [prefix ]
122
+
123
+ compressed_data = self .compress_weight (
124
+ weight = value ,
125
+ scale = scale ,
126
+ zero_point = zp ,
127
+ g_idx = g_idx ,
128
+ quantization_args = quant_args ,
129
+ device = "cpu" ,
130
+ )
131
+ for key , value in compressed_data .items ():
132
+ compressed_dict [merge_names (prefix , key )] = value
133
+ else :
134
+ compressed_dict [name ] = value .to ("cpu" )
135
+ # only save zp if asym and not packed zp
136
+ elif is_weight_zp and (
137
+ quant_args_zp .symmetric or self ._check_if_zp_pack_quantized (quant_args )
138
+ ):
139
+ continue
140
+ # only save if asym
141
+ elif is_input_zp and input_args_zp .symmetric :
142
+ continue
143
+ elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
144
+ continue
121
145
else :
122
- # omit saving zero points for symmetric or packed quantization
123
- if name .endswith ("zero_point" ) and self ._skip_zp (name , names_to_scheme ):
124
- continue
125
-
126
- # omit saving for g_idx if uninitialized
127
- # TODO: does this case actually occur?
128
- elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
129
- continue
130
-
131
- compressed_dict [name ] = value .to (save_device )
146
+ compressed_dict [name ] = value .to ("cpu" )
132
147
133
148
return compressed_dict
134
149
135
- def _skip_zp (
136
- self , name : str , names_to_scheme : Dict [str , QuantizationScheme ]
137
- ) -> bool :
150
+ def _check_if_zp_pack_quantized (self , quant_args ):
138
151
from compressed_tensors .compressors import PackedQuantizationCompressor
139
152
140
- module_name , zp_name = name .rsplit ("." , 1 ) if "." in name else ("" , name )
141
- scheme = names_to_scheme [module_name ]
142
-
143
- if zp_name == "weight_zero_point" :
144
- args = scheme .weights
145
- if zp_name == "input_zero_point" :
146
- args = scheme .input_activations
147
- if zp_name == "output_zero_point" :
148
- args = scheme .output_activations
149
-
150
- symmetric = args .symmetric
151
- packable_strategies = [
152
- QuantizationStrategy .GROUP .value ,
153
- QuantizationStrategy .CHANNEL .value ,
154
- ]
155
- packed = (
156
- isinstance (self , PackedQuantizationCompressor )
157
- and args .strategy in packable_strategies
158
- )
159
-
160
- return symmetric or packed
153
+ if isinstance (self , PackedQuantizationCompressor ):
154
+ if not quant_args .symmetric and quant_args .strategy in [
155
+ QuantizationStrategy .GROUP .value ,
156
+ QuantizationStrategy .CHANNEL .value ,
157
+ ]:
158
+ return True
159
+ return False
161
160
162
161
def decompress (
163
162
self ,
164
163
path_to_model_or_tensors : Union [str , Path , Dict [str , Any ]],
165
- names_to_scheme : Dict [str , QuantizationScheme ],
164
+ names_to_scheme : Dict [str , QuantizationArgs ],
166
165
device : str = "cpu" ,
167
166
) -> Generator [Tuple [str , Tensor ], None , None ]:
168
167
"""
@@ -171,9 +170,8 @@ def decompress(
171
170
dense state dict
172
171
:param path_to_model_or_tensors: path to compressed safetensors model (directory
173
172
with one or more safetensors files) or compressed tensors file
174
- :param names_to_scheme: quantization scheme for each quantized weight
175
- :param device: optional device to load intermediate weights into (must be `str`,
176
- not `torch.device`)
173
+ :param names_to_scheme: quantization args for each quantized weight
174
+ :param device: optional device to load intermediate weights into
177
175
:return: compressed state dict
178
176
"""
179
177
if isinstance (path_to_model_or_tensors , (str , Path )):
@@ -186,12 +184,7 @@ def decompress(
186
184
path_to_model_or_tensors , names_to_scheme
187
185
)
188
186
189
- def _decompress_from_path (
190
- self ,
191
- path_to_model : Union [str , Path , Dict [str , Any ]],
192
- names_to_scheme : Dict [str , QuantizationScheme ],
193
- device : str ,
194
- ):
187
+ def _decompress_from_path (self , path_to_model , names_to_scheme , device ):
195
188
weight_mappings = get_nested_weight_mappings (
196
189
path_to_model , self .compression_param_names
197
190
)
@@ -202,7 +195,7 @@ def _decompress_from_path(
202
195
with safe_open (safe_path , framework = "pt" , device = device ) as f :
203
196
weight_data [param_name ] = f .get_tensor (full_name )
204
197
if "weight_scale" in weight_data :
205
- quant_args = names_to_scheme [weight_name ]. weights
198
+ quant_args = names_to_scheme [weight_name ]
206
199
decompressed = self .decompress_weight (
207
200
compressed_data = weight_data , quantization_args = quant_args
208
201
)
0 commit comments