Skip to content

Commit afe646e

Browse files
committed
Autoquant
Summary: Test Plan: python test/test.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 13ee908 Pull Request resolved: #38
1 parent 969038f commit afe646e

File tree

7 files changed

+427
-4
lines changed

7 files changed

+427
-4
lines changed

test/test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
change_linear_weights_to_int8_woqtensors,
2525
change_linear_weights_to_int4_woqtensors,
2626
_replace_with_custom_fn_if_matches_filter,
27+
do_autoquant
2728
)
2829
from torchao.quantization.quant_primitives import (
2930
dequantize_per_channel,
@@ -53,6 +54,7 @@
5354
compute_error as SQNR,
5455
_fqn_to_op_to_shape_to_count,
5556
LoggingTensorMode,
57+
benchmark
5658
)
5759
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
5860
import os
@@ -1195,6 +1197,40 @@ def test_on_dummy_distilbert(self):
11951197
print("sqnr_pt_quant", sqnr_pt_quant)
11961198
self.assertTrue(sqnr_sq >= 8.0)
11971199

1200+
class TestAutoQuant(unittest.TestCase):
1201+
def test_auto_quant(self):
1202+
torch._inductor.config.epilogue_fusion = False
1203+
torch._inductor.config.use_mixed_mm = True
1204+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1205+
torch._inductor.config.coordinate_descent_tuning = True
1206+
torch._dynamo.config.automatic_dynamic_shapes = False
1207+
1208+
for m,k,n in [
1209+
(1, 1024, 1024),
1210+
(64, 1024, 1024),
1211+
(4096, 1024, 1024),
1212+
(1, 1024, 4096),
1213+
(64, 1024, 4096),
1214+
(1, 4096, 1024),
1215+
(64, 4096, 1024),
1216+
(4096, 4096, 1024),
1217+
]:
1218+
print("testing", m, k, n)
1219+
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
1220+
model = torch.nn.Sequential(
1221+
# torch.nn.ReLU(),
1222+
torch.nn.Linear(k,n),
1223+
# torch.nn.ReLU(),
1224+
# torch.nn.Linear(1280,3840),
1225+
# torch.nn.ReLU(),
1226+
# torch.nn.Linear(3840,1280),
1227+
# torch.nn.ReLU(),
1228+
# torch.nn.Linear(1280,1024),
1229+
# torch.nn.ReLU(),
1230+
# torch.nn.Linear(1024,4096),
1231+
# torch.nn.ReLU(),
1232+
).to("cuda").to(torch.bfloat16)
1233+
do_autoquant(model, example_input)
11981234

11991235
if __name__ == "__main__":
12001236
unittest.main()

test/test_autoquant.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: ignore-errors
8+
import copy
9+
import unittest
10+
11+
import torch
12+
import torch.nn as nn
13+
from torchao.quantization.quant_api import (
14+
change_linears_to_autoquantizable,
15+
change_autoquantizable_to_quantized
16+
)
17+
from torchao.quantization.autoquant import do_autoquant
18+
from torch._dynamo import config
19+
torch.manual_seed(0)
20+
config.cache_size_limit = 100
21+
22+
23+
class AutoquantTests(unittest.TestCase):
24+
def test_autoquant_e2e(self):
25+
model = torch.nn.Sequential(torch.nn.Linear(32,32), torch.nn.ReLU(), torch.nn.Linear(32,32)).cuda().to(torch.bfloat16)
26+
print(model, model[0].weight)
27+
example_input = torch.randn((1,64,32), dtype=torch.bfloat16, device=torch.cuda)
28+
out=model(example_input)
29+
print(out.sum())
30+
do_autoquant(model)
31+
print(model, model[0].weight)
32+
print(model(example_input).sum())
33+
34+
if __name__ == "__main__":
35+
unittest.main()

torchao/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"dynamically_quantize_per_channel",
2626
"dequantize_per_tensor",
2727
"dequantize_per_channel",
28+
"do_autoquant",
29+
"change_linears_to_autoquantizable",
30+
"change_autoquantizable_to_quantized",
2831
"quant_int8_dynamic_linear",
2932
"quant_int8_matmul",
3033
"quant_int8_dynamic_per_token_linear",

torchao/quantization/autoquant.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import torch
2+
3+
from .subclass import ( # noqa
4+
Int8DynamicallyQuantizedLinearWeight,
5+
Int8WeightOnlyQuantizedLinearWeight,
6+
QuantizedLinearWeightBase,
7+
)
8+
from torch.utils._python_dispatch import return_and_correct_aliasing
9+
from .utils import benchmark
10+
from .quant_primitives import (
11+
quantize_activation_per_token_absmax,
12+
safe_int_mm,
13+
)
14+
import torch.nn.functional as F
15+
16+
aten = torch.ops.aten
17+
18+
AUTOQUANT_CACHE = {}
19+
20+
def check_cache(cls, shape, dtype):
21+
return AUTOQUANT_CACHE.get((cls, shape, dtype), None)
22+
23+
def update_cache(cls, shape, dtype, res):
24+
AUTOQUANT_CACHE[(cls, shape, dtype)] = res
25+
26+
class AutoQuantizableLinearWeight(torch.Tensor):
27+
"""
28+
when run, finds best type of quantization for this tensor and swaps itself with that
29+
"""
30+
@staticmethod
31+
def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
32+
kwargs["device"] = weight.device
33+
kwargs["layout"] = (
34+
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
35+
)
36+
kwargs["dtype"] = (
37+
kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype
38+
)
39+
kwargs["requires_grad"] = False
40+
shape = kwargs.pop("shape", weight.shape)
41+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
42+
43+
def __init__(self, weight, qtensor_class_list, *args, **kwargs):
44+
self.weight = weight
45+
self.qtensor_class_list = qtensor_class_list
46+
self.logged_shape = None
47+
self.logged_dtype = None
48+
49+
def __repr__(self):
50+
return (
51+
f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, "
52+
f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})"
53+
)
54+
55+
@staticmethod
56+
def log_shape(act_mat, w_autoquant, bias):
57+
orig_shape = act_mat.shape
58+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
59+
logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
60+
logged_dtype = act_mat.dtype
61+
w_autoquant.logged_shape = logged_shape
62+
w_autoquant.logged_dtype = logged_dtype
63+
for q_cls in w_autoquant.qtensor_class_list:
64+
if check_cache(q_cls, logged_shape, logged_dtype) is None:
65+
update_cache(q_cls, logged_shape, logged_dtype, None)
66+
y = torch.mm(act_mat, w_autoquant.weight.t())
67+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
68+
if bias is not None:
69+
y += bias
70+
return y
71+
72+
def tune_autoquant(self, q_cls):
73+
act_shape, w_shape, bias_shape = self.logged_shape
74+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
75+
with torch.no_grad():
76+
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
77+
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
78+
res = q_cls._autoquant_test(act_mat, self.weight, bias)
79+
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)
80+
81+
def to_quantized(self, error_on_unseen, **kwargs):
82+
if error_on_unseen and (self.logged_shape is None or self.logged_dtype is None):
83+
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
84+
elif (self.logged_shape is None or self.logged_dtype is None) and not error_on_unseen:
85+
# default back to non-quantized weight if not seen
86+
self = AQFloatLinearWeight.from_float(self.weight)
87+
return self
88+
best_time = torch.inf
89+
best_cls = None
90+
do_print=False
91+
for q_cls in self.qtensor_class_list:
92+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
93+
do_print=True
94+
self.tune_autoquant(q_cls)
95+
torch._dynamo.reset()
96+
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
97+
if best_time >= cls_res:
98+
best_time = cls_res
99+
best_cls = q_cls
100+
if do_print:
101+
print(f"shape={self.logged_shape}, dtype={self.logged_dtype}, best_cls={best_cls}")
102+
# TODO handle random cls args/kwargs? or should they be curried
103+
self = best_cls.from_float(self.weight)
104+
return self
105+
106+
def _apply_fn_to_data(self, fn):
107+
return self.__class__(
108+
fn(self.weight), self.qtensor_class_list, dtype=self.dtype
109+
)
110+
111+
def __tensor_flatten__(self):
112+
return ["weight"], [self.qtensor_class_list, self.dtype, self.shape]
113+
114+
@classmethod
115+
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
116+
weight = tensor_data_dict["weight"]
117+
qtensor_class_list, dtype, shape = tensor_attributes[0]
118+
return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
119+
120+
@classmethod
121+
def from_float(cls, weight, qtensor_class_list):
122+
return cls(weight, qtensor_class_list)
123+
124+
@classmethod
125+
def __torch_function__(cls, func, types, args=(), kwargs=None):
126+
kwargs = {} if kwargs is None else kwargs
127+
128+
if func is torch.nn.functional.linear:
129+
mat1, w_autoquant, bias = (
130+
args[0],
131+
args[1],
132+
args[2] if len(args)>2 else None
133+
)
134+
return cls.log_shape(mat1, w_autoquant, bias)
135+
136+
try:
137+
with torch._C.DisableTorchFunctionSubclass():
138+
return func(*args, **kwargs)
139+
except:
140+
print(f"ERR: subclass doesn't implement {func}")
141+
142+
@classmethod
143+
def __torch_dispatch__(cls, func, types, args, kwargs):
144+
if func is aten.detach.default:
145+
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
146+
147+
class AQMixin():
148+
"""
149+
Mixin to turn normal quantized subclasses into autoquantizable ones
150+
"""
151+
@classmethod
152+
def _autoquant_test(cls, act_mat, weight, bias):
153+
w_qtensor = cls.from_float(weight)
154+
func = lambda a, b, c: F.relu(cls._quantized_op(F.relu(a), b, c))
155+
q_c_op = torch.compile(func, mode="max-autotune")
156+
# q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
157+
with torch.no_grad():
158+
torch.cuda.synchronize()
159+
res = benchmark(q_c_op, act_mat, w_qtensor, bias)
160+
print(cls, res)
161+
return res
162+
163+
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
164+
"""
165+
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
166+
"""
167+
@classmethod
168+
def _autoquant_test(cls, act_mat, weight, bias):
169+
res = super()._autoquant_test(act_mat, weight, bias)
170+
w_qtensor = cls.from_float(weight)
171+
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
172+
act_mat.reshape(-1, act_mat.shape[-1])
173+
)
174+
quantized_matmul = (
175+
lambda x_vals_int8, x_scales, w_vals_int8:
176+
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
177+
)
178+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
179+
with torch.no_grad():
180+
res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
181+
print(cls, "matmul", res2)
182+
# for SAM best is between .458-.499, SDXL .45=3.094 .47=2.880 .48=3.036 .5=2.930
183+
return res
184+
185+
186+
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
187+
"""
188+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
189+
"""
190+
191+
class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
192+
"""
193+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
194+
uses a different kernel
195+
"""
196+
@staticmethod
197+
def _quantized_op(act_mat, w_qtensor, bias):
198+
orig_dtype = act_mat.dtype
199+
orig_shape = act_mat.shape
200+
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
201+
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
202+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
203+
if bias is not None:
204+
y += bias
205+
return y.to(orig_dtype)
206+
207+
@classmethod
208+
def _autoquant_test(cls, act_mat, weight, bias):
209+
# if act_mat has batchsize>2 don't use this kernel
210+
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2:
211+
return torch.inf
212+
return super()._autoquant_test(act_mat, weight, bias)
213+
214+
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
215+
def _quantized_op(act_mat, w_qtensor, bias):
216+
orig_shape = act_mat.shape
217+
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
218+
y=y.reshape(*orig_shape[:-1], y.shape[-1])
219+
if bias is not None:
220+
y += bias
221+
return y
222+
223+
224+
class AQFloatLinearWeight(torch.Tensor, AQMixin):
225+
"""
226+
A class to be used in concert with AutoQuantizableLinearWeight to provide a
227+
default/non-quantized option. Only implements the bare minimum needed to work with the
228+
AutoQuantizableLinearWeight class using the same interfaces that would normally be
229+
used by QTensor subclasses but for a default linear op instead.
230+
"""
231+
def __init__(self):
232+
super().__init__()
233+
234+
@staticmethod
235+
def _quantized_op(act_mat, w_qtensor, bias):
236+
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
237+
238+
@classmethod
239+
def from_float(cls, weight):
240+
return weight
241+
242+
DEFAULT_CLASS_LIST = [
243+
AQFloatLinearWeight,
244+
AQInt8DynamicallyQuantizedLinearWeight,
245+
AQWeightOnlyQuantizedLinearWeight,
246+
AQWeightOnlyQuantizedLinearWeight2,
247+
AQWeightOnlyQuantizedLinearWeight3,
248+
]
249+
250+
if False:
251+
# def _get_to_kwargs(self, *args, **kwargs):
252+
# device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
253+
# device = self.device if device is None else device
254+
# dtype = self.dtype if dtype is None else dtype
255+
# memory_format = (
256+
# memory_format if memory_format is not None else torch.preserve_format
257+
# )
258+
# kwargs = {
259+
# "device": device,
260+
# "dtype": dtype,
261+
# "memory_format": memory_format,
262+
# }
263+
# return kwargs
264+
265+
# def to(self, *args, **kwargs):
266+
# kwargs = self._get_to_kwargs(*args, **kwargs)
267+
# return self.__class__(
268+
# self.int_data.to(kwargs["device"]),
269+
# self.q_scales.to(kwargs["device"]),
270+
# self.transposed,
271+
# self.shape,
272+
# **kwargs,
273+
# )
274+
275+
# def _apply_fn_to_data(self, fn):
276+
# return self.__class__(
277+
# fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype
278+
# )
279+
280+
# def _change_shape(self, shape):
281+
# return self.__class__(
282+
# self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype
283+
# )
284+
285+
# def half(self):
286+
# return self.to(torch.float16)
287+
pass

0 commit comments

Comments
 (0)