Skip to content

Commit 94853d4

Browse files
authored
Fix GPTQQuantizer test (#53)
Summary: att Test Plan: python test/quantization/test_quant_api.py -k test_gptq Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 9765bc6 commit 94853d4

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

test/quantization/test_quant_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,10 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
130130
compiled = m(*example_inputs)
131131
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
132132

133-
@unittest.skip("skipping for now and will fix in next PR")
134133
def test_gptq(self):
135134
# should be similar to TorchCompileDynamicQuantizer
136135
precision = torch.bfloat16
137-
device = "cuda"
136+
device = "cpu"
138137
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
139138
model = Transformer.from_name(checkpoint_path.parent.name)
140139
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)

torchao/quantization/quant_api.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"""
1717

1818
import torch
19+
import torch.nn.functional as F
20+
import torch.nn as nn
1921
from .dynamic_quant import (
2022
DynamicallyPerAxisQuantizedLinear,
2123
)
@@ -28,7 +30,11 @@
2830
from .weight_only import (
2931
WeightOnlyInt8QuantLinear,
3032
)
31-
from typing import Dict
33+
from .quant_primitives import (
34+
get_group_qparams_symmetric,
35+
per_token_dynamic_quant,
36+
)
37+
from typing import Dict, Tuple
3238

3339
__all__ = [
3440
"apply_weight_only_int8_quant",
@@ -382,7 +388,7 @@ def quantize(
382388
self.pad_calibration_inputs,
383389
)
384390
model = self._convert_for_runtime(model)
385-
model.load_state_dict(state_dict)
391+
model.load_state_dict(state_dict, strict=False)
386392
return model
387393

388394

@@ -465,11 +471,49 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
465471
self.precision,
466472
)
467473

474+
from math import gcd
475+
from functools import reduce
476+
477+
478+
def find_multiple(n: int, *args: Tuple[int]) -> int:
479+
# TODO: this change is reverted right now in gpt-fast
480+
k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9]
481+
if n % k == 0:
482+
return n
483+
return n + k - (n % k)
484+
468485

469486
def _check_linear_int4_k(k, group_size=1):
470487
return k % group_size == 0
471488

472489

490+
def _calc_padded_size_linear_int4(k, groupsize=1):
491+
return find_multiple(k, groupsize)
492+
493+
494+
def pack_scales_and_zeros(scales, zeros, precision=torch.float32):
495+
assert scales.shape == zeros.shape
496+
assert scales.dtype == precision
497+
assert zeros.dtype == precision
498+
return (
499+
torch.cat(
500+
[
501+
scales.reshape(scales.size(0), scales.size(1), 1),
502+
zeros.reshape(zeros.size(0), zeros.size(1), 1),
503+
],
504+
2,
505+
)
506+
.transpose(0, 1)
507+
.contiguous()
508+
)
509+
510+
511+
def unpack_scales_and_zeros(scales_and_zeros):
512+
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
513+
assert scales_and_zeros.dtype == torch.float
514+
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
515+
516+
473517
def replace_linear_8da4w(
474518
module,
475519
group_size,
@@ -554,25 +598,27 @@ def __init__(
554598
]
555599
# skip unless padding_allowed=True or its correctly sized
556600
self.skip_layer_func = lambda linear_weight: not (
557-
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
601+
_check_linear_int4_k(linear_weight.shape[-1], groupsize)
558602
or padding_allowed
559603
)
560604

561605
# we need to do the padding here, both for q and the qparams if necessary
562606
def make_names_and_values_dict_func(q, qparams):
563607
k = q.shape[1]
564-
new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles)
608+
new_k = _calc_padded_size_linear_int4(k, groupsize)
565609
# how much we need to pad the weight
566610
delta_k = new_k - q.shape[1]
567611
final_q = F.pad(q, pad=(0, delta_k))
568-
scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision)
612+
scales = qparams[0].to(self.precision)
613+
zeros = qparams[1].to(self.precision)
614+
# scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision)
569615
# how many new groups we need for padded weight
570-
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
616+
# delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
571617
# TODO: split scales and zero_points
572-
final_s_and_z = F.pad(
573-
scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
574-
)
575-
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
618+
# final_s_and_z = F.pad(
619+
# scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
620+
# )
621+
return {"weight": final_q, "scales": scales, "zeros": zeros}
576622

577623
self.make_names_and_values_dict_func = make_names_and_values_dict_func
578624
super().__init__()

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
quantized_decomposed_lib,
1313
)
1414
from torch.library import impl
15+
from typing import Tuple
1516

1617
__all__ = [
1718
"safe_int_mm",

0 commit comments

Comments
 (0)