Skip to content

Commit 1b61c82

Browse files
authored
Enable range learning for QAT
Differential Revision: D72754131 Pull Request resolved: #2033
1 parent 756d074 commit 1b61c82

File tree

8 files changed

+217
-19
lines changed

8 files changed

+217
-19
lines changed

test/quantization/test_qat.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import copy
1111
import unittest
12+
from typing import List
1213

1314
import torch
1415
import torch.nn.functional as F
@@ -26,7 +27,9 @@
2627
from torchao.quantization.qat.api import (
2728
ComposableQATQuantizer,
2829
FakeQuantizeConfig,
30+
IntXQuantizationAwareTrainingConfig,
2931
from_intx_quantization_aware_training,
32+
initialize_fake_quantizers,
3033
intx_quantization_aware_training,
3134
)
3235
from torchao.quantization.qat.embedding import (
@@ -99,6 +102,16 @@ def __init__(self):
99102
def example_inputs(self):
100103
return (torch.randn(1, 512).to(torch.float),)
101104

105+
def _get_all_weight_qparams(self) -> List[torch.Tensor]:
106+
return [
107+
self.linear1.weight_fake_quantizer.scale,
108+
self.linear1.weight_fake_quantizer.zero_point,
109+
self.sub.linear.weight_fake_quantizer.scale,
110+
self.sub.linear.weight_fake_quantizer.zero_point,
111+
self.linear2.weight_fake_quantizer.scale,
112+
self.linear2.weight_fake_quantizer.zero_point,
113+
]
114+
102115
def forward(self, x):
103116
x = self.linear1(x)
104117
x = self.sub(x)
@@ -996,6 +1009,21 @@ def test_fake_quantize_config_dtype(self):
9961009
FakeQuantizeConfig(TorchAODType.INT7, "per_token")
9971010
FakeQuantizeConfig(torch.int8, "per_token")
9981011

1012+
def test_fake_quantize_config_dynamic_and_range_learning(self):
1013+
"""
1014+
Test that `is_dynamic` and `range_learning` cannot both be set.
1015+
"""
1016+
FakeQuantizeConfig(
1017+
torch.int8, "per_channel", is_dynamic=True, range_learning=False
1018+
)
1019+
FakeQuantizeConfig(
1020+
torch.int8, "per_channel", is_dynamic=False, range_learning=True
1021+
)
1022+
with self.assertRaisesRegex(ValueError, "not compatible"):
1023+
FakeQuantizeConfig(
1024+
torch.int8, "per_channel", is_dynamic=True, range_learning=True
1025+
)
1026+
9991027
@unittest.skipIf(
10001028
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
10011029
)
@@ -1591,6 +1619,95 @@ def test_qat_8da4w_eps(self):
15911619
actual_out = converted_model.linear1(x)
15921620
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
15931621

1622+
@unittest.skipIf(
1623+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1624+
)
1625+
def test_fake_quantizer_range_learning(self):
1626+
"""
1627+
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
1628+
"""
1629+
config = FakeQuantizeConfig(
1630+
torch.int8,
1631+
"per_channel",
1632+
is_dynamic=False,
1633+
range_learning=True,
1634+
scale_precision=torch.float32,
1635+
zero_point_precision=torch.float32,
1636+
)
1637+
fake_quantizer = FakeQuantizer(config)
1638+
example_inputs = (torch.randn(2, 3),)
1639+
1640+
# Not initialized, should fail
1641+
self.assertFalse(fake_quantizer._initialized)
1642+
self.assertIsNone(fake_quantizer.scale)
1643+
self.assertIsNone(fake_quantizer.zero_point)
1644+
with self.assertRaisesRegex(
1645+
ValueError,
1646+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1647+
"before initializing the optimizer and beginning training.",
1648+
):
1649+
fake_quantizer(*example_inputs)
1650+
1651+
# Should pass after initializing
1652+
initialize_fake_quantizers(fake_quantizer, example_inputs)
1653+
self.assertTrue(fake_quantizer._initialized)
1654+
self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter)
1655+
self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter)
1656+
self.assertTrue(fake_quantizer.scale.requires_grad)
1657+
self.assertTrue(fake_quantizer.zero_point.requires_grad)
1658+
fake_quantizer(*example_inputs)
1659+
1660+
@unittest.skipIf(
1661+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1662+
)
1663+
def test_qat_range_learning(self):
1664+
"""
1665+
Test end-to-end QAT flow with range learning.
1666+
"""
1667+
config = FakeQuantizeConfig(
1668+
torch.int8,
1669+
"per_channel",
1670+
is_dynamic=False,
1671+
range_learning=True,
1672+
scale_precision=torch.float32,
1673+
zero_point_precision=torch.float32,
1674+
)
1675+
m = M()
1676+
example_inputs = m.example_inputs()
1677+
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1678+
1679+
# Not initialized, should fail
1680+
for t in m._get_all_weight_qparams():
1681+
self.assertIsNone(t)
1682+
with self.assertRaisesRegex(
1683+
ValueError,
1684+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1685+
"before initializing the optimizer and beginning training.",
1686+
):
1687+
m(*example_inputs)
1688+
1689+
# Should pass after initializing
1690+
# All scales and zero points should be in `m.parameters()`
1691+
initialize_fake_quantizers(m, example_inputs)
1692+
params = set(m.parameters())
1693+
for t in m._get_all_weight_qparams():
1694+
self.assertIsInstance(t, torch.nn.Parameter)
1695+
self.assertTrue(t.requires_grad)
1696+
self.assertTrue(t in params)
1697+
m(*example_inputs)
1698+
1699+
# Simulate training
1700+
optimizer = torch.optim.SGD(
1701+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
1702+
)
1703+
loss_fn = torch.nn.CrossEntropyLoss()
1704+
target = torch.randn(1, 512).float()
1705+
out = m(*example_inputs)
1706+
loss = loss_fn(out, target)
1707+
optimizer.zero_grad()
1708+
loss.backward()
1709+
optimizer.step()
1710+
15941711

15951712
if __name__ == "__main__":
15961713
unittest.main()

third_party/cutlass

Submodule cutlass updated 507 files

torchao/quantization/qat/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
FromIntXQuantizationAwareTrainingConfig,
55
IntXQuantizationAwareTrainingConfig,
66
from_intx_quantization_aware_training,
7+
initialize_fake_quantizers,
78
intx_quantization_aware_training,
89
)
910
from .embedding import (
@@ -17,11 +18,12 @@
1718
__all__ = [
1819
"ComposableQATQuantizer",
1920
"FakeQuantizeConfig",
20-
"Int4WeightOnlyQATQuantizer",
21+
"FromIntXQuantizationAwareTrainingConfig",
2122
"Int4WeightOnlyEmbeddingQATQuantizer",
23+
"Int4WeightOnlyQATQuantizer",
2224
"Int8DynActInt4WeightQATQuantizer",
25+
"IntXQuantizationAwareTrainingConfig",
26+
"initialize_fake_quantizers",
2327
"intx_quantization_aware_training",
2428
"from_intx_quantization_aware_training",
25-
"FromIntXQuantizationAwareTrainingConfig",
26-
"IntXQuantizationAwareTrainingConfig",
2729
]

torchao/quantization/qat/api.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, List, Optional, Union
8+
from typing import Any, List, Optional, Tuple, Union
99

1010
import torch
1111

@@ -51,7 +51,8 @@ class FakeQuantizeConfig:
5151
zero_point_precision: zero point dtype (default torch.int32)
5252
zero_point_domain: whether zero point is in integer (default) or float domain
5353
is_dynamic: whether to use dynamic (default) or static scale and zero points
54-
range_learning: whether to learn scale and zero points during training (coming soon)
54+
range_learning: whether to learn scale and zero points during training
55+
(default false), not compatible with `is_dynamic`.
5556
5657
kwargs (optional):
5758
group_size: size of each group in per group fake quantization,
@@ -123,6 +124,10 @@ def __init__(
123124
"Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)
124125
)
125126

127+
# Dynamic is not compatible with range learning
128+
if is_dynamic and range_learning:
129+
raise ValueError("`is_dynamic` is not compatible with `range_learning`")
130+
126131
def _get_granularity(
127132
self,
128133
granularity: Union[Granularity, str, None],
@@ -394,3 +399,23 @@ def convert(
394399
for quantizer in self.quantizers:
395400
model = quantizer.convert(model)
396401
return model
402+
403+
404+
def initialize_fake_quantizers(
405+
model: torch.nn.Module,
406+
example_inputs: Tuple[Any, ...],
407+
) -> None:
408+
"""
409+
Initialize the scales and zero points on all
410+
:class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer`
411+
in the model based on the provided example inputs.
412+
"""
413+
# avoid circular dependencies
414+
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
415+
416+
def _set_initialized(m: torch.nn.Module):
417+
if isinstance(m, FakeQuantizer):
418+
m._initialized = True
419+
420+
model.apply(_set_initialized)
421+
model(*example_inputs)

torchao/quantization/qat/embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def to_embedding(self) -> torch.nn.Embedding:
9292
self.scale_grad_by_freq,
9393
self.sparse,
9494
device=self.weight.device,
95+
dtype=self.weight.dtype,
9596
)
9697
# In distributed training, the model may be instantiated
9798
# on the meta device, in which case there is no need to
@@ -116,6 +117,7 @@ def from_embedding(
116117
mod.sparse,
117118
weight_config=weight_config,
118119
device=mod.weight.device,
120+
dtype=mod.weight.dtype,
119121
)
120122
# In distributed training, the model may be instantiated
121123
# on the meta device, in which case there is no need to

torchao/quantization/qat/fake_quantizer.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .utils import (
3232
_fake_quantize_per_channel_group,
3333
_fake_quantize_per_token,
34+
_Round,
3435
)
3536

3637

@@ -46,32 +47,43 @@ def __init__(self, config: FakeQuantizeConfig):
4647
self.scale: Optional[torch.Tensor] = None
4748
self.zero_point: Optional[torch.Tensor] = None
4849

49-
# TODO: support range learinng
50-
if self.config.range_learning:
51-
raise NotImplementedError("Range learning is not supported yet")
50+
# For range learning only
51+
# TODO: make this configurable?
52+
self._scale_eps = 1e-9
53+
self._initialized = False
5254

53-
def forward(self, x: torch.Tensor):
55+
def forward(self, x: torch.Tensor) -> torch.Tensor:
5456
"""
5557
Apply fake quantization to the tensor based on the bit-width,
5658
granularity, symmetry, and other properties specified in the config.
5759
"""
5860
if not self.enabled:
5961
return x
6062

63+
if (
64+
self.config.range_learning
65+
and not self._initialized
66+
and (self.scale is None or self.zero_point is None)
67+
):
68+
raise ValueError(
69+
"Scales and zero points must be initialized for range learning. "
70+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
71+
"before initializing the optimizer and beginning training."
72+
)
73+
6174
if isinstance(self.config.granularity, PerToken):
6275
return self._per_token_forward(x)
6376
elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
6477
return self._per_channel_or_group_forward(x)
6578
else:
6679
raise ValueError("Unknown granularity '%s'" % self.config.granularity)
6780

68-
def _per_token_forward(self, x: torch.Tensor):
81+
def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
6982
"""
7083
Perform per token fake quantization on the tensor.
7184
"""
7285
if self.config.is_symmetric:
7386
raise NotImplementedError("Symmetric per token is not supported yet")
74-
7587
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7688
if self._should_compute_qparams():
7789
self.scale, self.zero_point = choose_qparams_affine(
@@ -85,9 +97,10 @@ def _per_token_forward(self, x: torch.Tensor):
8597
scale_dtype=self.config.scale_precision,
8698
zero_point_dtype=self.config.zero_point_precision,
8799
)
100+
self._maybe_update_qparams_for_range_learning()
88101
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
89102

90-
def _per_channel_or_group_forward(self, x: torch.Tensor):
103+
def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
91104
"""
92105
Perform per channel or per group fake quantization on the tensor.
93106
We express per channel using per group where the group size is the size
@@ -129,6 +142,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
129142
eps=self.config.eps,
130143
)
131144
self.zero_point = self.zero_point.to(zero_point_precision)
145+
self._maybe_update_qparams_for_range_learning()
132146

133147
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
134148
return _fake_quantize_per_channel_group(
@@ -147,6 +161,26 @@ def _should_compute_qparams(self) -> bool:
147161
"""
148162
return self.config.is_dynamic or self.scale is None or self.zero_point is None
149163

164+
def _maybe_update_qparams_for_range_learning(self) -> None:
165+
"""
166+
If range learning is enabled, turn scales and zero points into trainable parameters.
167+
This function is idempotent and should only be called once.
168+
"""
169+
if (
170+
not self.config.range_learning
171+
or isinstance(self.scale, torch.nn.Parameter)
172+
or isinstance(self.zero_point, torch.nn.Parameter)
173+
):
174+
return
175+
scale, zero_point = self.scale, self.zero_point
176+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
177+
# Stabilize range learning
178+
scale = torch.clamp(scale, min=self._scale_eps)
179+
zero_point = _Round.apply(zero_point)
180+
zero_point = torch.clamp(zero_point, qmin, qmax)
181+
self.scale = torch.nn.Parameter(scale, requires_grad=True)
182+
self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True)
183+
150184
def __repr__(self) -> str:
151185
"""
152186
Return a human readable representation of this `FakeQuantizer` with config details.

torchao/quantization/qat/linear.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_replace_linear_int4,
1919
groupwise_affine_quantize_tensor,
2020
)
21+
from torchao.quantization.granularity import PerGroup
2122
from torchao.quantization.quant_primitives import (
2223
TorchAODType,
2324
ZeroPointDomain,
@@ -83,12 +84,13 @@ def __init__(
8384

8485
# initialize weight fake quantizer
8586
if weight_config is not None:
86-
group_size = weight_config.group_size
87-
if group_size is not None and in_features % group_size != 0:
88-
raise ValueError(
89-
"in_features (%s) %% group_size (%s) must be == 0"
90-
% (in_features, group_size)
91-
)
87+
if isinstance(weight_config.granularity, PerGroup):
88+
group_size = weight_config.group_size
89+
if group_size is not None and in_features % group_size != 0:
90+
raise ValueError(
91+
"in_features (%s) %% group_size (%s) must be == 0"
92+
% (in_features, group_size)
93+
)
9294
self.weight_fake_quantizer = FakeQuantizer(weight_config)
9395
else:
9496
self.weight_fake_quantizer = None
@@ -108,6 +110,7 @@ def to_linear(self) -> torch.nn.Linear:
108110
self.out_features,
109111
self.bias is not None,
110112
device=self.weight.device,
113+
dtype=self.weight.dtype,
111114
)
112115
# In distributed training, the model may be instantiated
113116
# on the meta device, in which case there is no need to
@@ -131,6 +134,7 @@ def from_linear(
131134
activation_config=activation_config,
132135
weight_config=weight_config,
133136
device=mod.weight.device,
137+
dtype=mod.weight.dtype,
134138
)
135139
# In distributed training, the model may be instantiated
136140
# on the meta device, in which case there is no need to

0 commit comments

Comments
 (0)