14
14
15
15
import logging
16
16
from pathlib import Path
17
- from typing import Any , Dict , Generator , Optional , Tuple , Union
17
+ from typing import Any , Dict , Generator , Tuple , Union
18
18
19
19
import torch
20
20
from compressed_tensors .compressors .base import BaseCompressor
21
- from compressed_tensors .quantization import QuantizationArgs , QuantizationStrategy
21
+ from compressed_tensors .quantization import QuantizationScheme , QuantizationStrategy
22
22
from compressed_tensors .utils import (
23
23
get_nested_mappings_from_state_dict ,
24
24
get_nested_weight_mappings ,
25
25
merge_names ,
26
+ remove_suffix ,
26
27
)
27
28
from safetensors import safe_open
28
29
from torch import Tensor
@@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
69
70
def compress (
70
71
self ,
71
72
model_state : Dict [str , Tensor ],
72
- names_to_scheme : Dict [str , QuantizationArgs ],
73
+ names_to_scheme : Dict [str , QuantizationScheme ],
73
74
** kwargs ,
74
75
) -> Dict [str , Tensor ]:
75
76
"""
@@ -81,87 +82,87 @@ def compress(
81
82
:return: compressed state dict
82
83
"""
83
84
compressed_dict = {}
84
- weight_suffix = ".weight"
85
- input_zp_suffix = ".input_zero_point"
86
- weight_zp_suffix = ".weight_zero_point"
87
- _LOGGER .debug (
88
- f"Compressing model with { len (model_state )} parameterized layers..."
89
- )
85
+ save_device = "cpu"
86
+
87
+ uncompressed_names = list (model_state .keys ())
88
+ for name in tqdm (uncompressed_names , desc = "Compressing with quantization" ):
89
+ value = model_state [name ]
90
+
91
+ # compress weights
92
+ if name .endswith ("weight" ):
93
+ prefix = remove_suffix (name , "weight" )
94
+
95
+ # gather qparams
96
+ scale = model_state .get (prefix + "weight_scale" , None )
97
+ g_idx = model_state .get (prefix + "weight_g_idx" , None )
98
+ zp = model_state .get (prefix + "weight_zero_point" , None )
99
+
100
+ # is scale does not exist, then weight cannot be compressed
101
+ if scale is None :
102
+ compressed_dict [name ] = value .to (save_device )
103
+ continue
104
+
105
+ # compress values on cpu (memory movement too expensive)
106
+ module_path = prefix [:- 1 ] if prefix .endswith ("." ) else prefix
107
+ quant_args = names_to_scheme [module_path ].weights
108
+ compressed_values = self .compress_weight (
109
+ weight = value ,
110
+ scale = scale ,
111
+ zero_point = zp ,
112
+ g_idx = g_idx ,
113
+ quantization_args = quant_args ,
114
+ device = "cpu" ,
115
+ )
116
+
117
+ # update state dict
118
+ for key , value in compressed_values .items ():
119
+ compressed_dict [prefix + key ] = value .to (save_device )
90
120
91
- for name , value in tqdm (model_state .items (), desc = "Quantized Compression" ):
92
- # check if the parameter we're compressing is the weight zp
93
- # or the input zp
94
- is_weight_zp = name .endswith (weight_zp_suffix )
95
- is_input_zp = name .endswith (input_zp_suffix )
96
-
97
- # if we're saving the weight zp, fetch weight quant args
98
- if is_weight_zp :
99
- quant_args_zp = names_to_scheme .get (name [: - (len (weight_zp_suffix ))])
100
- if isinstance (quant_args_zp , tuple ):
101
- # If tuple, first value is weight args, second is input args
102
- quant_args_zp = quant_args_zp [0 ]
103
-
104
- # if we're saving the input zp, fetch input quant args
105
- if is_input_zp :
106
- input_args_zp = names_to_scheme .get (name [: - (len (input_zp_suffix ))])
107
- if isinstance (input_args_zp , tuple ):
108
- # If tuple, first value is weight args, second is input args
109
- input_args_zp = input_args_zp [- 1 ]
110
-
111
- if name .endswith (weight_suffix ):
112
- prefix = name [: - (len (weight_suffix ))]
113
- scale = model_state .get (merge_names (prefix , "weight_scale" ), None )
114
- zp = model_state .get (merge_names (prefix , "weight_zero_point" ), None )
115
- g_idx = model_state .get (merge_names (prefix , "weight_g_idx" ), None )
116
- if scale is not None :
117
- # weight is quantized, compress it
118
- if isinstance (names_to_scheme [prefix ], tuple ):
119
- quant_args = names_to_scheme [prefix ][0 ]
120
- else :
121
- quant_args = names_to_scheme [prefix ]
122
-
123
- compressed_data = self .compress_weight (
124
- weight = value ,
125
- scale = scale ,
126
- zero_point = zp ,
127
- g_idx = g_idx ,
128
- quantization_args = quant_args ,
129
- device = "cpu" ,
130
- )
131
- for key , value in compressed_data .items ():
132
- compressed_dict [merge_names (prefix , key )] = value
133
- else :
134
- compressed_dict [name ] = value .to ("cpu" )
135
- # only save zp if asym and not packed zp
136
- elif is_weight_zp and (
137
- quant_args_zp .symmetric or self ._check_if_zp_pack_quantized (quant_args )
138
- ):
139
- continue
140
- # only save if asym
141
- elif is_input_zp and input_args_zp .symmetric :
142
- continue
143
- elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
144
- continue
145
121
else :
146
- compressed_dict [name ] = value .to ("cpu" )
122
+ # omit saving zero points for symmetric or packed quantization
123
+ if name .endswith ("zero_point" ) and self ._skip_zp (name , names_to_scheme ):
124
+ continue
125
+
126
+ # omit saving for g_idx if uninitialized
127
+ # TODO: does this case actually occur?
128
+ elif name .endswith ("g_idx" ) and torch .any (value <= - 1 ):
129
+ continue
130
+
131
+ compressed_dict [name ] = value .to (save_device )
147
132
148
133
return compressed_dict
149
134
150
- def _check_if_zp_pack_quantized (self , quant_args ):
135
+ def _skip_zp (
136
+ self , name : str , names_to_scheme : Dict [str , QuantizationScheme ]
137
+ ) -> bool :
151
138
from compressed_tensors .compressors import PackedQuantizationCompressor
152
139
153
- if isinstance (self , PackedQuantizationCompressor ):
154
- if not quant_args .symmetric and quant_args .strategy in [
155
- QuantizationStrategy .GROUP .value ,
156
- QuantizationStrategy .CHANNEL .value ,
157
- ]:
158
- return True
159
- return False
140
+ module_name , zp_name = name .rsplit ("." , 1 ) if "." in name else ("" , name )
141
+ scheme = names_to_scheme [module_name ]
142
+
143
+ if zp_name == "weight_zero_point" :
144
+ args = scheme .weights
145
+ if zp_name == "input_zero_point" :
146
+ args = scheme .input_activations
147
+ if zp_name == "output_zero_point" :
148
+ args = scheme .output_activations
149
+
150
+ symmetric = args .symmetric
151
+ packable_strategies = [
152
+ QuantizationStrategy .GROUP .value ,
153
+ QuantizationStrategy .CHANNEL .value ,
154
+ ]
155
+ packed = (
156
+ isinstance (self , PackedQuantizationCompressor )
157
+ and args .strategy in packable_strategies
158
+ )
159
+
160
+ return symmetric or packed
160
161
161
162
def decompress (
162
163
self ,
163
164
path_to_model_or_tensors : Union [str , Path , Dict [str , Any ]],
164
- names_to_scheme : Dict [str , QuantizationArgs ],
165
+ names_to_scheme : Dict [str , QuantizationScheme ],
165
166
device : str = "cpu" ,
166
167
) -> Generator [Tuple [str , Tensor ], None , None ]:
167
168
"""
@@ -170,8 +171,9 @@ def decompress(
170
171
dense state dict
171
172
:param path_to_model_or_tensors: path to compressed safetensors model (directory
172
173
with one or more safetensors files) or compressed tensors file
173
- :param names_to_scheme: quantization args for each quantized weight
174
- :param device: optional device to load intermediate weights into
174
+ :param names_to_scheme: quantization scheme for each quantized weight
175
+ :param device: optional device to load intermediate weights into (must be `str`,
176
+ not `torch.device`)
175
177
:return: compressed state dict
176
178
"""
177
179
if isinstance (path_to_model_or_tensors , (str , Path )):
@@ -184,7 +186,12 @@ def decompress(
184
186
path_to_model_or_tensors , names_to_scheme
185
187
)
186
188
187
- def _decompress_from_path (self , path_to_model , names_to_scheme , device ):
189
+ def _decompress_from_path (
190
+ self ,
191
+ path_to_model : Union [str , Path , Dict [str , Any ]],
192
+ names_to_scheme : Dict [str , QuantizationScheme ],
193
+ device : str ,
194
+ ):
188
195
weight_mappings = get_nested_weight_mappings (
189
196
path_to_model , self .compression_param_names
190
197
)
@@ -195,7 +202,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
195
202
with safe_open (safe_path , framework = "pt" , device = device ) as f :
196
203
weight_data [param_name ] = f .get_tensor (full_name )
197
204
if "weight_scale" in weight_data :
198
- quant_args = names_to_scheme [weight_name ]
205
+ quant_args = names_to_scheme [weight_name ]. weights
199
206
decompressed = self .decompress_weight (
200
207
compressed_data = weight_data , quantization_args = quant_args
201
208
)
0 commit comments