Skip to content

Commit 1553bd4

Browse files
committed
Update on "Add GPTQQuantizer"
Summary: Implement GPTQQuantizer with the unified quantizer API Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents 215d07b + a224003 commit 1553bd4

File tree

7 files changed

+176
-91
lines changed

7 files changed

+176
-91
lines changed

README.md

Lines changed: 46 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1-
# torchao
1+
# torchao: PyTorch Architecture Optimization
22

3-
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the APIs.**
3+
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue**
4+
5+
The `torchao` package allows you to quantize and prune your models using native PyTorch.
6+
7+
The repo hosts both
8+
1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4
9+
2. Quantization [algorithms](./torchao/quantization) such as dynamic quant, smoothquant
10+
3. Sparsity [algorithms](./torchao/sparsity) such as Wanda
11+
12+
## Success stories
13+
Our kernels have has been used to achieve SOTA inference performance on
14+
15+
1. Image segmentation modelss with [sam-fast](pytorch.org/blog/accelerating-generative-ai)
16+
2. Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2)
17+
3. Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3)
418

5-
The torchao package contains apis and workflows used to apply AO techniques like quantization and pruning to models using only native pytorch.
619

720
## Installation
821

@@ -18,43 +31,23 @@ pip install torchao
1831
```Shell
1932
git clone https://github.com/pytorch-labs/ao
2033
cd ao
21-
python setup.py install
22-
```
23-
24-
Verify Installation:
25-
26-
```Shell
27-
pip list | grep torchao
28-
```
29-
30-
Expected Output
31-
```Shell
32-
torchao 0.0.1 <install dir>
34+
pip install -e .
3335
```
3436

35-
## Usage
37+
## Examples
3638

37-
Relevant APIs can be found in torchao.quantization.quant_api
38-
39-
Note: While these techniques are designed to improve model performance, in some cases the opposite can occur.
40-
This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
41-
42-
The following apis use quantized [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor). By taking a linear op/module and replacing the original weight with a q-tensor subclass, we're able to convert it into a quantized version of the op. Upon replacement, these q-tensor subclasses quantize the original weight and override the dispatch for linear ops to instead use the subclass' _quantized_op method.
43-
44-
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
39+
Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change.
4540

4641
### A8W8 Dynamic Quantization
4742

48-
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
49-
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
50-
51-
Example
52-
5343
```Python
5444
import torch
5545
from torchao.quantization import quant_api
5646

57-
# some user model and example input
47+
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
48+
torch._inductor.config.force_fuse_int_mm_with_mul = True
49+
50+
# Plug in your model and example input
5851
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
5952
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
6053

@@ -66,78 +59,54 @@ model = torch.compile(model, mode='max-autotune')
6659
model(input)
6760
```
6861

69-
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
70-
71-
7262
### A16W8 WeightOnly Quantization
7363

74-
The `change_linear_weights_to_int8_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8WeightOnlyQuantizedLinearWeight`. In practice this
75-
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
76-
77-
Example
78-
79-
```Python
80-
# some user model and example input
81-
...
82-
83-
# convert linear modules to quantized linear modules
64+
```python
8465
quant_api.change_linear_weights_to_int8_woqtensors(model)
85-
86-
# compile the model to improve performance
87-
...
8866
```
8967

9068
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
9169

9270

9371
### A16W4 WeightOnly Quantization
9472

95-
The `change_linear_weights_to_int4_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int4WeightOnlyQuantizedLinearWeight`. In practice this
96-
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
97-
98-
Example
99-
100-
```Python
101-
# some user model and example input
102-
...
103-
104-
# convert linear modules to quantized linear modules
73+
```python
10574
quant_api.change_linear_weights_to_int4_woqtensors(model)
106-
107-
# compile the model to improve performance
108-
...
10975
```
11076

111-
The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
112-
113-
## Other APIs
77+
Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
11478

115-
### Module Swap APIs
116-
117-
The `apply_dynamic_quant` and `apply_weight_only_int8_quant` apis can be used in the same formula as above to achieve dynamic and weight-only quantization using module swaps instead of quantized tensor subclasses.
11879

11980
### A8W8 Dynamic Quantization with Smoothquant
12081

121-
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above.
122-
Due to requiring calibration, the API is slightly more complicated and currently only exists with a module swap api.
82+
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above. Due to requiring calibration, the API is more complicated.
12383

12484
Example
12585

12686
```Python
12787
import torch
12888
from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference
12989

130-
# some user model
90+
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
91+
torch._inductor.config.force_fuse_int_mm_with_mul = True
92+
93+
# plug in your model
13194
model = get_model()
13295

13396
# convert linear modules to smoothquant
13497
# linear module in calibration mode
13598
swap_linear_with_smooth_fq_linear(model)
13699

137-
# calibration
138-
for i in range(calibration_amount):
139-
input = get_input()
140-
model(input)
100+
# Create a data loader for calibration
101+
calibration_data = get_calibration_data()
102+
calibration_dataset = MyDataset(calibration_data)
103+
calibration_loader = DataLoader(calibration_dataset, batch_size=32, shuffle=True)
104+
105+
# Calibrate the model
106+
model.train()
107+
for batch in calibration_loader:
108+
inputs = batch
109+
model(inputs)
141110

142111
# set it to inference mode
143112
smooth_fq_linear_to_inference(model)
@@ -147,7 +116,11 @@ model = torch.compile(model, mode='max-autotune')
147116
model(input)
148117
```
149118

150-
like the other dynamic quantization apis, the torch._inductor.config.force_fuse_int_mm_with_mul option may significantly improve performance if enabled.
119+
## Sharp edges
120+
121+
1. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
122+
2. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
123+
151124

152125
## License
153126

test/dtypes/test_uint4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
compute_error,
1919
)
2020
from torchao.quantization.quant_api import (
21-
replace_with_custom_fn_if_matches_filter,
21+
_replace_with_custom_fn_if_matches_filter,
2222
)
2323
from torch.ao.quantization.observer import ObserverBase
2424
from torch import nn
@@ -36,7 +36,7 @@ def fn(mod):
3636
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False)
3737
return mod
3838

39-
replace_with_custom_fn_if_matches_filter(
39+
_replace_with_custom_fn_if_matches_filter(
4040
model,
4141
lambda mod: fn(mod),
4242
lambda mod, fqn: isinstance(mod, torch.nn.Linear),

test/modules/test_nf4_linear.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import torch
55
from torch import nn
66
from torch.testing._internal.common_utils import TestCase
7-
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
7+
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4
88
import torch.nn.functional as F
9-
9+
import io
10+
from collections import OrderedDict
1011

1112
bnb_available = False
1213

@@ -44,11 +45,19 @@ def _build_bnb_linear(input_weight, device):
4445

4546

4647
class TestNF4Linear(TestCase):
48+
class TestMod(nn.Module):
49+
def __init__(self, tensor, block_size, scaler_block_size):
50+
super().__init__()
51+
self.param = torch.nn.Parameter(to_nf4(tensor, block_size, scaler_block_size))
52+
53+
def save_state_dict_to_buffer(self, state_dict: OrderedDict):
54+
buffer = io.BytesIO()
55+
torch.save(state_dict, buffer)
56+
buffer.seek(0)
57+
return buffer
4758

4859
def test_register_nf4_as_param(self):
49-
nf4_tensor = NF4Tensor.from_tensor(
50-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
51-
)
60+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
5261

5362
# Would raise if nn.Parameter registration fails, such as no detach()
5463
# impl when calling __torch_dispatch__
@@ -58,18 +67,14 @@ def test_register_nf4_as_param(self):
5867
def test_output_bf16(self):
5968
# Test to ensure W4 A16 produces A16
6069
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
61-
nf4_tensor = NF4Tensor.from_tensor(
62-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
63-
)
70+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
6471
out = linear_nf4(input=inp, weight=nf4_tensor)
6572
assert out.dtype == torch.bfloat16
6673

6774
def test_backward_bf16(self):
6875
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
6976
# to the linear's weight, as it is frozen.
70-
nf4_tensor = NF4Tensor.from_tensor(
71-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
72-
)
77+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
7378
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
7479
linear_nf4(inp, nf4_tensor).sum().backward()
7580
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
@@ -83,7 +88,7 @@ def test_reconstruction_qlora_vs_bnb(self):
8388
device = "cuda"
8489
embed_dim = 512
8590
input_weight = _build_input_weight(embed_dim, device)
86-
nf4_weight = NF4Tensor.from_tensor(input_weight)
91+
nf4_weight = to_nf4(input_weight)
8792
bnb_linear = _build_bnb_linear(input_weight, device)
8893
bnb_reconstruction = bnb_linear(
8994
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device)
@@ -107,7 +112,7 @@ def test_nf4_bnb_linear(self):
107112
dim = 512
108113
device = "cuda"
109114
input_weight = _build_input_weight(dim, device)
110-
nf4_weight = NF4Tensor.from_tensor(input_weight)
115+
nf4_weight = to_nf4(input_weight)
111116
bnb_linear = _build_bnb_linear(input_weight, device)
112117

113118
inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")
@@ -121,6 +126,56 @@ def test_nf4_bnb_linear(self):
121126
assert err_native < 0.5 * dim
122127
assert err_bnb < 0.5 * dim
123128

129+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
130+
def test_load_from_bfloat16(self):
131+
"""Tests loading to and from different module state dicts"""
132+
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
133+
base_mod = self.TestMod(inpt_tensor, 32, 2)
134+
135+
bf16_dummy_dict = {"param": inpt_tensor}
136+
base_mod.load_state_dict(bf16_dummy_dict)
137+
138+
assert base_mod.param.block_size == 32
139+
assert base_mod.param.scaler_block_size == 2
140+
141+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
142+
def test_load_from_nf4_same_meta(self):
143+
"""Tests loading to and from different module state dicts"""
144+
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
145+
base_mod = self.TestMod(inpt_tensor, 32, 2)
146+
state_dict = base_mod.state_dict()
147+
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
148+
149+
other_mod = self.TestMod(inpt_tensor, 32, 2)
150+
other_mod.load_state_dict(torch.load(saved_state_dict))
151+
assert other_mod.param.block_size == 32
152+
assert other_mod.param.scaler_block_size == 2
153+
154+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
155+
def test_load_from_nf4_diff_meta(self):
156+
"""Tests loading to and from different module state dicts"""
157+
inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16)
158+
base_mod = self.TestMod(inpt_tensor, 32, 2)
159+
state_dict = base_mod.state_dict()
160+
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
161+
162+
other_mod = self.TestMod(inpt_tensor, 64, 1)
163+
other_mod.load_state_dict(torch.load(saved_state_dict))
164+
assert other_mod.param.block_size == 64
165+
assert other_mod.param.scaler_block_size == 1
166+
167+
def test_to_copy(self):
168+
inpt_tensor = torch.rand(128, device='cpu')
169+
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
170+
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
171+
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
172+
173+
if torch.cuda.is_available():
174+
inpt_tensor = torch.rand(128, device='cuda')
175+
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
176+
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
177+
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
178+
124179

125180
if __name__ == "__main__":
126181
unittest.main()

torchao/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import dtypes
2+
3+
__all__ = [
4+
"dtypes"
5+
]

torchao/dtypes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from .nf4tensor import NF4Tensor, to_nf4
12
from .uint4 import UInt4Tensor
23

34
__all__ = [
5+
"NF4Tensor",
6+
"to_nf4",
47
"UInt4Tensor"
58
]

0 commit comments

Comments
 (0)