Skip to content

Commit 76e2ef5

Browse files
authored
Support model.to int8 weight only quantized model (#122)
Summary: registering fields as buffers so they get picked up in `model.to` Test Plan: python test/quantization/test_quant_api.py -k test_int8_wo_quant_save_load Reviewers: Subscribers: Tasks: Tags:
1 parent eba4c36 commit 76e2ef5

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

test/quantization/test_quant_api.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# This test takes a long time to run
99
import unittest
1010
import torch
11+
import os
1112
from torch._export import capture_pre_autograd_graph
1213
from torch.ao.quantization.quantize_pt2e import (
1314
prepare_pt2e,
@@ -18,9 +19,10 @@
1819
get_symmetric_quantization_config,
1920
)
2021

21-
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
22-
from torchao.quantization.quant_api import apply_dynamic_quant
2322
from torchao.quantization.quant_api import (
23+
_replace_with_custom_fn_if_matches_filter,
24+
apply_dynamic_quant,
25+
apply_weight_only_int8_quant,
2426
Quantizer,
2527
TwoStepQuantizer,
2628
)
@@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
137139
compiled = m(*example_inputs)
138140
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
139141

142+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
143+
def test_int8_wo_quant_save_load(self):
144+
m = M().eval().cpu()
145+
apply_weight_only_int8_quant(m)
146+
example_inputs = m.example_inputs()
147+
ref = m(*example_inputs)
148+
_TMP_FN = "_test.pt"
149+
torch.save(m.state_dict(), _TMP_FN)
150+
151+
state_dict = torch.load(_TMP_FN)
152+
os.remove(_TMP_FN)
153+
m2 = M().eval()
154+
apply_weight_only_int8_quant(m2)
155+
m2.load_state_dict(state_dict)
156+
m2 = m2.to(device="cuda")
157+
example_inputs = map(lambda x: x.cuda(), example_inputs)
158+
res = m2(*example_inputs)
159+
160+
torch.testing.assert_close(ref, res.cpu())
161+
140162
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
141163
def test_8da4w_quantizer(self):
142164
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

torchao/quantization/weight_only.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def __init__(self, *args, **kwargs):
2222
scales = kwargs.pop("scales")
2323
super().__init__(*args, **kwargs)
2424

25-
self.w_int8 = w_int8
26-
27-
self.scales = scales
25+
self.register_buffer("w_int8", w_int8)
26+
self.register_buffer("scales", scales)
2827

2928
def forward(self, x, *args, **kwargs):
3029
"""

0 commit comments

Comments
 (0)