12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import math
15
- from typing import Dict , Optional , Tuple
15
+ from typing import Dict , Literal , Optional , Tuple , Union
16
16
17
17
import numpy as np
18
18
import torch
21
21
BaseQuantizationCompressor ,
22
22
)
23
23
from compressed_tensors .config import CompressionFormat
24
- from compressed_tensors .quantization import QuantizationArgs
24
+ from compressed_tensors .quantization import QuantizationArgs , QuantizationStrategy
25
25
from compressed_tensors .quantization .lifecycle .forward import dequantize , quantize
26
26
from compressed_tensors .quantization .utils import can_quantize
27
27
from torch import Tensor
@@ -65,10 +65,26 @@ def compression_param_info(
65
65
"""
66
66
pack_factor = 32 // quantization_args .num_bits
67
67
packed_size = math .ceil (weight_shape [1 ] / pack_factor )
68
- return {
68
+ packed_size_zp = math .ceil (weight_shape [0 ] / pack_factor )
69
+ output = {
69
70
"weight_packed" : (torch .Size ((weight_shape [0 ], packed_size )), torch .int32 ),
70
71
"weight_shape" : (torch .Size ((2 ,)), torch .int32 ),
71
72
}
73
+ if not quantization_args .symmetric and quantization_args .strategy in [
74
+ QuantizationStrategy .GROUP .value ,
75
+ QuantizationStrategy .CHANNEL .value ,
76
+ ]:
77
+ zp_factor = (
78
+ quantization_args .group_size
79
+ if quantization_args .strategy == QuantizationStrategy .GROUP .value
80
+ else weight_shape [- 1 ]
81
+ )
82
+
83
+ output ["weight_zero_point" ] = (
84
+ torch .Size ((packed_size_zp , weight_shape [- 1 ] // zp_factor )),
85
+ torch .int32 ,
86
+ )
87
+ return output
72
88
73
89
def compress_weight (
74
90
self ,
@@ -104,6 +120,7 @@ def compress_weight(
104
120
quantized_weight = weight
105
121
106
122
packed_weight = pack_to_int32 (quantized_weight , quantization_args .num_bits )
123
+
107
124
weight_shape = torch .tensor (weight .shape )
108
125
if device is not None :
109
126
packed_weight = packed_weight .to (device )
@@ -112,6 +129,15 @@ def compress_weight(
112
129
compressed_dict ["weight_shape" ] = weight_shape
113
130
compressed_dict ["weight_packed" ] = packed_weight
114
131
132
+ # We typically don't compress zp; apart from when using the packed_compressor and when storing group/channel zp
133
+ if not quantization_args .symmetric and quantization_args .strategy in [
134
+ QuantizationStrategy .GROUP .value ,
135
+ QuantizationStrategy .CHANNEL .value ,
136
+ ]:
137
+ packed_zp = pack_to_int32 (
138
+ zero_point , quantization_args .num_bits , packed_dim = 0
139
+ )
140
+ compressed_dict ["weight_zero_point" ] = packed_zp
115
141
return compressed_dict
116
142
117
143
def decompress_weight (
@@ -133,14 +159,33 @@ def decompress_weight(
133
159
original_shape = torch .Size (compressed_data ["weight_shape" ])
134
160
num_bits = quantization_args .num_bits
135
161
unpacked = unpack_from_int32 (weight , num_bits , original_shape )
162
+
163
+ # NOTE: this will fail decompression as we don't currently handle packed zp on decompression
164
+ if not quantization_args .symmetric and quantization_args .strategy in [
165
+ QuantizationStrategy .GROUP .value ,
166
+ QuantizationStrategy .CHANNEL .value ,
167
+ ]:
168
+ raise ValueError (
169
+ "Decompression of packed zero points is currently not supported"
170
+ )
171
+ assert zero_point is not None
172
+ original_zp_shape = (original_shape [0 ], scale .shape [- 1 ])
173
+ zero_point = unpack_from_int32 (
174
+ zero_point , num_bits , original_zp_shape , packed_dim = 0
175
+ )
176
+
136
177
decompressed_weight = dequantize (
137
178
x_q = unpacked , scale = scale , zero_point = zero_point , g_idx = g_idx
138
179
)
139
180
140
181
return decompressed_weight
141
182
142
183
143
- def pack_to_int32 (value : torch .Tensor , num_bits : int ) -> torch .Tensor :
184
+ def pack_to_int32 (
185
+ value : torch .Tensor ,
186
+ num_bits : int ,
187
+ packed_dim : Union [Literal [0 ], Literal [1 ]] = 1 ,
188
+ ) -> torch .Tensor :
144
189
"""
145
190
Packs a tensor of quantized weights stored in int8 into int32s with padding
146
191
@@ -176,22 +221,30 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
176
221
pack_factor = 32 // num_bits
177
222
178
223
# pad input tensor and initialize packed output
179
- packed_size = math .ceil (value .shape [1 ] / pack_factor )
180
- padding = packed_size * pack_factor - value .shape [1 ]
224
+ packed_size = math .ceil (value .shape [packed_dim ] / pack_factor )
225
+ padding = packed_size * pack_factor - value .shape [packed_dim ]
181
226
value = np .pad (value , pad_width = [(0 , 0 ), (0 , padding )], constant_values = 0 )
182
227
183
228
# pack values
184
- packed = np .zeros ((value .shape [0 ], packed_size ), dtype = np .uint32 )
185
- for i in range (pack_factor ):
186
- packed |= value [:, i ::pack_factor ] << num_bits * i
229
+ if packed_dim == 1 :
230
+ packed = np .zeros ((value .shape [0 ], packed_size ), dtype = np .uint32 )
231
+ for i in range (pack_factor ):
232
+ packed |= value [:, i ::pack_factor ] << num_bits * i
233
+ else :
234
+ packed = np .zeros ((packed_size , value .shape [1 ]), dtype = np .uint32 )
235
+ for i in range (pack_factor ):
236
+ packed |= value [i ::pack_factor , :] << num_bits * i
187
237
188
238
# convert back to signed and torch
189
239
packed = np .ascontiguousarray (packed ).view (np .int32 )
190
240
return torch .from_numpy (packed )
191
241
192
242
193
243
def unpack_from_int32 (
194
- value : torch .Tensor , num_bits : int , shape : torch .Size
244
+ value : torch .Tensor ,
245
+ num_bits : int ,
246
+ shape : torch .Size ,
247
+ packed_dim : Union [Literal [0 ], Literal [1 ]] = 1 ,
195
248
) -> torch .Tensor :
196
249
"""
197
250
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
@@ -216,17 +269,31 @@ def unpack_from_int32(
216
269
217
270
# unpack
218
271
mask = (1 << num_bits ) - 1
219
- unpacked = torch .zeros (
220
- (value .shape [0 ], value .shape [1 ] * pack_factor ),
221
- device = value .device ,
222
- dtype = torch .int32 ,
223
- )
224
- for i in range (pack_factor ):
225
- unpacked [:, i ::pack_factor ] = (value >> (num_bits * i )) & mask
226
-
227
- # remove padding
228
- original_row_size = int (shape [1 ])
229
- unpacked = unpacked [:, :original_row_size ]
272
+
273
+ if packed_dim == 1 :
274
+ unpacked = torch .zeros (
275
+ (value .shape [0 ], value .shape [1 ] * pack_factor ),
276
+ device = value .device ,
277
+ dtype = torch .int32 ,
278
+ )
279
+ for i in range (pack_factor ):
280
+ unpacked [:, i ::pack_factor ] = (value >> (num_bits * i )) & mask
281
+
282
+ # remove padding
283
+ original_row_size = int (shape [1 ])
284
+ unpacked = unpacked [:, :original_row_size ]
285
+ else :
286
+ unpacked = torch .zeros (
287
+ (value .shape [0 ] * pack_factor , value .shape [1 ]),
288
+ device = value .device ,
289
+ dtype = torch .int32 ,
290
+ )
291
+ for i in range (pack_factor ):
292
+ unpacked [i ::pack_factor , :] = (value >> (num_bits * i )) & mask
293
+
294
+ # remove padding
295
+ original_row_size = int (shape [0 ])
296
+ unpacked = unpacked [:original_row_size , :]
230
297
231
298
# bits are packed in unsigned format, reformat to signed
232
299
# update the value range from unsigned to signed
0 commit comments