|
12 | 12 | import torch.nn as nn
|
13 | 13 | from torch._inductor.utils import run_and_get_code
|
14 | 14 | from torch._dynamo import config
|
| 15 | +import torchao |
15 | 16 | from torch.ao.quantization import MinMaxObserver, QConfigMapping
|
16 | 17 |
|
17 | 18 | from torchao.quantization.dynamic_quant import (
|
|
54 | 55 | _fqn_to_op_to_shape_to_count,
|
55 | 56 | LoggingTensorMode,
|
56 | 57 | )
|
| 58 | +from torchao.quantization.autoquant import ( |
| 59 | + AQInt8DynamicallyQuantizedLinearWeight, |
| 60 | + AQWeightOnlyQuantizedLinearWeight, |
| 61 | + AQWeightOnlyQuantizedLinearWeight2, |
| 62 | + AQWeightOnlyQuantizedLinearWeight3 |
| 63 | + |
| 64 | +) |
57 | 65 | from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
|
58 | 66 | import os
|
59 | 67 | from parameterized import parameterized
|
@@ -896,7 +904,49 @@ def _test_lin_weight_subclass_impl(
|
896 | 904 | )
|
897 | 905 |
|
898 | 906 | @parameterized.expand(COMMON_DEVICE_DTYPE)
|
899 |
| - def test_int8_dynamic_quant_subclass(self, device, dtype): |
| 907 | + def test_int8_dynamic_quant_subclass(self): |
| 908 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 909 | + self._test_lin_weight_subclass_impl( |
| 910 | + Int8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype |
| 911 | + ) |
| 912 | + |
| 913 | + def test_int8_weight_only_quant_subclass(self): |
| 914 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 915 | + self._test_lin_weight_subclass_impl( |
| 916 | + Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype |
| 917 | + ) |
| 918 | + |
| 919 | + def test_aq_int8_dynamic_quant_subclass(self): |
| 920 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 921 | + self._test_lin_weight_subclass_impl( |
| 922 | + AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype |
| 923 | + ) |
| 924 | + |
| 925 | + def test_aq_int8_weight_only_quant_subclass(self): |
| 926 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 927 | + self._test_lin_weight_subclass_impl( |
| 928 | + AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype |
| 929 | + ) |
| 930 | + |
| 931 | + def test_aq_int8_weight_only_quant_subclass(self): |
| 932 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 933 | + self._test_lin_weight_subclass_impl( |
| 934 | + AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype |
| 935 | + ) |
| 936 | + |
| 937 | + def test_aq_int8_weight_only_quant_2_subclass(self): |
| 938 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 939 | + self._test_lin_weight_subclass_impl( |
| 940 | + AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype |
| 941 | + ) |
| 942 | + |
| 943 | + def test_aq_int8_weight_only_quant_3_subclass(self): |
| 944 | + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 945 | + self._test_lin_weight_subclass_impl( |
| 946 | + AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype |
| 947 | + ) |
| 948 | + |
| 949 | + def test_int4_weight_only_quant_subclass(self): |
900 | 950 | self._test_lin_weight_subclass_impl(
|
901 | 951 | Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
|
902 | 952 | )
|
@@ -1293,6 +1343,51 @@ def test_on_dummy_distilbert(self):
|
1293 | 1343 | print("sqnr_pt_quant", sqnr_pt_quant)
|
1294 | 1344 | self.assertTrue(sqnr_sq >= 8.0)
|
1295 | 1345 |
|
| 1346 | +class TestAutoQuant(unittest.TestCase): |
| 1347 | + def test_autoquant_one_input(self): |
| 1348 | + torch._inductor.config.epilogue_fusion = False |
| 1349 | + torch._inductor.config.use_mixed_mm = True |
| 1350 | + torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 1351 | + torch._dynamo.config.automatic_dynamic_shapes = False |
| 1352 | + |
| 1353 | + for m,k,n in [ |
| 1354 | + (1, 1024, 1024), |
| 1355 | + (64, 1024, 1024), |
| 1356 | + (2**15, 1024, 1024), |
| 1357 | + (1, 1024, 4096), |
| 1358 | + (64, 1024, 4096), |
| 1359 | + (1, 4096, 1024), |
| 1360 | + (64, 4096, 1024), |
| 1361 | + (4096, 4096, 1024), |
| 1362 | + ]: |
| 1363 | + example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) |
| 1364 | + model = torch.nn.Sequential( |
| 1365 | + torch.nn.ReLU(), |
| 1366 | + torch.nn.Linear(k,n), |
| 1367 | + torch.nn.ReLU(), |
| 1368 | + ).to("cuda").to(torch.bfloat16) |
| 1369 | + out = model(example_input) |
| 1370 | + torchao.autoquant(model, example_input) |
| 1371 | + out2 = model(example_input) |
| 1372 | + sqnr = SQNR(out, out2) |
| 1373 | + self.assertTrue(sqnr >= 30) |
| 1374 | + |
| 1375 | + def test_autoquant_multi_input(self): |
| 1376 | + m1, m2, k, n = 1, 8, 1024, 1024 |
| 1377 | + model = torch.nn.Sequential( |
| 1378 | + torch.nn.ReLU(), |
| 1379 | + torch.nn.Linear(k,n), |
| 1380 | + torch.nn.ReLU(), |
| 1381 | + ).cuda().to(torch.bfloat16) |
| 1382 | + example_input = torch.randn(m1, k, device="cuda", dtype=torch.bfloat16) |
| 1383 | + example_input2 = torch.randn(m2, k, device="cuda", dtype=torch.bfloat16) |
| 1384 | + torchao.change_linears_to_autoquantizable(model) |
| 1385 | + out=model(example_input) |
| 1386 | + model(example_input2) |
| 1387 | + torchao.change_autoquantizable_to_quantized(model) |
| 1388 | + out2 = model(example_input) |
| 1389 | + sqnr = SQNR(out, out2) |
| 1390 | + self.assertTrue(sqnr >= 30) |
1296 | 1391 |
|
1297 | 1392 | if __name__ == "__main__":
|
1298 | 1393 | unittest.main()
|
0 commit comments