Skip to content

Commit e317d0f

Browse files
committed
Adding int4 quantized tensor subclass
Summary: Adding int4 quantized tensor subclass support, also refactoring tensor subclass code to be easier to use with multiple subclasses. This subclass uses the tinygemm int4 mixed dtype gemm that was added to pytroch as _weight_int4pack_mm and _convert_weight_to_int4pack. Also added support for .to for tensor subclasses to get the save/loading of meta tensors working for int4. Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e1fdcb9 Pull Request resolved: #15
1 parent afe65e3 commit e317d0f

File tree

7 files changed

+720
-186
lines changed

7 files changed

+720
-186
lines changed

README.md

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,63 +32,87 @@ Relevant APIs can be found in torchao.quantization.quant_api
3232
Note: While these techniques are designed to improve model performance, in some cases the opposite can occur.
3333
This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
3434

35+
The following apis use quantized [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor). By taking a linear op/module and replacing the original weight with a q-tensor subclass, we're able to convert it into a quantized version of the op. Upon replacement, these q-tensor subclasses quantize the original weight and override the dispatch for linear ops to instead use the subclass' _quantized_op method.
36+
37+
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
38+
3539
### A8W8 Dynamic Quantization
3640

37-
Similar to the weight only api above, the `apply_dynamic_quant` function swaps all
38-
linear modules to dynamically quantized quantized linear modules.
41+
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
42+
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
3943

4044
Example
4145

4246
```
47+
import torch
48+
from torchao.quantization import quant_api
4349
4450
# some user model and example input
45-
...
51+
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
52+
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
4653
4754
# convert linear modules to quantized linear modules
48-
quant_api.apply_dynamic_quant(model)
55+
quant_api.change_linear_weights_to_int8_dqtensors(model)
4956
5057
# compile the model to improve performance
51-
...
58+
torch.compile(model, mode='max-autotune')
59+
model(input)
5260
```
5361

5462
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
5563

64+
5665
### A16W8 WeightOnly Quantization
5766

58-
The `apply_weight_only_int8_quant` function swaps all
59-
linear modules to weight-only quantized linear modules.
67+
The `change_linear_weights_to_int8_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8WeightOnlyQuantizedLinearWeight`. In practice this
68+
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
6069

6170
Example
6271

6372
```
64-
import torch
65-
from torchao.quantization import quant_api
66-
6773
# some user model and example input
68-
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
69-
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
74+
...
7075
7176
# convert linear modules to quantized linear modules
72-
quant_api.apply_weight_only_int8_quant(model)
77+
quant_api.change_linear_weights_to_int8_woqtensors(model)
7378
7479
# compile the model to improve performance
75-
torch.compile(model, mode='max-autotune')
76-
model(input)
80+
...
7781
```
7882

7983
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
8084

85+
86+
### A16W4 WeightOnly Quantization
87+
88+
The `change_linear_weights_to_int4_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int4WeightOnlyQuantizedLinearWeight`. In practice this
89+
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
90+
91+
Example
92+
93+
```
94+
# some user model and example input
95+
...
96+
97+
# convert linear modules to quantized linear modules
98+
quant_api.change_linear_weights_to_int4_woqtensors(model)
99+
100+
# compile the model to improve performance
101+
...
102+
```
103+
104+
The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
105+
81106
## Other APIs
82107

83-
### A8W8 Dynamic Quantization by subclasses
108+
### Module Swap APIs
84109

85-
You can use [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) to do dynamic quantization with the `change_linear_weights_to_dqtensors` function using the exact same formula as above. This avoids modifying the graph and can be more composable with
86-
other techniques.
110+
The `apply_dynamic_quant` and `apply_weight_only_int8_quant` apis can be used in the same formula as above to achieve dynamic and weight-only quantization using module swaps instead of quantized tensor subclasses.
87111

88112
### A8W8 Dynamic Quantization with Smoothquant
89113

90114
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above.
91-
Due to requiring calibration, the API is slightly more complicated
115+
Due to requiring calibration, the API is slightly more complicated and currently only exists with a module swap api.
92116

93117
Example
94118

@@ -116,6 +140,8 @@ torch.compile(model, mode='max-autotune')
116140
model(input)
117141
```
118142

143+
like the other dynamic quantization apis, the torch._inductor.config.force_fuse_int_mm_with_mul option may significantly improve performance if enabled.
144+
119145
## License
120146

121147
`torchao` is released under the [BSD 3](https://github.com/pytorch-labs/ao/blob/main/LICENSE) license.

0 commit comments

Comments
 (0)