|
8 | 8 | # This test takes a long time to run
|
9 | 9 | import unittest
|
10 | 10 | import torch
|
| 11 | +import os |
11 | 12 | from torch._export import capture_pre_autograd_graph
|
12 | 13 | from torch.ao.quantization.quantize_pt2e import (
|
13 | 14 | prepare_pt2e,
|
|
18 | 19 | get_symmetric_quantization_config,
|
19 | 20 | )
|
20 | 21 |
|
21 |
| -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
22 |
| -from torchao.quantization.quant_api import apply_dynamic_quant |
23 | 22 | from torchao.quantization.quant_api import (
|
| 23 | + _replace_with_custom_fn_if_matches_filter, |
| 24 | + apply_dynamic_quant, |
| 25 | + apply_weight_only_int8_quant, |
24 | 26 | Quantizer,
|
25 | 27 | TwoStepQuantizer,
|
26 | 28 | )
|
@@ -137,6 +139,26 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
|
137 | 139 | compiled = m(*example_inputs)
|
138 | 140 | torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
|
139 | 141 |
|
| 142 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 143 | + def test_int8_wo_quant_save_load(self): |
| 144 | + m = M().eval().cpu() |
| 145 | + apply_weight_only_int8_quant(m) |
| 146 | + example_inputs = m.example_inputs() |
| 147 | + ref = m(*example_inputs) |
| 148 | + _TMP_FN = "_test.pt" |
| 149 | + torch.save(m.state_dict(), _TMP_FN) |
| 150 | + |
| 151 | + state_dict = torch.load(_TMP_FN) |
| 152 | + os.remove(_TMP_FN) |
| 153 | + m2 = M().eval() |
| 154 | + apply_weight_only_int8_quant(m2) |
| 155 | + m2.load_state_dict(state_dict) |
| 156 | + m2 = m2.to(device="cuda") |
| 157 | + example_inputs = map(lambda x: x.cuda(), example_inputs) |
| 158 | + res = m2(*example_inputs) |
| 159 | + |
| 160 | + torch.testing.assert_close(ref, res.cpu()) |
| 161 | + |
140 | 162 | @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
|
141 | 163 | def test_8da4w_quantizer(self):
|
142 | 164 | from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
|
|
0 commit comments