Skip to content

Commit 71300c3

Browse files
committed
Autoquant
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL pytorch-labs/segment-anything-fast#114 HDCharles/sdxl-fast@8d9942a Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fddbaf2 Pull Request resolved: #38
1 parent 969038f commit 71300c3

File tree

8 files changed

+514
-13
lines changed

8 files changed

+514
-13
lines changed

README.md

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,50 @@ The following apis use quantized [tensor subclasses](https://pytorch.org/docs/st
4343

4444
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
4545

46-
### A8W8 Dynamic Quantization
46+
### Autoquantization
4747

48-
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
49-
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
50-
51-
Example
48+
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
49+
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.
5250

5351
```
5452
import torch
55-
from torchao.quantization import quant_api
53+
import torchao
54+
55+
# inductor settings which improve torch.compile runtime for quantized modules
56+
torch._inductor.config.force_fuse_int_mm_with_mul
57+
torch._inductor.config.use_mixed_mm
5658
5759
# some user model and example input
5860
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
5961
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
6062
61-
# convert linear modules to quantized linear modules
62-
quant_api.change_linear_weights_to_int8_dqtensors(model)
63+
# perform autoquantization
64+
torchao.autoquant(model, (input))
6365
6466
# compile the model to improve performance
6567
model = torch.compile(model, mode='max-autotune')
6668
model(input)
6769
```
6870

71+
72+
### A8W8 Dynamic Quantization
73+
74+
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
75+
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
76+
77+
Example
78+
79+
```
80+
# some user model and example input
81+
...
82+
83+
# convert linear modules to quantized linear modules
84+
torchao.change_linear_weights_to_int8_dqtensors(model)
85+
86+
# compile the model to improve performance
87+
...
88+
```
89+
6990
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
7091

7192

@@ -81,7 +102,7 @@ Example
81102
...
82103
83104
# convert linear modules to quantized linear modules
84-
quant_api.change_linear_weights_to_int8_woqtensors(model)
105+
torchao.change_linear_weights_to_int8_woqtensors(model)
85106
86107
# compile the model to improve performance
87108
...
@@ -102,7 +123,7 @@ Example
102123
...
103124
104125
# convert linear modules to quantized linear modules
105-
quant_api.change_linear_weights_to_int4_woqtensors(model)
126+
torchao.change_linear_weights_to_int4_woqtensors(model)
106127
107128
# compile the model to improve performance
108129
...

__init__.py

Whitespace-only changes.

test/test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
change_linear_weights_to_int8_woqtensors,
2525
change_linear_weights_to_int4_woqtensors,
2626
_replace_with_custom_fn_if_matches_filter,
27+
do_autoquant
2728
)
2829
from torchao.quantization.quant_primitives import (
2930
dequantize_per_channel,
@@ -54,6 +55,13 @@
5455
_fqn_to_op_to_shape_to_count,
5556
LoggingTensorMode,
5657
)
58+
from torchao.quantization.autoquant import (
59+
AQInt8DynamicallyQuantizedLinearWeight,
60+
AQWeightOnlyQuantizedLinearWeight,
61+
AQWeightOnlyQuantizedLinearWeight2,
62+
AQWeightOnlyQuantizedLinearWeight3
63+
64+
)
5765
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
5866
import os
5967

@@ -880,6 +888,36 @@ def test_int8_weight_only_quant_subclass(self):
880888
Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
881889
)
882890

891+
def test_aq_int8_dynamic_quant_subclass(self):
892+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
893+
self._test_lin_weight_subclass_impl(
894+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
895+
)
896+
897+
def test_aq_int8_weight_only_quant_subclass(self):
898+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
899+
self._test_lin_weight_subclass_impl(
900+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
901+
)
902+
903+
def test_aq_int8_weight_only_quant_subclass(self):
904+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
905+
self._test_lin_weight_subclass_impl(
906+
AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype
907+
)
908+
909+
def test_aq_int8_weight_only_quant_2_subclass(self):
910+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
911+
self._test_lin_weight_subclass_impl(
912+
AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype
913+
)
914+
915+
def test_aq_int8_weight_only_quant_3_subclass(self):
916+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
917+
self._test_lin_weight_subclass_impl(
918+
AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype
919+
)
920+
883921
def test_int4_weight_only_quant_subclass(self):
884922
self._test_lin_weight_subclass_impl(
885923
Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
@@ -1195,6 +1233,34 @@ def test_on_dummy_distilbert(self):
11951233
print("sqnr_pt_quant", sqnr_pt_quant)
11961234
self.assertTrue(sqnr_sq >= 8.0)
11971235

1236+
class TestAutoQuant(unittest.TestCase):
1237+
def test_autoquant(self):
1238+
torch._inductor.config.epilogue_fusion = False
1239+
torch._inductor.config.use_mixed_mm = True
1240+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1241+
torch._dynamo.config.automatic_dynamic_shapes = False
1242+
1243+
for m,k,n in [
1244+
(1, 1024, 1024),
1245+
(64, 1024, 1024),
1246+
(2**15, 1024, 1024),
1247+
(1, 1024, 4096),
1248+
(64, 1024, 4096),
1249+
(1, 4096, 1024),
1250+
(64, 4096, 1024),
1251+
(4096, 4096, 1024),
1252+
]:
1253+
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
1254+
model = torch.nn.Sequential(
1255+
torch.nn.ReLU(),
1256+
torch.nn.Linear(k,n),
1257+
torch.nn.ReLU(),
1258+
).to("cuda").to(torch.bfloat16)
1259+
out = model(example_input)
1260+
do_autoquant(model, example_input)
1261+
out2 = model(example_input)
1262+
sqnr = SQNR(out, out2)
1263+
self.assertTrue(sqnr >= 30)
11981264

11991265
if __name__ == "__main__":
12001266
unittest.main()

torchao/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from torchao.quantization import (
2+
apply_weight_only_int8_quant,
3+
apply_dynamic_quant,
4+
change_linear_weights_to_int8_dqtensors,
5+
change_linear_weights_to_int8_woqtensors,
6+
change_linear_weights_to_int4_woqtensors,
7+
swap_conv2d_1x1_to_linear,
8+
autoquant,
9+
change_linears_to_autoquantizable,
10+
change_autoquantizable_to_quantized,
11+
)
12+
13+
__all__ = [
14+
"apply_weight_only_int8_quant",
15+
"apply_dynamic_quant",
16+
"change_linear_weights_to_int8_dqtensors",
17+
"change_linear_weights_to_int8_woqtensors",
18+
"change_linear_weights_to_int4_woqtensors",
19+
"swap_conv2d_1x1_to_linear"
20+
"safe_int_mm",
21+
"autoquant",
22+
"change_linears_to_autoquantizable",
23+
"change_autoquantizable_to_quantized",
24+
]

torchao/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"dynamically_quantize_per_channel",
2626
"dequantize_per_tensor",
2727
"dequantize_per_channel",
28+
"autoquant",
29+
"change_linears_to_autoquantizable",
30+
"change_autoquantizable_to_quantized",
2831
"quant_int8_dynamic_linear",
2932
"quant_int8_matmul",
3033
"quant_int8_dynamic_per_token_linear",

0 commit comments

Comments
 (0)