Skip to content

Commit 00417b8

Browse files
xiaowangintelpytorchmergebot
authored andcommitted
Align scale dtype with model precision in GPTQ (#2403)
Summary: For general usage, align the data type of scale with model precision Instead of the default use of bfloat16. Pull Request resolved: #2403 Approved by: https://github.com/liangan1, https://github.com/jerryzh168
1 parent 2025b75 commit 00417b8

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

torchao/quantization/GPTQ/GPTQ.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __torch_function__(
295295
SQNR(DQ, DQ_from_qtensor),
296296
)
297297

298-
qparams2 = cls.get_qparams_func(W)
298+
qparams2 = cls.get_qparams_func(W, W.dtype)
299299
Q2 = cls.quantize_func(W, qparams2)
300300
DQ2 = cls.dequantize_func(Q2, qparams2).to(W.dtype)
301301
old_q_out = (
@@ -444,7 +444,9 @@ def faster_quant(cls, H, W, device):
444444
group_end = min(group_start + group_size, columns)
445445
if group_start % group_size == 0:
446446
# needed for when group_size == columns so only calculate qparams once
447-
cur_qparams = cls.get_qparams_func(W[:, group_start:group_end])
447+
cur_qparams = cls.get_qparams_func(
448+
W[:, group_start:group_end], orig_dtype
449+
)
448450
all_qparams.append(cur_qparams)
449451

450452
for index in range(group_start, group_end): # within each group
@@ -679,10 +681,11 @@ def __init__(
679681
else:
680682
self.zero_point_domain = ZeroPointDomain.FLOAT
681683

682-
self.get_qparams_func = lambda w: get_groupwise_affine_qparams(
684+
self.get_qparams_func = lambda w, precision: get_groupwise_affine_qparams(
683685
w,
684686
n_bit,
685687
group_size,
688+
dtype=precision,
686689
zero_point_domain=self.zero_point_domain,
687690
)
688691
self.quantize_func = (

0 commit comments

Comments
 (0)