Skip to content

Commit 29214a9

Browse files
committed
Update on "Autoquant"
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL pytorch-labs/segment-anything-fast#114 HDCharles/sdxl-fast@8d9942a Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 97733c2 commit 29214a9

File tree

7 files changed

+120
-59
lines changed

7 files changed

+120
-59
lines changed

README.md

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,50 @@ The following apis use quantized [tensor subclasses](https://pytorch.org/docs/st
4343

4444
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
4545

46-
### A8W8 Dynamic Quantization
46+
### Autoquantization
4747

48-
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
49-
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
50-
51-
Example
48+
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
49+
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.
5250

5351
```
5452
import torch
55-
from torchao.quantization import quant_api
53+
import torchao
54+
55+
# inductor settings which improve torch.compile runtime for quantized modules
56+
torch._inductor.config.force_fuse_int_mm_with_mul
57+
torch._inductor.config.use_mixed_mm
5658
5759
# some user model and example input
5860
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
5961
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
6062
61-
# convert linear modules to quantized linear modules
62-
quant_api.change_linear_weights_to_int8_dqtensors(model)
63+
# perform autoquantization
64+
torchao.autoquant(model, (input))
6365
6466
# compile the model to improve performance
6567
model = torch.compile(model, mode='max-autotune')
6668
model(input)
6769
```
6870

71+
72+
### A8W8 Dynamic Quantization
73+
74+
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
75+
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
76+
77+
Example
78+
79+
```
80+
# some user model and example input
81+
...
82+
83+
# convert linear modules to quantized linear modules
84+
torchao.change_linear_weights_to_int8_dqtensors(model)
85+
86+
# compile the model to improve performance
87+
...
88+
```
89+
6990
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
7091

7192

@@ -81,7 +102,7 @@ Example
81102
...
82103
83104
# convert linear modules to quantized linear modules
84-
quant_api.change_linear_weights_to_int8_woqtensors(model)
105+
torchao.change_linear_weights_to_int8_woqtensors(model)
85106
86107
# compile the model to improve performance
87108
...
@@ -102,7 +123,7 @@ Example
102123
...
103124
104125
# convert linear modules to quantized linear modules
105-
quant_api.change_linear_weights_to_int4_woqtensors(model)
126+
torchao.change_linear_weights_to_int4_woqtensors(model)
106127
107128
# compile the model to improve performance
108129
...

__init__.py

Whitespace-only changes.

torchao/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from torchao.quantization import (
2+
apply_weight_only_int8_quant,
3+
apply_dynamic_quant,
4+
change_linear_weights_to_int8_dqtensors,
5+
change_linear_weights_to_int8_woqtensors,
6+
change_linear_weights_to_int4_woqtensors,
7+
swap_conv2d_1x1_to_linear,
8+
autoquant,
9+
change_linears_to_autoquantizable,
10+
change_autoquantizable_to_quantized,
11+
)
12+
13+
__all__ = [
14+
"apply_weight_only_int8_quant",
15+
"apply_dynamic_quant",
16+
"change_linear_weights_to_int8_dqtensors",
17+
"change_linear_weights_to_int8_woqtensors",
18+
"change_linear_weights_to_int4_woqtensors",
19+
"swap_conv2d_1x1_to_linear"
20+
"safe_int_mm",
21+
"autoquant",
22+
"change_linears_to_autoquantizable",
23+
"change_autoquantizable_to_quantized",
24+
]

torchao/quantization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"dynamically_quantize_per_channel",
2626
"dequantize_per_tensor",
2727
"dequantize_per_channel",
28-
"do_autoquant",
28+
"autoquant",
2929
"change_linears_to_autoquantizable",
3030
"change_autoquantizable_to_quantized",
3131
"quant_int8_dynamic_linear",

torchao/quantization/autoquant.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import torch
2-
import os
3-
from subprocess import check_output
42
from .subclass import ( # noqa
53
Int8DynamicallyQuantizedLinearWeight,
64
Int8WeightOnlyQuantizedLinearWeight,
@@ -79,26 +77,56 @@ def to_quantized(self, error_on_unseen, **kwargs):
7977
# default back to non-quantized weight if not seen
8078
self = AQFloatLinearWeight.from_float(self.weight)
8179
return self
80+
81+
82+
# only want to do shape+final print a single time if multiple layers
83+
# see/have same shapes so we gate on check_cache being empty for
84+
# at least one of the class/shape combinations.
85+
do_final_print = False
86+
print_once = True
87+
88+
def count_shapes(self, do_print=True):
89+
differe_shape_count=0
90+
for shapes_and_dtype, times_seen in self.logged_data.items():
91+
differe_shape_count += 1
92+
if do_print:
93+
act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype
94+
print(f"activation_shapes: {act_shape}, times_seen: {times_seen}")
95+
if do_print:
96+
print(f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}")
97+
return differe_shape_count
98+
99+
# check each class
82100
best_time = torch.inf
83101
best_cls = None
84-
do_print=False
85-
# check each class
86102
for q_cls in self.qtensor_class_list:
87103
# for each logged shape+dtype, benchmark
88-
cls_res=0
104+
cur_time=0
105+
shape_count = count_shapes(self, do_print=False)
89106
for shapes_and_dtype, times_seen in self.logged_data.items():
90107
if check_cache(q_cls, shapes_and_dtype) is None:
91-
do_print=True
92-
self.tune_autoquant(q_cls, shapes_and_dtype, best_time)
108+
# only do final print if we have to autotune at least one cls/shape pair
109+
do_final_print=True
110+
111+
# only print shapes once
112+
if print_once == True:
113+
print_once = False
114+
count_shapes(self, do_print=True)
115+
116+
time_for_best_shape = check_cache(best_cls, shapes_and_dtype)
117+
time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape
118+
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
93119
torch._dynamo.reset()
94-
cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen
95-
if best_time >= cls_res:
96-
best_time = cls_res
120+
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
121+
if shape_count is not None and shape_count > 1:
122+
print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
123+
if best_time >= cur_time:
124+
best_time = cur_time
97125
best_cls = q_cls
98126
# only print if this is the first time seeing some cls+shape combo,
99127
# otherwise we will print the same thing for every layer.
100-
if do_print:
101-
print(f"for {self.logged_data}, best_cls={best_cls}")
128+
if do_final_print:
129+
print(f"best_cls={best_cls}\n")
102130
# TODO handle random cls args/kwargs? or should they be curried?
103131
self = best_cls.from_float(self.weight)
104132
return self
@@ -145,21 +173,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
145173
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
146174

147175
def do_autoquant_bench(op, *args, **kwargs):
176+
"""
177+
runs benchmark op(*args, **kwargs) avoiding torch.compile overhead
178+
"""
148179
rep = kwargs.pop("rep", 100)
149180
warmup = kwargs.pop("warmup", 25)
150181
with torch.no_grad():
151182
torch.cuda.synchronize()
152183
stream = torch.cuda.Stream()
153184
stream.wait_stream(torch.cuda.current_stream())
154185
with torch.cuda.stream(stream):
155-
op(*args)
186+
op(*args, **kwargs)
156187
stream.synchronize()
157188
torch.cuda.current_stream().wait_stream(stream)
158189
torch.cuda.synchronize()
159190

160191
graph = torch.cuda.CUDAGraph()
161192
with torch.cuda.graph(graph, stream=stream):
162-
op(*args)
193+
op(*args, **kwargs)
163194
res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
164195
return res
165196

@@ -180,11 +211,11 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
180211
else:
181212
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
182213
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
183-
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias)
214+
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
184215
if res < best_time*1.1:
185216
res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900)
186217
res=(res2*.9+res*.1)
187-
print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
218+
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
188219
return res
189220

190221
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
@@ -196,7 +227,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
196227
if not _is_interpolate_mode(mode):
197228
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)
198229

199-
# SAM best is between .8 to 1, SDXL also performs best in this range
230+
# SAM best is between .8 and 1, SDXL also performs best in this range
200231
INTERPOLATION_CONSTANT = mode[1]
201232
w_qtensor = cls.from_float(weight)
202233
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
@@ -209,7 +240,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
209240
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
210241
with torch.no_grad():
211242
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
212-
print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
243+
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
213244

214245
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
215246
if res_matmul>=best_time:
@@ -220,7 +251,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
220251
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
221252
max_int_const_win = (best_time-res_matmul)/(res-res_matmul)
222253
res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
223-
print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
254+
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
224255
return res_f
225256

226257
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
@@ -252,6 +283,10 @@ def _autoquant_test(cls, act_mat, *args):
252283
return super()._autoquant_test(act_mat, *args)
253284

254285
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
286+
"""
287+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
288+
uses a different kernel
289+
"""
255290
def _quantized_op(act_mat, w_qtensor, bias):
256291
orig_shape = act_mat.shape
257292
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
@@ -265,7 +300,8 @@ class AQFloatLinearWeight(torch.Tensor, AQMixin):
265300
A class to be used in concert with AutoQuantizableLinearWeight to provide a
266301
default/non-quantized option. Only implements the bare minimum needed to work with the
267302
AutoQuantizableLinearWeight class using the same interfaces that would normally be
268-
used by QTensor subclasses but for a default linear op instead.
303+
used by QTensor subclasses but for a default linear op instead. Result of from_float
304+
is not a tensor subclass, but rather the float tensor.
269305
"""
270306
def __init__(self):
271307
super().__init__()
@@ -284,5 +320,5 @@ def from_float(cls, weight):
284320
AQWeightOnlyQuantizedLinearWeight,
285321
AQWeightOnlyQuantizedLinearWeight2,
286322
# AQWeightOnlyQuantizedLinearWeight3,
287-
# 3rd version gets picked in situations where it is slower for the interpolation mode
323+
# TODO this gets picked in places where it makes perf worse, why?
288324
]

torchao/quantization/quant_api.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"change_linear_weights_to_int8_woqtensors",
3838
"change_linear_weights_to_int4_woqtensors",
3939
"swap_conv2d_1x1_to_linear",
40-
"do_autoquant",
40+
"autoquant",
4141
"change_linears_to_autoquantizable",
4242
"change_autoquantizable_to_quantized",
4343
]
@@ -182,6 +182,9 @@ def change_autoquantizable_to_quantized(model, **kwargs):
182182
on benchmark results. Expectation is that these modules are
183183
torch.compiled afterwards.
184184
"""
185+
hold = torch._dynamo.config.automatic_dynamic_shapes
186+
torch._dynamo.config.automatic_dynamic_shapes = False
187+
185188
filter_fn = kwargs.pop(
186189
"filter_fn",
187190
lambda mod, *args:
@@ -195,24 +198,22 @@ def change_autoquantizable_to_quantized(model, **kwargs):
195198
),
196199
filter_fn,
197200
)
201+
torch._dynamo.config.automatic_dynamic_shapes = hold
202+
torch._dynamo.reset()
198203

199204
@torch.no_grad()
200-
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs):
205+
def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs):
201206
"""
202207
Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape
203208
across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer
204209
and applies that type of quantization.
205210
"""
206-
hold = torch._dynamo.config.automatic_dynamic_shapes
207-
torch._dynamo.config.automatic_dynamic_shapes = False
208211
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs)
209212
if not isinstance(example_input, (tuple, list)):
210213
assert isinstance(example_input, torch.Tensor)
211214
example_input = [example_input]
212215
model(*example_input)
213216
change_autoquantizable_to_quantized(model, **kwargs)
214-
torch._dynamo.config.automatic_dynamic_shapes = hold
215-
torch._dynamo.reset()
216217
return model
217218

218219
def swap_conv2d_1x1_to_linear(model, filter_fn=None):

torchao/quantization/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
from torch.utils._python_dispatch import TorchDispatchMode
10-
from torch.utils.benchmark import Timer
1110

1211
__all__ = [
1312
"find_multiple",
@@ -87,23 +86,3 @@ def get_model_size_in_bytes(model):
8786
for b in model.buffers():
8887
s += b.nelement() * b.element_size()
8988
return s
90-
91-
92-
def benchmark(f, *args, **kwargs):
93-
if "best_time" in kwargs:
94-
best_time = kwargs.pop("best_time")
95-
else:
96-
best_time = torch.inf
97-
t0 = Timer(
98-
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
99-
)
100-
101-
# warmup
102-
t0.timeit(10)
103-
res=t0.adaptive_autorange(min_run_time=.1)
104-
# run more if median vs median minus iqr (interpolated based on number of runs left) is lower than best_time,
105-
# stop if good res.iqr/res.median or have 20 samples
106-
while res.median-res.iqr+res.iqr*len(res.times)/20 < best_time * 1e-3 and not (res.iqr/res.median<.02 or len(res.times)>=20):
107-
res2 = t0.adaptive_autorange(min_run_time=.5)
108-
res=res.merge([res2, res])[0]
109-
return res.median * 1e3

0 commit comments

Comments
 (0)