Skip to content

Commit ed3ac7c

Browse files
authored
[Compressor] Update packed compressor to support zp packing (#296)
* update packed compressor * update * fix packing conditions * update condition * update * Delete src/compressed_tensors/load_weights.py * clean-up condition; add error message for decompression * update * add test, fix condition * fix dtype
1 parent 3cbe247 commit ed3ac7c

File tree

4 files changed

+164
-25
lines changed

4 files changed

+164
-25
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def compress(
382382
compressed_state_dict = self.quantization_compressor.compress(
383383
state_dict, names_to_scheme=quantized_modules_to_args
384384
)
385+
385386
if self.quantization_config.format != CompressionFormat.dense.value:
386387
self.quantization_config.quantization_status = (
387388
QuantizationStatus.COMPRESSED

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
21-
from compressed_tensors.quantization import QuantizationArgs
21+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
2222
from compressed_tensors.utils import (
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
@@ -132,8 +132,10 @@ def compress(
132132
compressed_dict[merge_names(prefix, key)] = value
133133
else:
134134
compressed_dict[name] = value.to("cpu")
135-
# only save if asym
136-
elif is_weight_zp and quant_args_zp.symmetric:
135+
# only save zp if asym and not packed zp
136+
elif is_weight_zp and (
137+
quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
138+
):
137139
continue
138140
# only save if asym
139141
elif is_input_zp and input_args_zp.symmetric:
@@ -145,6 +147,17 @@ def compress(
145147

146148
return compressed_dict
147149

150+
def _check_if_zp_pack_quantized(self, quant_args):
151+
from compressed_tensors.compressors import PackedQuantizationCompressor
152+
153+
if isinstance(self, PackedQuantizationCompressor):
154+
if not quant_args.symmetric and quant_args.strategy in [
155+
QuantizationStrategy.GROUP.value,
156+
QuantizationStrategy.CHANNEL.value,
157+
]:
158+
return True
159+
return False
160+
148161
def decompress(
149162
self,
150163
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Dict, Optional, Tuple
15+
from typing import Dict, Literal, Optional, Tuple, Union
1616

1717
import numpy as np
1818
import torch
@@ -21,7 +21,7 @@
2121
BaseQuantizationCompressor,
2222
)
2323
from compressed_tensors.config import CompressionFormat
24-
from compressed_tensors.quantization import QuantizationArgs
24+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
2525
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
2626
from compressed_tensors.quantization.utils import can_quantize
2727
from torch import Tensor
@@ -65,10 +65,26 @@ def compression_param_info(
6565
"""
6666
pack_factor = 32 // quantization_args.num_bits
6767
packed_size = math.ceil(weight_shape[1] / pack_factor)
68-
return {
68+
packed_size_zp = math.ceil(weight_shape[0] / pack_factor)
69+
output = {
6970
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
7071
"weight_shape": (torch.Size((2,)), torch.int32),
7172
}
73+
if not quantization_args.symmetric and quantization_args.strategy in [
74+
QuantizationStrategy.GROUP.value,
75+
QuantizationStrategy.CHANNEL.value,
76+
]:
77+
zp_factor = (
78+
quantization_args.group_size
79+
if quantization_args.strategy == QuantizationStrategy.GROUP.value
80+
else weight_shape[-1]
81+
)
82+
83+
output["weight_zero_point"] = (
84+
torch.Size((packed_size_zp, weight_shape[-1] // zp_factor)),
85+
torch.int32,
86+
)
87+
return output
7288

7389
def compress_weight(
7490
self,
@@ -104,6 +120,7 @@ def compress_weight(
104120
quantized_weight = weight
105121

106122
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
123+
107124
weight_shape = torch.tensor(weight.shape)
108125
if device is not None:
109126
packed_weight = packed_weight.to(device)
@@ -112,6 +129,15 @@ def compress_weight(
112129
compressed_dict["weight_shape"] = weight_shape
113130
compressed_dict["weight_packed"] = packed_weight
114131

132+
# We typically don't compress zp; apart from when using the packed_compressor and when storing group/channel zp
133+
if not quantization_args.symmetric and quantization_args.strategy in [
134+
QuantizationStrategy.GROUP.value,
135+
QuantizationStrategy.CHANNEL.value,
136+
]:
137+
packed_zp = pack_to_int32(
138+
zero_point, quantization_args.num_bits, packed_dim=0
139+
)
140+
compressed_dict["weight_zero_point"] = packed_zp
115141
return compressed_dict
116142

117143
def decompress_weight(
@@ -133,14 +159,33 @@ def decompress_weight(
133159
original_shape = torch.Size(compressed_data["weight_shape"])
134160
num_bits = quantization_args.num_bits
135161
unpacked = unpack_from_int32(weight, num_bits, original_shape)
162+
163+
# NOTE: this will fail decompression as we don't currently handle packed zp on decompression
164+
if not quantization_args.symmetric and quantization_args.strategy in [
165+
QuantizationStrategy.GROUP.value,
166+
QuantizationStrategy.CHANNEL.value,
167+
]:
168+
raise ValueError(
169+
"Decompression of packed zero points is currently not supported"
170+
)
171+
assert zero_point is not None
172+
original_zp_shape = (original_shape[0], scale.shape[-1])
173+
zero_point = unpack_from_int32(
174+
zero_point, num_bits, original_zp_shape, packed_dim=0
175+
)
176+
136177
decompressed_weight = dequantize(
137178
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
138179
)
139180

140181
return decompressed_weight
141182

142183

143-
def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
184+
def pack_to_int32(
185+
value: torch.Tensor,
186+
num_bits: int,
187+
packed_dim: Union[Literal[0], Literal[1]] = 1,
188+
) -> torch.Tensor:
144189
"""
145190
Packs a tensor of quantized weights stored in int8 into int32s with padding
146191
@@ -176,22 +221,30 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
176221
pack_factor = 32 // num_bits
177222

178223
# pad input tensor and initialize packed output
179-
packed_size = math.ceil(value.shape[1] / pack_factor)
180-
padding = packed_size * pack_factor - value.shape[1]
224+
packed_size = math.ceil(value.shape[packed_dim] / pack_factor)
225+
padding = packed_size * pack_factor - value.shape[packed_dim]
181226
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
182227

183228
# pack values
184-
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
185-
for i in range(pack_factor):
186-
packed |= value[:, i::pack_factor] << num_bits * i
229+
if packed_dim == 1:
230+
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
231+
for i in range(pack_factor):
232+
packed |= value[:, i::pack_factor] << num_bits * i
233+
else:
234+
packed = np.zeros((packed_size, value.shape[1]), dtype=np.uint32)
235+
for i in range(pack_factor):
236+
packed |= value[i::pack_factor, :] << num_bits * i
187237

188238
# convert back to signed and torch
189239
packed = np.ascontiguousarray(packed).view(np.int32)
190240
return torch.from_numpy(packed)
191241

192242

193243
def unpack_from_int32(
194-
value: torch.Tensor, num_bits: int, shape: torch.Size
244+
value: torch.Tensor,
245+
num_bits: int,
246+
shape: torch.Size,
247+
packed_dim: Union[Literal[0], Literal[1]] = 1,
195248
) -> torch.Tensor:
196249
"""
197250
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
@@ -216,17 +269,31 @@ def unpack_from_int32(
216269

217270
# unpack
218271
mask = (1 << num_bits) - 1
219-
unpacked = torch.zeros(
220-
(value.shape[0], value.shape[1] * pack_factor),
221-
device=value.device,
222-
dtype=torch.int32,
223-
)
224-
for i in range(pack_factor):
225-
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
226-
227-
# remove padding
228-
original_row_size = int(shape[1])
229-
unpacked = unpacked[:, :original_row_size]
272+
273+
if packed_dim == 1:
274+
unpacked = torch.zeros(
275+
(value.shape[0], value.shape[1] * pack_factor),
276+
device=value.device,
277+
dtype=torch.int32,
278+
)
279+
for i in range(pack_factor):
280+
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
281+
282+
# remove padding
283+
original_row_size = int(shape[1])
284+
unpacked = unpacked[:, :original_row_size]
285+
else:
286+
unpacked = torch.zeros(
287+
(value.shape[0] * pack_factor, value.shape[1]),
288+
device=value.device,
289+
dtype=torch.int32,
290+
)
291+
for i in range(pack_factor):
292+
unpacked[i::pack_factor, :] = (value >> (num_bits * i)) & mask
293+
294+
# remove padding
295+
original_row_size = int(shape[0])
296+
unpacked = unpacked[:original_row_size, :]
230297

231298
# bits are packed in unsigned format, reformat to signed
232299
# update the value range from unsigned to signed

tests/test_compressors/quantized_compressors/test_pack_quant.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QuantizationConfig,
3030
QuantizationScheme,
3131
QuantizationStatus,
32+
QuantizationStrategy,
3233
apply_quantization_config,
3334
)
3435
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
@@ -76,7 +77,7 @@ def test_quant_format(shape):
7677
dense_state_dict = {
7778
"dummy.weight": torch.rand(shape),
7879
"dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32),
79-
"dummy.weight_zero_point": torch.tensor(0, dtype=torch.int32),
80+
"dummy.weight_zero_point": torch.tensor(0, dtype=torch.int8),
8081
}
8182
quant_config = get_dummy_quant_config()
8283

@@ -203,6 +204,63 @@ def test_reload_match(tmp_path, num_bits):
203204
shutil.rmtree(tmp_path)
204205

205206

207+
@pytest.mark.parametrize(
208+
"strategy",
209+
{QuantizationStrategy.GROUP, QuantizationStrategy.CHANNEL},
210+
)
211+
def test_asymmetric_packed_support(strategy):
212+
shape = (1024, 1024)
213+
214+
group_size = None
215+
if strategy == QuantizationStrategy.GROUP:
216+
group_size = 128
217+
218+
if strategy == QuantizationStrategy.CHANNEL:
219+
expected_shape = (shape[0], 1)
220+
elif strategy == QuantizationStrategy.GROUP:
221+
num_groups = shape[1] // group_size
222+
expected_shape = (shape[0], max(num_groups, 1))
223+
224+
dense_state_dict = {
225+
"dummy.weight": torch.rand(shape),
226+
"dummy.weight_scale": torch.rand(expected_shape).to(torch.float32),
227+
"dummy.weight_zero_point": torch.rand(expected_shape).to(torch.int8),
228+
}
229+
230+
quant_config = get_dummy_quant_config(
231+
strategy=strategy.value, symmetric=False, group_size=group_size
232+
)
233+
234+
compressor = PackedQuantizationCompressor(config=quant_config)
235+
quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights}
236+
compressed_state_dict = compressor.compress(
237+
dense_state_dict, names_to_scheme=quantized_modules_to_args
238+
)
239+
240+
# compressed state_dict adds one entry for shape
241+
assert len(dense_state_dict) + 1 == len(compressed_state_dict)
242+
assert compressed_state_dict["dummy.weight_packed"].dtype == torch.int32
243+
assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32
244+
assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32
245+
246+
# check weight compressed and packed
247+
expected_rows = shape[0]
248+
expected_columns = math.ceil(shape[1] / 8) # round each row up to nearest int32
249+
assert compressed_state_dict["dummy.weight_packed"].shape == (
250+
expected_rows,
251+
expected_columns,
252+
)
253+
assert torch.equal(compressed_state_dict["dummy.weight_shape"], torch.tensor(shape))
254+
255+
# check zp compressed and packed
256+
packed_size_zp = math.ceil(shape[0] / 8)
257+
zp_factor = group_size if strategy == QuantizationStrategy.GROUP else shape[-1]
258+
assert compressed_state_dict["dummy.weight_zero_point"].shape == (
259+
packed_size_zp,
260+
shape[-1] // zp_factor,
261+
)
262+
263+
206264
@pytest.mark.parametrize(
207265
"actorder",
208266
[

0 commit comments

Comments
 (0)