Skip to content

Commit 0d9631b

Browse files
authored
Add Float8ActInt4WeightQATQuantizer (#2289)
**Summary:** This commit adds a QAT quantizer that performs float8 dynamic activation + int4 symmetric per channel weight fake quantization. Note that there is no corresponding config for float8 QAT yet. This will be added in a future PR. **Test Plan:** python test/quantization/test_qat.py -k test_float8_fake_quantize python test/quantization/test_qat.py -k test_qat_fp8a4w_quantizer
1 parent d72a6d1 commit 0d9631b

File tree

5 files changed

+228
-8
lines changed

5 files changed

+228
-8
lines changed

test/quantization/test_qat.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1818

1919
from torchao import quantize_
20+
from torchao.float8.config import ScalingGranularity
21+
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
22+
from torchao.float8.float8_tensor import LinearMMConfig
2023
from torchao.quantization.granularity import (
2124
PerAxis,
2225
PerGroup,
@@ -40,15 +43,18 @@
4043
)
4144
from torchao.quantization.qat.fake_quantizer import (
4245
FakeQuantizer,
46+
_Float8RowwiseActivationFakeQuantizer,
4347
)
4448
from torchao.quantization.qat.linear import (
4549
FakeQuantizedLinear,
50+
Float8ActInt4WeightQATQuantizer,
4651
Int4WeightOnlyQATLinear,
4752
Int8DynActInt4WeightQATLinear,
4853
)
4954
from torchao.quantization.qat.utils import (
5055
_fake_quantize_per_channel_group,
5156
_fake_quantize_per_token,
57+
_Float8RowwiseFakeQuantize,
5258
_get_qmin_qmax,
5359
)
5460
from torchao.quantization.quant_api import (
@@ -68,6 +74,7 @@
6874
)
6975
from torchao.quantization.utils import (
7076
_get_per_token_block_size,
77+
compute_error,
7178
get_group_qparams_symmetric,
7279
get_groupwise_affine_qparams,
7380
groupwise_affine_quantize_tensor,
@@ -1474,7 +1481,6 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype):
14741481
numerics that match exactly over N trials.
14751482
"""
14761483
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
1477-
from torchao.quantization.utils import compute_error
14781484

14791485
num_trials = 1000
14801486
group_size = 16
@@ -1688,6 +1694,61 @@ def test_qat_range_learning(self):
16881694
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
16891695
self.assertFalse(torch.equal(new_weight, prev_weight))
16901696

1697+
def test_float8_rowwise_fake_quantize(self):
1698+
"""
1699+
Test that `_Float8RowwiseFakeQuantize` is numerically close to `Float8Tensor`.
1700+
"""
1701+
torch.manual_seed(self.SEED)
1702+
dtype = torch.float8_e4m3fn
1703+
x = torch.randn(32, 64)
1704+
axiswise_dim = 0
1705+
out = _Float8RowwiseFakeQuantize.apply(x, dtype, axiswise_dim)
1706+
out_expected = hp_tensor_to_float8_dynamic(
1707+
x,
1708+
dtype,
1709+
LinearMMConfig(),
1710+
scaling_granularity=ScalingGranularity.AXISWISE,
1711+
axiswise_dim=axiswise_dim,
1712+
).to_original_precision()
1713+
torch.testing.assert_close(out, out_expected, atol=0, rtol=0)
1714+
1715+
@unittest.skipIf(
1716+
not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower"
1717+
)
1718+
def test_qat_fp8a4w_quantizer(self):
1719+
"""
1720+
Test basic model training with `Float8ActInt4WeightQATQuantizer`.
1721+
"""
1722+
torch.manual_seed(self.SEED)
1723+
m = M()
1724+
qat_quantizer = Float8ActInt4WeightQATQuantizer()
1725+
qat_model = qat_quantizer.prepare(m)
1726+
for linear in [m.linear1, m.sub.linear, m.linear2]:
1727+
self.assertIsInstance(linear, FakeQuantizedLinear)
1728+
self.assertIsInstance(
1729+
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
1730+
)
1731+
self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer)
1732+
prev_weight = copy.deepcopy(m.linear1.weight)
1733+
1734+
# Simulate training
1735+
optimizer = torch.optim.SGD(
1736+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
1737+
)
1738+
loss_fn = torch.nn.CrossEntropyLoss()
1739+
optimizer.zero_grad()
1740+
target = torch.randn(1, 512).float()
1741+
example_inputs = m.example_inputs()
1742+
out = qat_model(*example_inputs)
1743+
loss = loss_fn(out, target)
1744+
loss.backward()
1745+
optimizer.step()
1746+
# Assert that weights have valid gradients and are being updated
1747+
new_weight = m.linear1.weight
1748+
self.assertIsNotNone(new_weight.grad)
1749+
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
1750+
self.assertFalse(torch.equal(new_weight, prev_weight))
1751+
16911752

16921753
if __name__ == "__main__":
16931754
unittest.main()

torchao/quantization/qat/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
Int4WeightOnlyEmbeddingQATQuantizer,
1212
)
1313
from .linear import (
14+
Float8ActInt4WeightQATQuantizer,
1415
Int4WeightOnlyQATQuantizer,
1516
Int8DynActInt4WeightQATQuantizer,
1617
)
1718

1819
__all__ = [
1920
"ComposableQATQuantizer",
2021
"FakeQuantizeConfig",
22+
"Float8ActInt4WeightQATQuantizer",
2123
"FromIntXQuantizationAwareTrainingConfig",
2224
"Int4WeightOnlyEmbeddingQATQuantizer",
2325
"Int4WeightOnlyQATQuantizer",

torchao/quantization/qat/fake_quantizer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .utils import (
3333
_fake_quantize_per_channel_group,
3434
_fake_quantize_per_token,
35+
_Float8RowwiseFakeQuantize,
3536
)
3637

3738

@@ -186,3 +187,23 @@ def __repr__(self) -> str:
186187
Return a human readable representation of this `FakeQuantizer` with config details.
187188
"""
188189
return "FakeQuantizer(%s)" % self.config
190+
191+
192+
class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module):
193+
"""
194+
Simple fake quantizer for float8 rowwise fake quantization, intended for activations only.
195+
"""
196+
197+
def __init__(self):
198+
super().__init__()
199+
self.enabled = True
200+
201+
def forward(self, x: torch.Tensor) -> torch.Tensor:
202+
if self.enabled:
203+
return _Float8RowwiseFakeQuantize.apply(
204+
x,
205+
torch.float8_e4m3fn,
206+
-1,
207+
)
208+
else:
209+
return x

torchao/quantization/qat/linear.py

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2929

3030
from .api import FakeQuantizeConfig
31-
from .fake_quantizer import FakeQuantizer
31+
from .fake_quantizer import (
32+
FakeQuantizer,
33+
_Float8RowwiseActivationFakeQuantizer,
34+
)
3235
from .utils import (
3336
_get_qmin_qmax,
3437
)
@@ -145,6 +148,11 @@ def from_linear(
145148
return new_linear
146149

147150

151+
# ===========================
152+
# | QAT quantizer interface |
153+
# ===========================
154+
155+
148156
class _LegacyQATQuantizer(TwoStepQuantizer):
149157
"""
150158
Base class for sharing common methods across legacy QAT quantizers.
@@ -157,9 +165,30 @@ def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
157165
return None
158166

159167

160-
# =========================================================
161-
# | Linear int8 dynamic activations + int4 weight QAT |
162-
# =========================================================
168+
def enable_linear_fake_quant(
169+
mod: torch.nn.Module,
170+
enabled: bool = True,
171+
):
172+
"""
173+
Helper function to enable fake quantization in `FakeQuantizerLinear`.
174+
"""
175+
if isinstance(mod, FakeQuantizedLinear):
176+
if mod.activation_fake_quantizer is not None:
177+
mod.activation_fake_quantizer.enabled = enabled
178+
if mod.weight_fake_quantizer is not None:
179+
mod.weight_fake_quantizer.enabled = enabled
180+
181+
182+
def disable_linear_fake_quant(mod: torch.nn.Module):
183+
"""
184+
Helper function to disable fake quantization in `FakeQuantizerLinear`.
185+
"""
186+
enable_linear_fake_quant(mod, enabled=False)
187+
188+
189+
# ===========================================
190+
# | int8 dynamic activations + int4 weights |
191+
# ===========================================
163192

164193

165194
class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer):
@@ -307,6 +336,7 @@ def disable_fake_quant(self):
307336
self.enable_fake_quant(False)
308337

309338

339+
# TODO: remove these in favor of enable_linear_fake_quant
310340
def enable_8da4w_fake_quant(mod: torch.nn.Module):
311341
"""
312342
Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
@@ -315,6 +345,7 @@ def enable_8da4w_fake_quant(mod: torch.nn.Module):
315345
mod.enable_fake_quant()
316346

317347

348+
# TODO: remove in favor of disable_linear_fake_quant
318349
def disable_8da4w_fake_quant(mod: torch.nn.Module):
319350
"""
320351
Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
@@ -357,9 +388,9 @@ def _get_8da4w_weight_config(
357388
)
358389

359390

360-
# ===================================
361-
# | Linear int4 weight-only QAT |
362-
# ===================================
391+
# ====================
392+
# | int4 weight-only |
393+
# ====================
363394

364395

365396
class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer):
@@ -501,6 +532,7 @@ def disable_fake_quant(self):
501532
self.enable_fake_quant(False)
502533

503534

535+
# TODO: remove these in favor of enable_linear_fake_quant
504536
def enable_4w_fake_quant(mod: torch.nn.Module):
505537
"""
506538
Enable fake quantization for `Int4WeightOnlyQATLinear`.
@@ -509,6 +541,7 @@ def enable_4w_fake_quant(mod: torch.nn.Module):
509541
mod.enable_fake_quant()
510542

511543

544+
# TODO: remove these in favor of disable_linear_fake_quant
512545
def disable_4w_fake_quant(mod: torch.nn.Module):
513546
"""
514547
Disable fake quantization for `Int4WeightOnlyQATLinear`.
@@ -533,3 +566,74 @@ def _get_4w_weight_config(
533566
zero_point_precision=qparams_precision,
534567
zero_point_domain=ZeroPointDomain.FLOAT,
535568
)
569+
570+
571+
# =============================================
572+
# | float8 rowwise activations + int4 weights |
573+
# =============================================
574+
575+
576+
class Float8ActInt4WeightQATQuantizer(_LegacyQATQuantizer):
577+
"""
578+
QAT quantizer for applying dynamic rowwise float8 activation + int4
579+
per group/channel symmetric weight fake quantization to linear layers
580+
in the model. Currently only supports rowwise granularity for float8
581+
activations.
582+
583+
args:
584+
group_size (Optional[int]): the number of elements in each quantized
585+
group for weights, defaults to 64. Use None for per channel.
586+
scale_precision: precision of weight scales, defaults to torch.bfloat16.
587+
"""
588+
589+
def __init__(
590+
self,
591+
group_size: Optional[int] = 64,
592+
scale_precision: torch.dtype = torch.bfloat16,
593+
):
594+
if group_size is not None:
595+
weight_granularity = "per_group"
596+
else:
597+
weight_granularity = "per_channel"
598+
self._weight_config = FakeQuantizeConfig(
599+
dtype=torch.int4,
600+
granularity=weight_granularity,
601+
group_size=group_size,
602+
is_symmetric=True,
603+
is_dynamic=True,
604+
scale_precision=scale_precision,
605+
)
606+
607+
def prepare(
608+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
609+
) -> torch.nn.Module:
610+
"""
611+
Swap all `nn.Linear` with `FakeQuantizedLinear` with float8
612+
fake quantizer for activations and int4 fake quantizer for weights.
613+
"""
614+
for name, child in model.named_children():
615+
if isinstance(child, torch.nn.Linear):
616+
# TODO: add a config for float8?
617+
new_linear = FakeQuantizedLinear.from_linear(
618+
child,
619+
weight_config=self._weight_config,
620+
)
621+
new_linear.activation_fake_quantizer = (
622+
_Float8RowwiseActivationFakeQuantizer()
623+
)
624+
setattr(model, name, new_linear)
625+
else:
626+
self.prepare(child)
627+
return model
628+
629+
# TODO: add convert path
630+
def convert(
631+
self, model: torch.nn.Module, *args: Any, **kwargs: Any
632+
) -> torch.nn.Module:
633+
raise NotImplementedError
634+
635+
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
636+
raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet")
637+
638+
def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
639+
return self.weight_config

torchao/quantization/qat/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,38 @@
1616
)
1717

1818

19+
class _Float8RowwiseFakeQuantize(torch.autograd.Function):
20+
"""
21+
Implementation of float8 rowwise fake quantize with backward STE.
22+
"""
23+
24+
@staticmethod
25+
def forward(
26+
ctx: torch.autograd.function.FunctionCtx,
27+
x: torch.Tensor,
28+
float8_dtype: torch.dtype,
29+
axiswise_dim: int,
30+
):
31+
# compute rowwise scale based on `torchao.float8.float8_utils.tensor_to_scale`
32+
eps = 1e-12
33+
amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True)
34+
amax = amax.to(torch.float64)
35+
scale = torch.finfo(float8_dtype).max / torch.clamp(amax, min=eps)
36+
scale = scale.to(torch.float32)
37+
38+
# fake quantize
39+
max_value = torch.finfo(float8_dtype).max
40+
x_fq = x.to(torch.float32) * scale
41+
x_fq = x_fq.clamp(min=-max_value, max=max_value)
42+
x_fq = x_fq.to(float8_dtype).to(x.dtype)
43+
x_fq = x_fq / scale
44+
return x_fq.to(x.dtype)
45+
46+
@staticmethod
47+
def backward(ctx, gy):
48+
return gy, None, None
49+
50+
1951
# TODO: delete?
2052
class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function):
2153
"""

0 commit comments

Comments
 (0)