|
12 | 12 | import torch.nn as nn
|
13 | 13 | from torch._inductor.utils import run_and_get_code
|
14 | 14 | from torch._dynamo import config
|
15 |
| -import torchao |
16 | 15 | from torch.ao.quantization import MinMaxObserver, QConfigMapping
|
17 | 16 |
|
18 | 17 | from torchao.quantization.dynamic_quant import (
|
|
55 | 54 | _fqn_to_op_to_shape_to_count,
|
56 | 55 | LoggingTensorMode,
|
57 | 56 | )
|
58 |
| -from torchao.quantization.autoquant import ( |
59 |
| - AQInt8DynamicallyQuantizedLinearWeight, |
60 |
| - AQWeightOnlyQuantizedLinearWeight, |
61 |
| - AQWeightOnlyQuantizedLinearWeight2, |
62 |
| - AQWeightOnlyQuantizedLinearWeight3 |
63 |
| - |
64 |
| -) |
65 | 57 | from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
|
66 | 58 | import os
|
67 | 59 |
|
@@ -888,36 +880,6 @@ def test_int8_weight_only_quant_subclass(self):
|
888 | 880 | Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
|
889 | 881 | )
|
890 | 882 |
|
891 |
| - def test_aq_int8_dynamic_quant_subclass(self): |
892 |
| - for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
893 |
| - self._test_lin_weight_subclass_impl( |
894 |
| - AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype |
895 |
| - ) |
896 |
| - |
897 |
| - def test_aq_int8_weight_only_quant_subclass(self): |
898 |
| - for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
899 |
| - self._test_lin_weight_subclass_impl( |
900 |
| - AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype |
901 |
| - ) |
902 |
| - |
903 |
| - def test_aq_int8_weight_only_quant_subclass(self): |
904 |
| - for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
905 |
| - self._test_lin_weight_subclass_impl( |
906 |
| - AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype |
907 |
| - ) |
908 |
| - |
909 |
| - def test_aq_int8_weight_only_quant_2_subclass(self): |
910 |
| - for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
911 |
| - self._test_lin_weight_subclass_impl( |
912 |
| - AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype |
913 |
| - ) |
914 |
| - |
915 |
| - def test_aq_int8_weight_only_quant_3_subclass(self): |
916 |
| - for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: |
917 |
| - self._test_lin_weight_subclass_impl( |
918 |
| - AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype |
919 |
| - ) |
920 |
| - |
921 | 883 | def test_int4_weight_only_quant_subclass(self):
|
922 | 884 | self._test_lin_weight_subclass_impl(
|
923 | 885 | Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
|
@@ -1233,51 +1195,6 @@ def test_on_dummy_distilbert(self):
|
1233 | 1195 | print("sqnr_pt_quant", sqnr_pt_quant)
|
1234 | 1196 | self.assertTrue(sqnr_sq >= 8.0)
|
1235 | 1197 |
|
1236 |
| -class TestAutoQuant(unittest.TestCase): |
1237 |
| - def test_autoquant_one_input(self): |
1238 |
| - torch._inductor.config.epilogue_fusion = False |
1239 |
| - torch._inductor.config.use_mixed_mm = True |
1240 |
| - torch._inductor.config.force_fuse_int_mm_with_mul = True |
1241 |
| - torch._dynamo.config.automatic_dynamic_shapes = False |
1242 |
| - |
1243 |
| - for m,k,n in [ |
1244 |
| - (1, 1024, 1024), |
1245 |
| - (64, 1024, 1024), |
1246 |
| - (2**15, 1024, 1024), |
1247 |
| - (1, 1024, 4096), |
1248 |
| - (64, 1024, 4096), |
1249 |
| - (1, 4096, 1024), |
1250 |
| - (64, 4096, 1024), |
1251 |
| - (4096, 4096, 1024), |
1252 |
| - ]: |
1253 |
| - example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) |
1254 |
| - model = torch.nn.Sequential( |
1255 |
| - torch.nn.ReLU(), |
1256 |
| - torch.nn.Linear(k,n), |
1257 |
| - torch.nn.ReLU(), |
1258 |
| - ).to("cuda").to(torch.bfloat16) |
1259 |
| - out = model(example_input) |
1260 |
| - torchao.autoquant(model, example_input) |
1261 |
| - out2 = model(example_input) |
1262 |
| - sqnr = SQNR(out, out2) |
1263 |
| - self.assertTrue(sqnr >= 30) |
1264 |
| - |
1265 |
| - def test_autoquant_multi_input(self): |
1266 |
| - m1, m2, k, n = 1, 8, 1024, 1024 |
1267 |
| - model = torch.nn.Sequential( |
1268 |
| - torch.nn.ReLU(), |
1269 |
| - torch.nn.Linear(k,n), |
1270 |
| - torch.nn.ReLU(), |
1271 |
| - ).cuda().to(torch.bfloat16) |
1272 |
| - example_input = torch.randn(m1, k, device="cuda", dtype=torch.bfloat16) |
1273 |
| - example_input2 = torch.randn(m2, k, device="cuda", dtype=torch.bfloat16) |
1274 |
| - torchao.change_linears_to_autoquantizable(model) |
1275 |
| - out=model(example_input) |
1276 |
| - model(example_input2) |
1277 |
| - torchao.change_autoquantizable_to_quantized(model) |
1278 |
| - out2 = model(example_input) |
1279 |
| - sqnr = SQNR(out, out2) |
1280 |
| - self.assertTrue(sqnr >= 30) |
1281 | 1198 |
|
1282 | 1199 | if __name__ == "__main__":
|
1283 | 1200 | unittest.main()
|
0 commit comments