Skip to content

Commit a37928a

Browse files
committed
Autoquant
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL pytorch-labs/segment-anything-fast#114 HDCharles/sdxl-fast@8d9942a Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0dbb2ff Pull Request resolved: #38
1 parent 969038f commit a37928a

File tree

6 files changed

+443
-3
lines changed

6 files changed

+443
-3
lines changed

test/test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
change_linear_weights_to_int8_woqtensors,
2525
change_linear_weights_to_int4_woqtensors,
2626
_replace_with_custom_fn_if_matches_filter,
27+
do_autoquant
2728
)
2829
from torchao.quantization.quant_primitives import (
2930
dequantize_per_channel,
@@ -54,6 +55,13 @@
5455
_fqn_to_op_to_shape_to_count,
5556
LoggingTensorMode,
5657
)
58+
from torchao.quantization.autoquant import (
59+
AQInt8DynamicallyQuantizedLinearWeight,
60+
AQWeightOnlyQuantizedLinearWeight,
61+
AQWeightOnlyQuantizedLinearWeight2,
62+
AQWeightOnlyQuantizedLinearWeight3
63+
64+
)
5765
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
5866
import os
5967

@@ -880,6 +888,36 @@ def test_int8_weight_only_quant_subclass(self):
880888
Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype
881889
)
882890

891+
def test_aq_int8_dynamic_quant_subclass(self):
892+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
893+
self._test_lin_weight_subclass_impl(
894+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
895+
)
896+
897+
def test_aq_int8_weight_only_quant_subclass(self):
898+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
899+
self._test_lin_weight_subclass_impl(
900+
AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype
901+
)
902+
903+
def test_aq_int8_weight_only_quant_subclass(self):
904+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
905+
self._test_lin_weight_subclass_impl(
906+
AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype
907+
)
908+
909+
def test_aq_int8_weight_only_quant_2_subclass(self):
910+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
911+
self._test_lin_weight_subclass_impl(
912+
AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype
913+
)
914+
915+
def test_aq_int8_weight_only_quant_3_subclass(self):
916+
for test_dtype in [torch.float32, torch.float16, torch.bfloat16]:
917+
self._test_lin_weight_subclass_impl(
918+
AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype
919+
)
920+
883921
def test_int4_weight_only_quant_subclass(self):
884922
self._test_lin_weight_subclass_impl(
885923
Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8]
@@ -1195,6 +1233,34 @@ def test_on_dummy_distilbert(self):
11951233
print("sqnr_pt_quant", sqnr_pt_quant)
11961234
self.assertTrue(sqnr_sq >= 8.0)
11971235

1236+
class TestAutoQuant(unittest.TestCase):
1237+
def test_autoquant(self):
1238+
torch._inductor.config.epilogue_fusion = False
1239+
torch._inductor.config.use_mixed_mm = True
1240+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1241+
torch._dynamo.config.automatic_dynamic_shapes = False
1242+
1243+
for m,k,n in [
1244+
(1, 1024, 1024),
1245+
(64, 1024, 1024),
1246+
(2**15, 1024, 1024),
1247+
(1, 1024, 4096),
1248+
(64, 1024, 4096),
1249+
(1, 4096, 1024),
1250+
(64, 4096, 1024),
1251+
(4096, 4096, 1024),
1252+
]:
1253+
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
1254+
model = torch.nn.Sequential(
1255+
torch.nn.ReLU(),
1256+
torch.nn.Linear(k,n),
1257+
torch.nn.ReLU(),
1258+
).to("cuda").to(torch.bfloat16)
1259+
out = model(example_input)
1260+
do_autoquant(model, example_input)
1261+
out2 = model(example_input)
1262+
sqnr = SQNR(out, out2)
1263+
self.assertTrue(sqnr >= 30)
11981264

11991265
if __name__ == "__main__":
12001266
unittest.main()

torchao/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"dynamically_quantize_per_channel",
2626
"dequantize_per_tensor",
2727
"dequantize_per_channel",
28+
"do_autoquant",
29+
"change_linears_to_autoquantizable",
30+
"change_autoquantizable_to_quantized",
2831
"quant_int8_dynamic_linear",
2932
"quant_int8_matmul",
3033
"quant_int8_dynamic_per_token_linear",

torchao/quantization/autoquant.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import torch
2+
import os
3+
from subprocess import check_output
4+
from .subclass import ( # noqa
5+
Int8DynamicallyQuantizedLinearWeight,
6+
Int8WeightOnlyQuantizedLinearWeight,
7+
QuantizedLinearWeightBase,
8+
)
9+
from torch.utils._python_dispatch import return_and_correct_aliasing
10+
from .quant_primitives import (
11+
quantize_activation_per_token_absmax,
12+
safe_int_mm,
13+
)
14+
import torch.nn.functional as F
15+
from torch._inductor.utils import do_bench
16+
aten = torch.ops.aten
17+
18+
AUTOQUANT_CACHE = {}
19+
20+
def check_cache(cls, shapes_and_dtype):
21+
return AUTOQUANT_CACHE.get((cls,)+shapes_and_dtype, None)
22+
23+
def update_cache(cls, shapes_and_dtype, res):
24+
AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res
25+
26+
class AutoQuantizableLinearWeight(torch.Tensor):
27+
"""
28+
when run, finds best type of quantization for this tensor and swaps itself with that
29+
"""
30+
@staticmethod
31+
def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs):
32+
kwargs["device"] = weight.device
33+
kwargs["layout"] = (
34+
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
35+
)
36+
kwargs["dtype"] = (
37+
kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype
38+
)
39+
kwargs["requires_grad"] = False
40+
shape = kwargs.pop("shape", weight.shape)
41+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
42+
43+
def __init__(self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs):
44+
self.weight = weight
45+
self.qtensor_class_list = qtensor_class_list
46+
self.logged_data = {}
47+
self.mode = mode
48+
49+
def __repr__(self):
50+
return (
51+
f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, "
52+
f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})"
53+
)
54+
55+
@staticmethod
56+
def log_shape(act_mat, w_autoquant, bias):
57+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
58+
logged_dtype = act_mat.dtype
59+
logged_shapes = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape,)
60+
shapes_and_dtype = logged_shapes + (logged_dtype,)
61+
w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(shapes_and_dtype, 0)
62+
for q_cls in w_autoquant.qtensor_class_list:
63+
if check_cache(q_cls, shapes_and_dtype) is None:
64+
update_cache(q_cls, shapes_and_dtype, None)
65+
66+
def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
67+
act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype
68+
if check_cache(q_cls, shapes_and_dtype) is None:
69+
with torch.no_grad():
70+
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
71+
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device)
72+
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
73+
update_cache(q_cls, shapes_and_dtype, res)
74+
75+
def to_quantized(self, error_on_unseen, **kwargs):
76+
if error_on_unseen and self.logged_data == {}:
77+
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
78+
elif (self.logged_data == {}) and not error_on_unseen:
79+
# default back to non-quantized weight if not seen
80+
self = AQFloatLinearWeight.from_float(self.weight)
81+
return self
82+
best_time = torch.inf
83+
best_cls = None
84+
do_print=False
85+
# check each class
86+
for q_cls in self.qtensor_class_list:
87+
# for each logged shape+dtype, benchmark
88+
cls_res=0
89+
for shapes_and_dtype, times_seen in self.logged_data.items():
90+
if check_cache(q_cls, shapes_and_dtype) is None:
91+
do_print=True
92+
self.tune_autoquant(q_cls, shapes_and_dtype, best_time)
93+
torch._dynamo.reset()
94+
cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen
95+
if best_time >= cls_res:
96+
best_time = cls_res
97+
best_cls = q_cls
98+
# only print if this is the first time seeing some cls+shape combo,
99+
# otherwise we will print the same thing for every layer.
100+
if do_print:
101+
print(f"for {self.logged_data}, best_cls={best_cls}")
102+
# TODO handle random cls args/kwargs? or should they be curried?
103+
self = best_cls.from_float(self.weight)
104+
return self
105+
106+
def _apply_fn_to_data(self, fn):
107+
return self.__class__(
108+
fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode
109+
)
110+
111+
def __tensor_flatten__(self):
112+
return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape]
113+
114+
@classmethod
115+
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
116+
weight = tensor_data_dict["weight"]
117+
qtensor_class_list, mode, dtype, shape = tensor_attributes[0]
118+
return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
119+
120+
@classmethod
121+
def from_float(cls, weight, qtensor_class_list, **kwargs):
122+
return cls(weight, qtensor_class_list, **kwargs)
123+
124+
@classmethod
125+
def __torch_function__(cls, func, types, args=(), kwargs=None):
126+
kwargs = {} if kwargs is None else kwargs
127+
128+
if func is torch.nn.functional.linear:
129+
mat1, w_autoquant, bias = (
130+
args[0],
131+
args[1],
132+
args[2] if len(args)>2 else None
133+
)
134+
cls.log_shape(mat1, w_autoquant, bias)
135+
return func(mat1, w_autoquant.weight, bias)
136+
try:
137+
with torch._C.DisableTorchFunctionSubclass():
138+
return func(*args, **kwargs)
139+
except:
140+
print(f"ERR: subclass doesn't implement {func}")
141+
142+
@classmethod
143+
def __torch_dispatch__(cls, func, types, args, kwargs):
144+
if func is aten.detach.default:
145+
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
146+
147+
def do_autoquant_bench(op, *args, **kwargs):
148+
rep = kwargs.pop("rep", 100)
149+
warmup = kwargs.pop("warmup", 25)
150+
with torch.no_grad():
151+
torch.cuda.synchronize()
152+
stream = torch.cuda.Stream()
153+
stream.wait_stream(torch.cuda.current_stream())
154+
with torch.cuda.stream(stream):
155+
op(*args)
156+
stream.synchronize()
157+
torch.cuda.current_stream().wait_stream(stream)
158+
torch.cuda.synchronize()
159+
160+
graph = torch.cuda.CUDAGraph()
161+
with torch.cuda.graph(graph, stream=stream):
162+
op(*args)
163+
res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
164+
return res
165+
166+
def _is_interpolate_mode(mode):
167+
if isinstance(mode, list) and mode[0]=="interpolate" and len(mode)==2 and isinstance(mode[1], float):
168+
return True
169+
return False
170+
171+
class AQMixin():
172+
"""
173+
Mixin to turn normal quantized subclasses into autoquantizable ones
174+
"""
175+
@classmethod
176+
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
177+
w_qtensor = cls.from_float(weight)
178+
if _is_interpolate_mode(mode):
179+
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs")
180+
else:
181+
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
182+
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
183+
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias)
184+
if res < best_time*1.1:
185+
res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900)
186+
res=(res2*.9+res*.1)
187+
print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
188+
return res
189+
190+
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
191+
"""
192+
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
193+
"""
194+
@classmethod
195+
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
196+
if not _is_interpolate_mode(mode):
197+
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)
198+
199+
# SAM best is between .8 to 1, SDXL also performs best in this range
200+
INTERPOLATION_CONSTANT = mode[1]
201+
w_qtensor = cls.from_float(weight)
202+
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
203+
act_mat.reshape(-1, act_mat.shape[-1])
204+
)
205+
quantized_matmul = (
206+
lambda x_vals_int8, x_scales, w_vals_int8:
207+
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
208+
)
209+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
210+
with torch.no_grad():
211+
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
212+
print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
213+
214+
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
215+
if res_matmul>=best_time:
216+
return res_matmul
217+
218+
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
219+
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
220+
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
221+
max_int_const_win = (best_time-res_matmul)/(res-res_matmul)
222+
res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
223+
print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
224+
return res_f
225+
226+
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
227+
"""
228+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
229+
"""
230+
231+
class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
232+
"""
233+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
234+
uses a different kernel
235+
"""
236+
@staticmethod
237+
def _quantized_op(act_mat, w_qtensor, bias):
238+
orig_dtype = act_mat.dtype
239+
orig_shape = act_mat.shape
240+
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
241+
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
242+
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales
243+
if bias is not None:
244+
y += bias
245+
return y.to(orig_dtype)
246+
247+
@classmethod
248+
def _autoquant_test(cls, act_mat, *args):
249+
# if act_mat has batchsize>2 don't use this kernel
250+
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32:
251+
return torch.inf
252+
return super()._autoquant_test(act_mat, *args)
253+
254+
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
255+
def _quantized_op(act_mat, w_qtensor, bias):
256+
orig_shape = act_mat.shape
257+
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
258+
y=y.reshape(*orig_shape[:-1], y.shape[-1])
259+
if bias is not None:
260+
y += bias
261+
return y
262+
263+
class AQFloatLinearWeight(torch.Tensor, AQMixin):
264+
"""
265+
A class to be used in concert with AutoQuantizableLinearWeight to provide a
266+
default/non-quantized option. Only implements the bare minimum needed to work with the
267+
AutoQuantizableLinearWeight class using the same interfaces that would normally be
268+
used by QTensor subclasses but for a default linear op instead.
269+
"""
270+
def __init__(self):
271+
super().__init__()
272+
273+
@staticmethod
274+
def _quantized_op(act_mat, w_qtensor, bias):
275+
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
276+
277+
@classmethod
278+
def from_float(cls, weight):
279+
return weight
280+
281+
DEFAULT_CLASS_LIST = [
282+
AQFloatLinearWeight,
283+
AQInt8DynamicallyQuantizedLinearWeight,
284+
AQWeightOnlyQuantizedLinearWeight,
285+
AQWeightOnlyQuantizedLinearWeight2,
286+
# AQWeightOnlyQuantizedLinearWeight3,
287+
# 3rd version gets picked in situations where it is slower for the interpolation mode
288+
]

0 commit comments

Comments
 (0)