Skip to content

Commit a8704f8

Browse files
authored
[test] get_group_qparams_symmetric matches observer (#94)
1 parent 8c62eb0 commit a8704f8

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: ignore-errors
8+
# This test takes a long time to run
9+
import unittest
10+
import torch
11+
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
12+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
13+
14+
class TestQuantPrimitives(unittest.TestCase):
15+
SEED = 123
16+
17+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower")
18+
def test_get_group_qparams_symmetric(self):
19+
"""
20+
Test that `get_group_qparams_symmetric` produces the exact same scales as
21+
`PerChannelMinMaxObserver._calculate_qparams`.
22+
"""
23+
n_bit = 4
24+
qmin = -(2 ** (n_bit - 1))
25+
qmax = 2 ** (n_bit - 1) - 1
26+
eps = torch.finfo(torch.float32).eps
27+
groupsize = 256
28+
torch.manual_seed(self.SEED)
29+
weight = torch.randn(100, 256).to(torch.float16)
30+
31+
# calculate observer scales
32+
obs = torch.ao.quantization.PerChannelMinMaxObserver(
33+
ch_axis=0,
34+
qscheme=torch.per_channel_symmetric,
35+
quant_min=qmin,
36+
quant_max=qmax,
37+
# This is needed to ensure `min_val` and `max_val` are fp16,
38+
# otherwise they default to fp32 and the qparams will be slightly off
39+
factory_kwargs={"dtype": torch.float16}
40+
)
41+
obs(weight)
42+
(scale_obs, _) = obs.calculate_qparams()
43+
scale_obs = scale_obs.reshape(weight.shape[0], -1)
44+
45+
# assert that scales are identical
46+
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
47+
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)
48+
49+
if __name__ == "__main__":
50+
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def groupwise_affine_dequantize_tensor(
470470
)
471471

472472

473+
# TODO: replace this with torch.ao.quantization.PerChannelMinMaxObserver
473474
def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32):
474475
# needed for GPTQ with padding
475476
if groupsize > w.shape[-1]:

0 commit comments

Comments
 (0)