You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary: Adding autoquantization functionality, using hte do_quant api
we can test kernel speeds and pick the best quantization type (or no
quantization) for each layer.
Test Plan: python test/test.py -k "autoquant"
also tested on SAM and SDXL
pytorch-labs/segment-anything-fast#114HDCharles/sdxl-fast@8d9942a
Reviewers:
Subscribers:
Tasks:
Tags:
[ghstack-poisoned]
Copy file name to clipboardExpand all lines: README.md
+31-10Lines changed: 31 additions & 10 deletions
Original file line number
Diff line number
Diff line change
@@ -43,29 +43,50 @@ The following apis use quantized [tensor subclasses](https://pytorch.org/docs/st
43
43
44
44
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
45
45
46
-
### A8W8 Dynamic Quantization
46
+
### Autoquantization
47
47
48
-
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
49
-
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
50
-
51
-
Example
48
+
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
49
+
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.
52
50
53
51
```
54
52
import torch
55
-
from torchao.quantization import quant_api
53
+
import torchao
54
+
55
+
# inductor settings which improve torch.compile runtime for quantized modules
56
+
torch._inductor.config.force_fuse_int_mm_with_mul
57
+
torch._inductor.config.use_mixed_mm
56
58
57
59
# some user model and example input
58
60
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
75
+
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
76
+
77
+
Example
78
+
79
+
```
80
+
# some user model and example input
81
+
...
82
+
83
+
# convert linear modules to quantized linear modules
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
70
91
71
92
@@ -81,7 +102,7 @@ Example
81
102
...
82
103
83
104
# convert linear modules to quantized linear modules
0 commit comments