Skip to content

Commit b9c536d

Browse files
horheynmmgoin
andauthored
Use faster operations on packed-quantized, add tests (#211)
* bitwise op to make it faster, add tests * Update tests/test_compressors/quantized_compressors/test_pack_quant.py Co-authored-by: Michael Goin <michael@neuralmagic.com> * rahul-tuli comment * comments --------- Co-authored-by: Michael Goin <michael@neuralmagic.com>
1 parent 890608d commit b9c536d

File tree

2 files changed

+185
-6
lines changed

2 files changed

+185
-6
lines changed

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,20 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
138138
"""
139139
Packs a tensor of quantized weights stored in int8 into int32s with padding
140140
141+
Pseudocode:
142+
1. Shift wrt num_bits to convert to unsigned. num_bits=8
143+
[1,2] -> [129, 130]
144+
2. Pad to fill in 32 bits
145+
[129, 130] -> [129, 130, 0, 0]
146+
3. convert to binary align in order
147+
[129, 130, 0, 0] -> 00000000 00000000 10000010 10000001
148+
4. convert aligned binary to number
149+
00000000000000001000001010000001 -> 33409
150+
5. covert back to uint32
151+
33409 -> 33409
152+
141153
:param value: tensor to pack
142-
:param num_bits: number of bits used to store underlying data
154+
:param num_bits: number of bits used to store underlying data, must be at least 1
143155
:returns: packed int32 tensor
144156
"""
145157
if value.dtype is not torch.int8:
@@ -148,19 +160,22 @@ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
148160
if num_bits > 8:
149161
raise ValueError("Packing is only supported for less than 8 bits")
150162

163+
if num_bits < 1:
164+
raise ValueError(f"num_bits must be at least 1, got {num_bits}")
165+
151166
# convert to unsigned for packing
152-
offset = pow(2, num_bits) // 2
167+
offset = 1 << (num_bits - 1)
153168
value = (value + offset).to(torch.uint8)
154169
value = value.cpu().numpy().astype(np.uint32)
155170
pack_factor = 32 // num_bits
156171

157172
# pad input tensor and initialize packed output
158173
packed_size = math.ceil(value.shape[1] / pack_factor)
159-
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
160-
padding = packed.shape[1] * pack_factor - value.shape[1]
174+
padding = packed_size * pack_factor - value.shape[1]
161175
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
162176

163177
# pack values
178+
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
164179
for i in range(pack_factor):
165180
packed |= value[:, i::pack_factor] << num_bits * i
166181

@@ -174,7 +189,9 @@ def unpack_from_int32(
174189
) -> torch.Tensor:
175190
"""
176191
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
177-
original their bit range
192+
original bit range.
193+
194+
Return tensors in int8
178195
179196
:param value: tensor to upack
180197
:param num_bits: number of bits to unpack each data point into
@@ -192,7 +209,7 @@ def unpack_from_int32(
192209
pack_factor = 32 // num_bits
193210

194211
# unpack
195-
mask = pow(2, num_bits) - 1
212+
mask = (1 << num_bits) - 1
196213
unpacked = torch.zeros(
197214
(value.shape[0], value.shape[1] * pack_factor),
198215
device=value.device,

tests/test_compressors/quantized_compressors/test_pack_quant.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,165 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
250250
assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"])
251251

252252
shutil.rmtree(tmp_path)
253+
254+
255+
@pytest.mark.parametrize(
256+
"num_bits,values,expected_values",
257+
[
258+
(
259+
4,
260+
torch.tensor([[1]]),
261+
torch.tensor([[9]], dtype=torch.int32),
262+
),
263+
(
264+
8,
265+
torch.tensor([[1]]),
266+
torch.tensor([[129]], dtype=torch.int32),
267+
),
268+
# 0000 0000 0000 0000 1100 1011 1010 1001
269+
(4, torch.tensor([[1, 2, 3, 4]]), torch.tensor([[52137]], dtype=torch.int32)),
270+
# 0111 0110 0101 0100 0011 0010 0001 0000
271+
(
272+
4,
273+
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]]),
274+
torch.tensor([[1985229328]], dtype=torch.int32),
275+
),
276+
# 10000100 10000011 10000010 10000001
277+
(
278+
8,
279+
torch.tensor([[1, 2, 3, 4]]),
280+
torch.tensor([[-2071756159]], dtype=torch.int32),
281+
),
282+
# 00000011 00000010 00000001 00000000
283+
(
284+
8,
285+
torch.tensor([[-128, -127, -126, -125]]),
286+
torch.tensor([[50462976]], dtype=torch.int32),
287+
),
288+
(
289+
4,
290+
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]]),
291+
torch.tensor([[1985229328, 52137]], dtype=torch.int32),
292+
),
293+
(
294+
4,
295+
torch.tensor(
296+
[
297+
[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8],
298+
[1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1],
299+
]
300+
),
301+
torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32),
302+
),
303+
(
304+
8,
305+
torch.tensor(
306+
[
307+
[1, 2, 3, 4],
308+
[-128, -127, -126, -125],
309+
]
310+
),
311+
torch.tensor([[-2071756159], [50462976]], dtype=torch.int32),
312+
),
313+
(
314+
8,
315+
torch.tensor(
316+
[
317+
[1, 2, 3, 4, -128, -127, -126, -125],
318+
[-128, -127, -126, -125, 1, 2, 3, 4],
319+
]
320+
),
321+
torch.tensor(
322+
[[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32
323+
),
324+
),
325+
],
326+
)
327+
def test_pack_to_int32(num_bits, values, expected_values):
328+
values = values.to(torch.int8)
329+
packed_values = pack_to_int32(values, num_bits)
330+
assert torch.equal(packed_values, expected_values)
331+
assert packed_values.dtype == expected_values.dtype
332+
333+
334+
@pytest.mark.parametrize(
335+
"num_bits,values,expected_tensor",
336+
[
337+
(
338+
4,
339+
torch.tensor([[9]], dtype=torch.int32),
340+
torch.tensor([[1]], dtype=torch.int8),
341+
),
342+
(
343+
8,
344+
torch.tensor([[129]], dtype=torch.int32),
345+
torch.tensor([[1]], dtype=torch.int8),
346+
),
347+
(
348+
4,
349+
torch.tensor([[52137]], dtype=torch.int32),
350+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int8),
351+
),
352+
(
353+
4,
354+
torch.tensor([[1985229328]], dtype=torch.int32),
355+
torch.tensor([[-8, -7, -6, -5, -4, -3, -2, -1]], dtype=torch.int8),
356+
),
357+
(
358+
8,
359+
torch.tensor([[-2071756159]], dtype=torch.int32),
360+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int8),
361+
),
362+
(
363+
8,
364+
torch.tensor([[50462976]], dtype=torch.int32),
365+
torch.tensor([[-128, -127, -126, -125]], dtype=torch.int8),
366+
),
367+
(
368+
4,
369+
torch.tensor([[1985229328, 52137]], dtype=torch.int32),
370+
torch.tensor(
371+
[[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]], dtype=torch.int8
372+
),
373+
),
374+
(
375+
4,
376+
torch.tensor([[1985229328, 52137], [52137, 1985229328]], dtype=torch.int32),
377+
torch.tensor(
378+
[
379+
[-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4, -8, -8, -8, -8],
380+
[1, 2, 3, 4, -8, -8, -8, -8, -8, -7, -6, -5, -4, -3, -2, -1],
381+
],
382+
dtype=torch.int8,
383+
),
384+
),
385+
(
386+
8,
387+
torch.tensor([[-2071756159], [50462976]], dtype=torch.int32),
388+
torch.tensor(
389+
[
390+
[1, 2, 3, 4],
391+
[-128, -127, -126, -125],
392+
],
393+
dtype=torch.int8,
394+
),
395+
),
396+
(
397+
8,
398+
torch.tensor(
399+
[[-2071756159, 50462976], [50462976, -2071756159]], dtype=torch.int32
400+
),
401+
torch.tensor(
402+
[
403+
[1, 2, 3, 4, -128, -127, -126, -125],
404+
[-128, -127, -126, -125, 1, 2, 3, 4],
405+
],
406+
dtype=torch.int8,
407+
),
408+
),
409+
],
410+
)
411+
def test_unpack_from_int32(num_bits, values, expected_tensor):
412+
unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape)
413+
assert torch.equal(unpacked_tensor, unpacked_tensor)
414+
assert unpacked_tensor.dtype == unpacked_tensor.dtype

0 commit comments

Comments
 (0)