Skip to content

Commit 97733c2

Browse files
committed
Update on "Autoquant"
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL pytorch-labs/segment-anything-fast#114 HDCharles/sdxl-fast@8d9942a Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent c6d59e5 commit 97733c2

File tree

3 files changed

+165
-78
lines changed

3 files changed

+165
-78
lines changed

test/test.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@
5454
compute_error as SQNR,
5555
_fqn_to_op_to_shape_to_count,
5656
LoggingTensorMode,
57-
benchmark
57+
)
58+
from torchao.quantization.autoquant import (
59+
AQInt8DynamicallyQuantizedLinearWeight,
60+
AQWeightOnlyQuantizedLinearWeight,
61+
AQWeightOnlyQuantizedLinearWeight2,
62+
AQWeightOnlyQuantizedLinearWeight3
63+
5864
)
5965
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
6066
import os
@@ -882,6 +888,36 @@ def test_int8_weight_only_quant_subclass(self):
882888
Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
883889
)
884890

891+
def test_aq_int8_dynamic_quant_subclass(self):
892+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
893+
self._test_lin_weight_subclass_impl(
894+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
895+
)
896+
897+
def test_aq_int8_weight_only_quant_subclass(self):
898+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
899+
self._test_lin_weight_subclass_impl(
900+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
901+
)
902+
903+
def test_aq_int8_weight_only_quant_subclass(self):
904+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
905+
self._test_lin_weight_subclass_impl(
906+
AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype
907+
)
908+
909+
def test_aq_int8_weight_only_quant_2_subclass(self):
910+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
911+
self._test_lin_weight_subclass_impl(
912+
AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype
913+
)
914+
915+
def test_aq_int8_weight_only_quant_3_subclass(self):
916+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
917+
self._test_lin_weight_subclass_impl(
918+
AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype
919+
)
920+
885921
def test_int4_weight_only_quant_subclass(self):
886922
self._test_lin_weight_subclass_impl(
887923
Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
@@ -1197,19 +1233,17 @@ def test_on_dummy_distilbert(self):
11971233
print("sqnr_pt_quant", sqnr_pt_quant)
11981234
self.assertTrue(sqnr_sq >= 8.0)
11991235

1200-
# TODO FINISH TEST CODE
12011236
class TestAutoQuant(unittest.TestCase):
1202-
def test_auto_quant(self):
1237+
def test_autoquant(self):
12031238
torch._inductor.config.epilogue_fusion = False
12041239
torch._inductor.config.use_mixed_mm = True
12051240
torch._inductor.config.force_fuse_int_mm_with_mul = True
1206-
torch._inductor.config.coordinate_descent_tuning = True
12071241
torch._dynamo.config.automatic_dynamic_shapes = False
12081242

12091243
for m,k,n in [
12101244
(1, 1024, 1024),
12111245
(64, 1024, 1024),
1212-
(4096, 1024, 1024),
1246+
(2**15, 1024, 1024),
12131247
(1, 1024, 4096),
12141248
(64, 1024, 4096),
12151249
(1, 4096, 1024),
@@ -1222,7 +1256,11 @@ def test_auto_quant(self):
12221256
torch.nn.Linear(k,n),
12231257
torch.nn.ReLU(),
12241258
).to("cuda").to(torch.bfloat16)
1259+
out = model(example_input)
12251260
do_autoquant(model, example_input)
1261+
out2 = model(example_input)
1262+
sqnr = SQNR(out, out2)
1263+
self.assertTrue(sqnr >= 30)
12261264

12271265
if __name__ == "__main__":
12281266
unittest.main()

torchao/quantization/autoquant.py

Lines changed: 101 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,34 @@
11
import torch
2-
2+
import os
3+
from subprocess import check_output
34
from .subclass import ( # noqa
45
Int8DynamicallyQuantizedLinearWeight,
56
Int8WeightOnlyQuantizedLinearWeight,
67
QuantizedLinearWeightBase,
78
)
89
from torch.utils._python_dispatch import return_and_correct_aliasing
9-
from .utils import benchmark
1010
from .quant_primitives import (
1111
quantize_activation_per_token_absmax,
1212
safe_int_mm,
1313
)
1414
import torch.nn.functional as F
15-
15+
from torch._inductor.utils import do_bench
1616
aten = torch.ops.aten
1717

1818
AUTOQUANT_CACHE = {}
1919

20-
def check_cache(cls, shape, dtype):
21-
return AUTOQUANT_CACHE.get((cls, shape, dtype), None)
20+
def check_cache(cls, shapes_and_dtype):
21+
return AUTOQUANT_CACHE.get((cls,)+shapes_and_dtype, None)
2222

23-
def update_cache(cls, shape, dtype, res):
24-
AUTOQUANT_CACHE[(cls, shape, dtype)] = res
23+
def update_cache(cls, shapes_and_dtype, res):
24+
AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res
2525

2626
class AutoQuantizableLinearWeight(torch.Tensor):
2727
"""
2828
when run, finds best type of quantization for this tensor and swaps itself with that
2929
"""
3030
@staticmethod
31-
def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
31+
def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs):
3232
kwargs["device"] = weight.device
3333
kwargs["layout"] = (
3434
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
@@ -40,11 +40,11 @@ def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
4040
shape = kwargs.pop("shape", weight.shape)
4141
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
4242

43-
def __init__(self, weight, qtensor_class_list, *args, **kwargs):
43+
def __init__(self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs):
4444
self.weight = weight
4545
self.qtensor_class_list = qtensor_class_list
46-
self.logged_shape = None
47-
self.logged_dtype = None
46+
self.logged_data = {}
47+
self.mode = mode
4848

4949
def __repr__(self):
5050
return (
@@ -54,72 +54,72 @@ def __repr__(self):
5454

5555
@staticmethod
5656
def log_shape(act_mat, w_autoquant, bias):
57-
orig_shape = act_mat.shape
5857
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
59-
logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
6058
logged_dtype = act_mat.dtype
61-
w_autoquant.logged_shape = logged_shape
62-
w_autoquant.logged_dtype = logged_dtype
59+
logged_shapes = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape,)
60+
shapes_and_dtype = logged_shapes + (logged_dtype,)
61+
w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(shapes_and_dtype, 0)
6362
for q_cls in w_autoquant.qtensor_class_list:
64-
if check_cache(q_cls, logged_shape, logged_dtype) is None:
65-
update_cache(q_cls, logged_shape, logged_dtype, None)
66-
y = torch.mm(act_mat, w_autoquant.weight.t())
67-
y = y.reshape(*orig_shape[:-1], y.shape[-1])
68-
if bias is not None:
69-
y += bias
70-
return y
63+
if check_cache(q_cls, shapes_and_dtype) is None:
64+
update_cache(q_cls, shapes_and_dtype, None)
7165

72-
def tune_autoquant(self, q_cls, best_time):
73-
act_shape, w_shape, bias_shape = self.logged_shape
74-
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
66+
def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
67+
act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype
68+
if check_cache(q_cls, shapes_and_dtype) is None:
7569
with torch.no_grad():
76-
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
77-
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
78-
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time)
79-
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)
70+
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
71+
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device)
72+
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
73+
update_cache(q_cls, shapes_and_dtype, res)
8074

8175
def to_quantized(self, error_on_unseen, **kwargs):
82-
if error_on_unseen and (self.logged_shape is None or self.logged_dtype is None):
76+
if error_on_unseen and self.logged_data == {}:
8377
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
84-
elif (self.logged_shape is None or self.logged_dtype is None) and not error_on_unseen:
78+
elif (self.logged_data == {}) and not error_on_unseen:
8579
# default back to non-quantized weight if not seen
8680
self = AQFloatLinearWeight.from_float(self.weight)
87-
return self
81+
return self
8882
best_time = torch.inf
8983
best_cls = None
9084
do_print=False
85+
# check each class
9186
for q_cls in self.qtensor_class_list:
92-
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
93-
do_print=True
94-
self.tune_autoquant(q_cls, best_time)
95-
torch._dynamo.reset()
96-
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
87+
# for each logged shape+dtype, benchmark
88+
cls_res=0
89+
for shapes_and_dtype, times_seen in self.logged_data.items():
90+
if check_cache(q_cls, shapes_and_dtype) is None:
91+
do_print=True
92+
self.tune_autoquant(q_cls, shapes_and_dtype, best_time)
93+
torch._dynamo.reset()
94+
cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen
9795
if best_time >= cls_res:
9896
best_time = cls_res
9997
best_cls = q_cls
98+
# only print if this is the first time seeing some cls+shape combo,
99+
# otherwise we will print the same thing for every layer.
100100
if do_print:
101-
print(f"shape={self.logged_shape}, dtype={self.logged_dtype}, best_cls={best_cls}")
102-
# TODO handle random cls args/kwargs? or should they be curried
101+
print(f"for {self.logged_data}, best_cls={best_cls}")
102+
# TODO handle random cls args/kwargs? or should they be curried?
103103
self = best_cls.from_float(self.weight)
104104
return self
105105

106106
def _apply_fn_to_data(self, fn):
107107
return self.__class__(
108-
fn(self.weight), self.qtensor_class_list, dtype=self.dtype
108+
fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode
109109
)
110110

111111
def __tensor_flatten__(self):
112-
return ["weight"], [self.qtensor_class_list, self.dtype, self.shape]
112+
return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape]
113113

114114
@classmethod
115115
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
116116
weight = tensor_data_dict["weight"]
117-
qtensor_class_list, dtype, shape = tensor_attributes[0]
118-
return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
117+
qtensor_class_list, mode, dtype, shape = tensor_attributes[0]
118+
return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
119119

120120
@classmethod
121-
def from_float(cls, weight, qtensor_class_list):
122-
return cls(weight, qtensor_class_list)
121+
def from_float(cls, weight, qtensor_class_list, **kwargs):
122+
return cls(weight, qtensor_class_list, **kwargs)
123123

124124
@classmethod
125125
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -131,8 +131,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
131131
args[1],
132132
args[2] if len(args)>2 else None
133133
)
134-
return cls.log_shape(mat1, w_autoquant, bias)
135-
134+
cls.log_shape(mat1, w_autoquant, bias)
135+
return func(mat1, w_autoquant.weight, bias)
136136
try:
137137
with torch._C.DisableTorchFunctionSubclass():
138138
return func(*args, **kwargs)
@@ -144,28 +144,60 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
144144
if func is aten.detach.default:
145145
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
146146

147+
def do_autoquant_bench(op, *args, **kwargs):
148+
rep = kwargs.pop("rep", 100)
149+
warmup = kwargs.pop("warmup", 25)
150+
with torch.no_grad():
151+
torch.cuda.synchronize()
152+
stream = torch.cuda.Stream()
153+
stream.wait_stream(torch.cuda.current_stream())
154+
with torch.cuda.stream(stream):
155+
op(*args)
156+
stream.synchronize()
157+
torch.cuda.current_stream().wait_stream(stream)
158+
torch.cuda.synchronize()
159+
160+
graph = torch.cuda.CUDAGraph()
161+
with torch.cuda.graph(graph, stream=stream):
162+
op(*args)
163+
res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
164+
return res
165+
166+
def _is_interpolate_mode(mode):
167+
if isinstance(mode, list) and mode[0]=="interpolate" and len(mode)==2 and isinstance(mode[1], float):
168+
return True
169+
return False
170+
147171
class AQMixin():
148172
"""
149173
Mixin to turn normal quantized subclasses into autoquantizable ones
150174
"""
151175
@classmethod
152-
def _autoquant_test(cls, act_mat, weight, bias, best_time, *args, **kwargs):
176+
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
153177
w_qtensor = cls.from_float(weight)
154-
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
155-
with torch.no_grad():
156-
torch.cuda.synchronize()
157-
res = benchmark(q_c_op, act_mat, w_qtensor, bias, best_time=best_time)
158-
print(cls, res)
178+
if _is_interpolate_mode(mode):
179+
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs")
180+
else:
181+
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
182+
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
183+
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias)
184+
if res < best_time*1.1:
185+
res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900)
186+
res=(res2*.9+res*.1)
187+
print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
159188
return res
160189

161190
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
162191
"""
163192
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
164193
"""
165194
@classmethod
166-
def _autoquant_test(cls, act_mat, weight, bias, best_time):
167-
# SAM best is between .51 to .60, SDXL also performs best in this range
168-
INTERPOLATION_CONSTANT=.55
195+
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
196+
if not _is_interpolate_mode(mode):
197+
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)
198+
199+
# SAM best is between .8 to 1, SDXL also performs best in this range
200+
INTERPOLATION_CONSTANT = mode[1]
169201
w_qtensor = cls.from_float(weight)
170202
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
171203
act_mat.reshape(-1, act_mat.shape[-1])
@@ -174,10 +206,10 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time):
174206
lambda x_vals_int8, x_scales, w_vals_int8:
175207
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
176208
)
177-
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
209+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
178210
with torch.no_grad():
179-
res_matmul=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data, best_time=best_time)
180-
print(cls, "matmul", res_matmul)
211+
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
212+
print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
181213

182214
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
183215
if res_matmul>=best_time:
@@ -186,9 +218,10 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time):
186218
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
187219
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
188220
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
189-
print(cls, "full", INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul)
190-
return INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
191-
221+
max_int_const_win = (best_time-res_matmul)/(res-res_matmul)
222+
res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
223+
print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
224+
return res_f
192225

193226
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
194227
"""
@@ -206,17 +239,17 @@ def _quantized_op(act_mat, w_qtensor, bias):
206239
orig_shape = act_mat.shape
207240
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
208241
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
209-
y = y.reshape(*orig_shape[:-1], y.shape[-1])
242+
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales
210243
if bias is not None:
211244
y += bias
212245
return y.to(orig_dtype)
213246

214247
@classmethod
215-
def _autoquant_test(cls, act_mat, weight, bias, best_time):
248+
def _autoquant_test(cls, act_mat, *args):
216249
# if act_mat has batchsize>2 don't use this kernel
217-
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2:
250+
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32:
218251
return torch.inf
219-
return super()._autoquant_test(act_mat, weight, bias, best_time)
252+
return super()._autoquant_test(act_mat, *args)
220253

221254
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
222255
def _quantized_op(act_mat, w_qtensor, bias):
@@ -227,7 +260,6 @@ def _quantized_op(act_mat, w_qtensor, bias):
227260
y += bias
228261
return y
229262

230-
231263
class AQFloatLinearWeight(torch.Tensor, AQMixin):
232264
"""
233265
A class to be used in concert with AutoQuantizableLinearWeight to provide a
@@ -251,5 +283,6 @@ def from_float(cls, weight):
251283
AQInt8DynamicallyQuantizedLinearWeight,
252284
AQWeightOnlyQuantizedLinearWeight,
253285
AQWeightOnlyQuantizedLinearWeight2,
254-
AQWeightOnlyQuantizedLinearWeight3,
286+
# AQWeightOnlyQuantizedLinearWeight3,
287+
# 3rd version gets picked in situations where it is slower for the interpolation mode
255288
]

0 commit comments

Comments
 (0)