You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary: Adding int4 quantized tensor subclass support, also refactoring tensor
subclass code to be easier to use with multiple subclasses. This
subclass uses the tinygemm int4 mixed dtype gemm that was added to
pytroch as _weight_int4pack_mm and _convert_weight_to_int4pack. Also
added support for .to for tensor subclasses to get the save/loading of
meta tensors working for int4.
Test Plan: python test/test.py
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: e1fdcb9
Pull Request resolved: #15
Copy file name to clipboardExpand all lines: README.md
+45-19Lines changed: 45 additions & 19 deletions
Original file line number
Diff line number
Diff line change
@@ -32,63 +32,87 @@ Relevant APIs can be found in torchao.quantization.quant_api
32
32
Note: While these techniques are designed to improve model performance, in some cases the opposite can occur.
33
33
This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
34
34
35
+
The following apis use quantized [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor). By taking a linear op/module and replacing the original weight with a q-tensor subclass, we're able to convert it into a quantized version of the op. Upon replacement, these q-tensor subclasses quantize the original weight and override the dispatch for linear ops to instead use the subclass' _quantized_op method.
36
+
37
+
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
38
+
35
39
### A8W8 Dynamic Quantization
36
40
37
-
Similar to the weight only api above, the `apply_dynamic_quant` function swaps all
38
-
linear modules to dynamically quantized quantized linear modules.
41
+
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
42
+
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
39
43
40
44
Example
41
45
42
46
```
47
+
import torch
48
+
from torchao.quantization import quant_api
43
49
44
50
# some user model and example input
45
-
...
51
+
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
55
63
64
+
56
65
### A16W8 WeightOnly Quantization
57
66
58
-
The `apply_weight_only_int8_quant` function swaps all
59
-
linear modules to weight-only quantized linear modules.
67
+
The `change_linear_weights_to_int8_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8WeightOnlyQuantizedLinearWeight`. In practice this
68
+
converts the floating point linear matmul of the original linear op to a weightonly quantized linear matmul
60
69
61
70
Example
62
71
63
72
```
64
-
import torch
65
-
from torchao.quantization import quant_api
66
-
67
73
# some user model and example input
68
-
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
80
84
85
+
86
+
### A16W4 WeightOnly Quantization
87
+
88
+
The `change_linear_weights_to_int4_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int4WeightOnlyQuantizedLinearWeight`. In practice this
89
+
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
90
+
91
+
Example
92
+
93
+
```
94
+
# some user model and example input
95
+
...
96
+
97
+
# convert linear modules to quantized linear modules
The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
105
+
81
106
## Other APIs
82
107
83
-
### A8W8 Dynamic Quantization by subclasses
108
+
### Module Swap APIs
84
109
85
-
You can use [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) to do dynamic quantization with the `change_linear_weights_to_dqtensors` function using the exact same formula as above. This avoids modifying the graph and can be more composable with
86
-
other techniques.
110
+
The `apply_dynamic_quant` and `apply_weight_only_int8_quant` apis can be used in the same formula as above to achieve dynamic and weight-only quantization using module swaps instead of quantized tensor subclasses.
87
111
88
112
### A8W8 Dynamic Quantization with Smoothquant
89
113
90
114
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above.
91
-
Due to requiring calibration, the API is slightly more complicated
115
+
Due to requiring calibration, the API is slightly more complicated and currently only exists with a module swap api.
like the other dynamic quantization apis, the torch._inductor.config.force_fuse_int_mm_with_mul option may significantly improve performance if enabled.
144
+
119
145
## License
120
146
121
147
`torchao` is released under the [BSD 3](https://github.com/pytorch-labs/ao/blob/main/LICENSE) license.
0 commit comments