Skip to content

Commit c45d975

Browse files
authored
Add support for quantize_() with Float8Linear module (#1344)
1 parent ed76e9c commit c45d975

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

test/float8/test_base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616

17-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_89
1818

1919
if not TORCH_VERSION_AT_LEAST_2_5:
2020
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -531,6 +531,21 @@ def test_inference_mode(self):
531531
with torch.inference_mode(mode=True):
532532
m(x)
533533

534+
@unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available")
535+
def test_quantize(self):
536+
x = torch.randn(32, 32, device="cuda")
537+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
538+
m = convert_to_float8_training(m)
539+
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
540+
from torchao.quantization.quant_api import float8_weight_only, quantize_
541+
542+
quantize_(m, float8_weight_only())
543+
assert (
544+
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
545+
), "Post quantization dtype should be torch.float8_e4m3fn"
546+
with torch.no_grad():
547+
m(x)
548+
534549

535550
class TestScaledMM:
536551
@unittest.skipIf(
@@ -576,7 +591,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
576591
if base_dtype in {torch.bfloat16, torch.float16}:
577592
atol, rtol = 7e-2, 7e-2
578593
else:
579-
atol, rtol = 2e-3, 2e-3
594+
atol, rtol = 3e-3, 3e-3
580595
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
581596

582597
@unittest.skipIf(not is_cuda_8_9, "CUDA not available")

torchao/quantization/quant_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
to_affine_quantized_intx,
4040
to_marlinqqq_quantized_intx,
4141
)
42+
from torchao.float8.float8_linear import Float8Linear
4243
from torchao.float8.inference import Float8MMConfig
4344
from torchao.quantization.linear_activation_weight_observed_tensor import (
4445
LinearActivationWeightObservedTensor,
@@ -222,6 +223,12 @@ def _replace_with_custom_fn_if_matches_filter(
222223
Returns:
223224
None
224225
"""
226+
if isinstance(model, Float8Linear):
227+
with torch.device("meta"):
228+
new_module = nn.Linear(model.in_features, model.out_features)
229+
new_module.weight = model.weight
230+
new_module.bias = model.bias
231+
model = new_module
225232
if filter_fn(model, cur_fqn[:-1]):
226233
if device is not None:
227234
model.to(device=device) # move to device before quantization

0 commit comments

Comments
 (0)