Skip to content

Commit 45b39b1

Browse files
Set eps in end-to-end QAT flow (#2180)
* Set eps in end-to-end QAT flow **Summary:** This commit does two things: (1) Allow users to set eps in `FakeQuantizeConfig` (2) For other parts of the QAT flow, set eps to `torch.finfo(torch.float32).eps` for input linear activations to match the existing hardcoded input activation scale dtype (which is fp32) The motivation is to enable users who wish to lower their models to XNNPACK. This would require them to use the following combination of dtypes during training for end-to-end numerical match: - input activations: bf16 - input activation scales: fp32 - input activation eps: `torch.finfo(torch.float32).eps` - weight: bf16 - weight scales: bf16 - weight eps: `torch.finfo(torch.bfloat16).eps` However, today there is no way to specify the above in any of the QAT flows. For the recommended `FakeQuantizeConfig` flow, we always use `torch.finfo(x.dtype).eps`, where x is bf16 in this case, and there is no way for users to configure this. This is resolved by (1). For the legacy `Int8DynActInt4QATQuantizer` flow, we hardcode input activation scales to always use fp32 in #2085, but did not set the corresponding eps. Today, this also uses `torch.finfo(x.dtype).eps` by default, where x is bf16, and so we use the wrong eps value. This is resolved by (2). **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_eps python test/quantization/test_qat.py -k test_qat_8da4w_eps * up --------- Co-authored-by: Scott Roy <161522778+metascroy@users.noreply.github.com>
1 parent b95cf18 commit 45b39b1

File tree

9 files changed

+107
-6
lines changed

9 files changed

+107
-6
lines changed

test/quantization/test_qat.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,84 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
15131513
)
15141514
self.assertEqual(len(non_inf_sqnr), 0, fail_message)
15151515

1516+
@unittest.skipIf(
1517+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1518+
)
1519+
def test_fake_quantize_config_eps(self):
1520+
"""
1521+
Test that users can set arbitrary eps value in `FakeQuantizeConfig`.
1522+
"""
1523+
eps = 0.00123
1524+
x = torch.randn(2, 3).to(torch.float32)
1525+
scale, zp = choose_qparams_affine(
1526+
x,
1527+
mapping_type=MappingType.ASYMMETRIC,
1528+
block_size=(1, 3),
1529+
target_dtype=torch.int8,
1530+
quant_min=-128,
1531+
quant_max=127,
1532+
eps=eps,
1533+
)
1534+
expected_out = _fake_quantize_per_token(x, scale, zp, -128, 127)
1535+
config = FakeQuantizeConfig(
1536+
torch.int8,
1537+
"per_token",
1538+
is_symmetric=False,
1539+
eps=eps,
1540+
)
1541+
fake_quantizer = FakeQuantizer(config)
1542+
actual_out = fake_quantizer(x)
1543+
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
1544+
1545+
@unittest.skipIf(
1546+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1547+
)
1548+
def test_qat_8da4w_eps(self):
1549+
"""
1550+
Test that the 8da4w QAT flow uses the expected eps.
1551+
"""
1552+
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
1553+
from torchao.quantization.utils import per_token_dynamic_quant
1554+
1555+
group_size = 16
1556+
torch.manual_seed(self.SEED)
1557+
m = M()
1558+
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
1559+
1560+
# prepare
1561+
prepared_model = quantizer.prepare(m)
1562+
self.assertEqual(
1563+
prepared_model.linear1.activation_fake_quantizer.config.eps,
1564+
torch.finfo(torch.float32).eps,
1565+
)
1566+
1567+
# convert
1568+
converted_model = quantizer.convert(m)
1569+
x = m.example_inputs()[0]
1570+
_input = per_token_dynamic_quant(
1571+
x,
1572+
scale_dtype=torch.float32,
1573+
zero_point_dtype=torch.float32,
1574+
eps=torch.finfo(torch.float32).eps,
1575+
)
1576+
_weight_dq = dequantize_affine(
1577+
converted_model.linear1.weight,
1578+
(1, group_size),
1579+
converted_model.linear1.scales,
1580+
converted_model.linear1.zeros,
1581+
torch.int8,
1582+
quant_min=-8,
1583+
quant_max=7,
1584+
output_dtype=torch.float32,
1585+
)
1586+
expected_out = torch.nn.functional.linear(
1587+
_input,
1588+
_weight_dq,
1589+
converted_model.linear1.bias,
1590+
)
1591+
actual_out = converted_model.linear1(x)
1592+
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
1593+
15161594

15171595
if __name__ == "__main__":
15181596
unittest.main()

torchao/experimental/quant_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _get_q_dq_linear_patterns_replacements_and_filters(
8787
glbs["a_quant_max"] = None
8888
glbs["a_mapping_type"] = "ASYMMETRIC"
8989
glbs["a_scale_dtype"] = torch.float32
90-
glbs["a_eps"] = None
90+
glbs["a_eps"] = torch.finfo(torch.float32).eps
9191

9292
lcls = {}
9393

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def test_export_QDQLayout(self):
361361
self.assertTrue(torch.allclose(eager_results, exported_results))
362362

363363
expected_lines = [
364-
"torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, None, torch.float32, torch.int8)",
364+
"torch.ops.torchao.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int8, None, None, 1.1920928955078125e-07, torch.float32, torch.int8)",
365365
"torch.ops.torchao.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int8)",
366366
"torch.ops.torchao.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int8)",
367367
"torch.ops.torchao.dequantize_affine.default",

torchao/quantization/GPTQ.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,10 @@ def linear_forward_8da4w(
938938
# TODO: in future add ability to specify activation_scale_dtype to PTQ configs
939939
# and enable similar change here
940940
x = per_token_dynamic_quant(
941-
x, scale_dtype=torch.float32, zero_point_dtype=torch.float32
941+
x,
942+
scale_dtype=torch.float32,
943+
zero_point_dtype=torch.float32,
944+
eps=torch.finfo(torch.float32).eps,
942945
)
943946

944947
# TODO: verify and remove following reshape code

torchao/quantization/qat/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class FakeQuantizeConfig:
8585
zero_point_domain: ZeroPointDomain
8686
is_dynamic: bool = True
8787
range_learning: bool = False
88+
eps: Optional[float] = None
8889

8990
def __init__(
9091
self,
@@ -96,6 +97,7 @@ def __init__(
9697
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
9798
is_dynamic: bool = True,
9899
range_learning: bool = False,
100+
eps: Optional[float] = None,
99101
*,
100102
group_size: Optional[int] = None,
101103
is_symmetric: Optional[bool] = None,
@@ -110,6 +112,7 @@ def __init__(
110112
self.zero_point_domain = zero_point_domain
111113
self.is_dynamic = is_dynamic
112114
self.range_learning = range_learning
115+
self.eps = eps
113116

114117
# Validate dtype
115118
all_dtypes = [torch.int8, torch.uint8]

torchao/quantization/qat/fake_quantizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _per_token_forward(self, x: torch.Tensor):
8181
target_dtype=self.config.dtype,
8282
quant_min=qmin,
8383
quant_max=qmax,
84+
eps=self.config.eps,
8485
scale_dtype=self.config.scale_precision,
8586
zero_point_dtype=self.config.zero_point_precision,
8687
)
@@ -117,13 +118,15 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
117118
bit_width,
118119
group_size,
119120
scale_precision,
121+
eps=self.config.eps,
120122
)
121123
else:
122124
(self.scale, self.zero_point) = get_groupwise_affine_qparams(
123125
x,
124126
bit_width,
125127
group_size,
126128
scale_precision,
129+
eps=self.config.eps,
127130
)
128131
self.zero_point = self.zero_point.to(zero_point_precision)
129132

torchao/quantization/qat/linear.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def __init__(
177177
self.padding_allowed: bool = padding_allowed
178178
self.precision: torch.dtype = precision
179179
self.scales_precision: torch.dtype = scales_precision
180+
# TODO: generalize this
181+
self.activation_scales_precision = torch.float32
180182

181183
def prepare(
182184
self, model: torch.nn.Module, *args: Any, **kwargs: Any
@@ -247,7 +249,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
247249
self._convert_qat_linear_8da4w(child)
248250

249251
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
250-
return _get_8da4w_activation_config(self.scales_precision)
252+
return _get_8da4w_activation_config(self.activation_scales_precision)
251253

252254
def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
253255
return _get_8da4w_weight_config(self.groupsize, self.scales_precision)
@@ -280,6 +282,7 @@ def __init__(
280282
) -> None:
281283
# Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant,
282284
# which is used in PTQ routines
285+
# TODO: generalize this
283286
activation_config = _get_8da4w_activation_config(torch.float32)
284287
weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
285288
super().__init__(
@@ -320,13 +323,16 @@ def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantize
320323
"""
321324
Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
322325
"""
326+
# TODO: generalize this
327+
assert qparams_precision == torch.float32
323328
return FakeQuantizeConfig(
324329
dtype=torch.int8,
325330
granularity="per_token",
326331
is_symmetric=False,
327332
is_dynamic=True,
328333
scale_precision=qparams_precision,
329334
zero_point_precision=qparams_precision,
335+
eps=torch.finfo(qparams_precision).eps,
330336
)
331337

332338

torchao/quantization/quant_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,13 +627,15 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
627627
mapping_type = MappingType.ASYMMETRIC
628628
target_dtype = torch.int8
629629
scale_dtype = torch.float32
630+
eps = torch.finfo(torch.float32).eps
630631
zero_point_dtype = torch.int8
631632
if TORCH_VERSION_AT_LEAST_2_6:
632633
return to_affine_quantized_intx(
633634
x,
634635
mapping_type,
635636
_get_per_token_block_size(x),
636637
target_dtype,
638+
eps=eps,
637639
scale_dtype=scale_dtype,
638640
zero_point_dtype=zero_point_dtype,
639641
)

torchao/quantization/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def get_groupwise_affine_qparams(
324324
dtype=torch.bfloat16,
325325
zero_point_domain=ZeroPointDomain.FLOAT,
326326
preserve_zero=False,
327+
eps=None,
327328
):
328329
if groupsize > w.shape[-1]:
329330
groupsize = w.shape[-1]
@@ -337,7 +338,8 @@ def get_groupwise_affine_qparams(
337338
block_size = (1, groupsize)
338339
quant_min = 0
339340
quant_max = 2**n_bit - 1
340-
eps = 1e-6
341+
if eps is None:
342+
eps = 1e-6
341343
scale_dtype = dtype
342344
zero_point_dtype = (
343345
dtype if zero_point_domain != ZeroPointDomain.INT else torch.int32
@@ -530,6 +532,7 @@ def get_group_qparams_symmetric(
530532
groupsize=128,
531533
precision=torch.float32,
532534
mapping_type=MappingType.SYMMETRIC,
535+
eps=None,
533536
):
534537
# needed for GPTQ with padding
535538
if groupsize > w.shape[-1]:
@@ -540,7 +543,8 @@ def get_group_qparams_symmetric(
540543
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"
541544

542545
block_size = (1, groupsize)
543-
eps = torch.finfo(w.dtype).eps
546+
if eps is None:
547+
eps = torch.finfo(w.dtype).eps
544548
ranges = {}
545549
ranges[1] = (-1, 0)
546550
# generating ranges for bit 2 to 8
@@ -591,6 +595,7 @@ def per_token_dynamic_quant(
591595
input: torch.Tensor,
592596
scale_dtype: torch.dtype = torch.float32,
593597
zero_point_dtype: torch.dtype = torch.float32,
598+
eps: Optional[float] = None,
594599
) -> torch.Tensor:
595600
mapping_type = MappingType.ASYMMETRIC
596601
block_size = _get_per_token_block_size(input)
@@ -608,6 +613,7 @@ def per_token_dynamic_quant(
608613
quant_max,
609614
scale_dtype=scale_dtype,
610615
zero_point_dtype=zero_point_dtype,
616+
eps=eps,
611617
)
612618
q = quantize_affine(
613619
input,

0 commit comments

Comments
 (0)