Skip to content

Commit ee063b7

Browse files
committed
Autoquant
Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f98800eb760> f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.07677136734128 1311.548416 0 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff12050> f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff27b80> f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.6846738075837493 3.1023880932480097 2144.447488 25 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f40> f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986cfd33a0> f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 3.1286065466701984 2.210085652768612 2144.447488 22 <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9888dd2b30> f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.329739556647837 2228.913664 39 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9880133fd0> f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f988175b610> f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.4762858413159847 3.2240213360637426 2228.913664 38 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9881782f80> f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f70> f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ac88c07 Pull Request resolved: #38
1 parent 969038f commit ee063b7

File tree

6 files changed

+374
-2
lines changed

6 files changed

+374
-2
lines changed

test/test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
change_linear_weights_to_int8_woqtensors,
2525
change_linear_weights_to_int4_woqtensors,
2626
_replace_with_custom_fn_if_matches_filter,
27+
do_autoquant
2728
)
2829
from torchao.quantization.quant_primitives import (
2930
dequantize_per_channel,
@@ -1195,6 +1196,24 @@ def test_on_dummy_distilbert(self):
11951196
print("sqnr_pt_quant", sqnr_pt_quant)
11961197
self.assertTrue(sqnr_sq >= 8.0)
11971198

1199+
class TestAutoQuant(unittest.TestCase):
1200+
def test_auto_quant(self):
1201+
model = torch.nn.Sequential(
1202+
# torch.nn.Linear(5120,1280),
1203+
# torch.nn.ReLU(),
1204+
torch.nn.Linear(1280,3840),
1205+
torch.nn.ReLU(),
1206+
torch.nn.Linear(3840,1280),
1207+
torch.nn.ReLU(),
1208+
).to("cuda").to(torch.bfloat16)
1209+
example_input = torch.randn(65536, 1280, device="cuda", dtype=torch.bfloat16)
1210+
torch._inductor.config.epilogue_fusion = False
1211+
torch._inductor.config.use_mixed_mm = True
1212+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1213+
torch._inductor.config.coordinate_descent_tuning = True
1214+
torch._dynamo.config.automatic_dynamic_shapes = False
1215+
torch._dynamo.reset() # TODO use in autoquantizer
1216+
do_autoquant(model, example_input)
11981217

11991218
if __name__ == "__main__":
12001219
unittest.main()

test/test_autoquant.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: ignore-errors
8+
import copy
9+
import unittest
10+
11+
import torch
12+
import torch.nn as nn
13+
from torchao.quantization.quant_api import (
14+
change_linears_to_autoquantizable,
15+
change_autoquantizable_to_quantized
16+
)
17+
from torchao.quantization.autoquant import do_autoquant
18+
from torch._dynamo import config
19+
torch.manual_seed(0)
20+
config.cache_size_limit = 100
21+
22+
23+
class AutoquantTests(unittest.TestCase):
24+
def test_autoquant_e2e(self):
25+
model = torch.nn.Sequential(torch.nn.Linear(32,32), torch.nn.ReLU(), torch.nn.Linear(32,32)).cuda().to(torch.bfloat16)
26+
print(model, model[0].weight)
27+
example_input = torch.randn((1,64,32), dtype=torch.bfloat16, device=torch.cuda)
28+
out=model(example_input)
29+
print(out.sum())
30+
do_autoquant(model)
31+
print(model, model[0].weight)
32+
print(model(example_input).sum())
33+
34+
if __name__ == "__main__":
35+
unittest.main()

torchao/quantization/autoquant.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import torch
2+
3+
from .subclass import ( # noqa
4+
Int8DynamicallyQuantizedLinearWeight,
5+
Int8WeightOnlyQuantizedLinearWeight,
6+
QuantizedLinearWeightBase,
7+
)
8+
from torch.utils._python_dispatch import return_and_correct_aliasing
9+
from .utils import benchmark
10+
11+
aten = torch.ops.aten
12+
13+
AUTOQUANT_CACHE = {}
14+
15+
def check_cache(cls, shape, dtype):
16+
return AUTOQUANT_CACHE.get((cls, shape, dtype), None)
17+
18+
def update_cache(cls, shape, dtype, res):
19+
AUTOQUANT_CACHE[(cls, shape, dtype)] = res
20+
21+
class AutoQuantizableLinearWeight(torch.Tensor):
22+
"""
23+
when run, finds best type of quantization for this tensor and swaps itself with that
24+
"""
25+
@staticmethod
26+
def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
27+
kwargs["device"] = weight.device
28+
kwargs["layout"] = (
29+
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
30+
)
31+
kwargs["dtype"] = (
32+
kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype
33+
)
34+
kwargs["requires_grad"] = False
35+
shape = kwargs.pop("shape", weight.shape)
36+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
37+
38+
def __init__(self, weight, qtensor_class_list, *args, **kwargs):
39+
self.weight = weight
40+
self.qtensor_class_list = qtensor_class_list
41+
self.logged_shape = None
42+
self.logged_dtype = None
43+
44+
def __repr__(self):
45+
return (
46+
f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, "
47+
f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})"
48+
)
49+
50+
@staticmethod
51+
def log_shape(act_mat, w_autoquant, bias):
52+
orig_shape = act_mat.shape
53+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
54+
logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
55+
logged_dtype = act_mat.dtype
56+
w_autoquant.logged_shape = logged_shape
57+
w_autoquant.logged_dtype = logged_dtype
58+
for q_cls in w_autoquant.qtensor_class_list:
59+
if check_cache(q_cls, logged_shape, logged_dtype) is None:
60+
update_cache(q_cls, logged_shape, logged_dtype, None)
61+
y = torch.mm(act_mat, w_autoquant.weight.t())
62+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
63+
if bias is not None:
64+
y += bias
65+
return y
66+
67+
def tune_autoquant(self, q_cls):
68+
act_shape, w_shape, bias_shape = self.logged_shape
69+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
70+
with torch.no_grad():
71+
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
72+
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
73+
print(q_cls, self.logged_shape, self.logged_dtype)
74+
print("mem", torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage())
75+
res = q_cls._autoquant_test(act_mat, self.weight, bias)
76+
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)
77+
78+
def to_quantized(self):
79+
if self.logged_shape is None or self.logged_dtype is None:
80+
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
81+
best_time = torch.inf
82+
best_cls = None
83+
for q_cls in self.qtensor_class_list:
84+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
85+
self.tune_autoquant(q_cls)
86+
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
87+
if best_time >= cls_res:
88+
best_time = cls_res
89+
best_cls = q_cls
90+
# TODO handle random cls args/kwargs? or should they be curried
91+
self = best_cls.from_float(self.weight)
92+
return self
93+
94+
def _apply_fn_to_data(self, fn):
95+
return self.__class__(
96+
fn(self.weight), self.qtensor_class_list, dtype=self.dtype
97+
)
98+
99+
def __tensor_flatten__(self):
100+
return ["weight"], [self.qtensor_class_list, self.dtype, self.shape]
101+
102+
@classmethod
103+
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
104+
weight = tensor_data_dict["weight"]
105+
qtensor_class_list, dtype, shape = tensor_attributes[0]
106+
return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
107+
108+
@classmethod
109+
def from_float(cls, weight, qtensor_class_list):
110+
return cls(weight, qtensor_class_list)
111+
112+
@classmethod
113+
def __torch_function__(cls, func, types, args=(), kwargs=None):
114+
kwargs = {} if kwargs is None else kwargs
115+
116+
if func is torch.nn.functional.linear:
117+
mat1, w_autoquant, bias = (
118+
args[0],
119+
args[1],
120+
args[2] if len(args)>2 else None
121+
)
122+
return cls.log_shape(mat1, w_autoquant, bias)
123+
124+
try:
125+
with torch._C.DisableTorchFunctionSubclass():
126+
return func(*args, **kwargs)
127+
except:
128+
print(f"ERR: subclass doesn't implement {func}")
129+
130+
@classmethod
131+
def __torch_dispatch__(cls, func, types, args, kwargs):
132+
if func is aten.detach.default:
133+
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
134+
135+
136+
class DefaultLinear(torch.Tensor):
137+
"""
138+
An class to be used in concert with AutoQuantizableLinearWeight to provide a
139+
default/non-quantized option. Only implements the bare minimum needed to work with the
140+
AutoQuantizableLinearWeight class using the same interfaces that would normally be
141+
used by QTensor subclasses but for a default linear op instead.
142+
"""
143+
def __init__(self):
144+
super().__init__()
145+
146+
@classmethod
147+
def _autoquant_test(cls, act_mat, weight, bias):
148+
w_qtensor = cls.from_float(weight)
149+
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
150+
with torch.no_grad():
151+
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
152+
print(cls, res)
153+
return res
154+
155+
@staticmethod
156+
def _quantized_op(act_mat, w_qtensor, bias):
157+
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
158+
159+
@classmethod
160+
def from_float(cls, weight):
161+
return weight
162+
163+
DEFAULT_CLASS_LIST = [
164+
Int8DynamicallyQuantizedLinearWeight,
165+
DefaultLinear,
166+
Int8WeightOnlyQuantizedLinearWeight,
167+
168+
]
169+
170+
if False:
171+
# def _get_to_kwargs(self, *args, **kwargs):
172+
# device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
173+
# device = self.device if device is None else device
174+
# dtype = self.dtype if dtype is None else dtype
175+
# memory_format = (
176+
# memory_format if memory_format is not None else torch.preserve_format
177+
# )
178+
# kwargs = {
179+
# "device": device,
180+
# "dtype": dtype,
181+
# "memory_format": memory_format,
182+
# }
183+
# return kwargs
184+
185+
# def to(self, *args, **kwargs):
186+
# kwargs = self._get_to_kwargs(*args, **kwargs)
187+
# return self.__class__(
188+
# self.int_data.to(kwargs["device"]),
189+
# self.q_scales.to(kwargs["device"]),
190+
# self.transposed,
191+
# self.shape,
192+
# **kwargs,
193+
# )
194+
195+
# def _apply_fn_to_data(self, fn):
196+
# return self.__class__(
197+
# fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype
198+
# )
199+
200+
# def _change_shape(self, shape):
201+
# return self.__class__(
202+
# self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype
203+
# )
204+
205+
# def half(self):
206+
# return self.to(torch.float16)
207+
pass

torchao/quantization/quant_api.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .weight_only import (
2929
WeightOnlyInt8QuantLinear,
3030
)
31+
from .autoquant import AutoQuantizableLinearWeight, DEFAULT_CLASS_LIST
3132

3233
__all__ = [
3334
"apply_weight_only_int8_quant",
@@ -95,9 +96,11 @@ def apply_dynamic_quant(model, filter_fn=None):
9596

9697

9798
def _get_subclass_inserter(cls, **kwargs):
99+
method = kwargs.pop("method", "from_float")
98100
def insert_subclass(lin):
99101
lin.weight = torch.nn.Parameter(
100-
cls.from_float(lin.weight, **kwargs), requires_grad=False
102+
# cls.from_float(...)
103+
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
101104
)
102105
return lin
103106

@@ -153,6 +156,40 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
153156
filter_fn,
154157
)
155158

159+
160+
def change_linears_to_autoquantizable(model, **kwargs):
161+
filter_fn = kwargs.pop("filter_fn", _is_linear)
162+
_replace_with_custom_fn_if_matches_filter(
163+
model,
164+
_get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs),
165+
filter_fn if filter_fn is not None else _is_linear,
166+
)
167+
168+
def change_autoquantizable_to_quantized(model, **kwargs):
169+
filter_fn = kwargs.pop(
170+
"filter_fn",
171+
lambda mod, *args:
172+
_is_linear(mod, *args) and
173+
isinstance(mod.weight, AutoQuantizableLinearWeight)
174+
)
175+
_replace_with_custom_fn_if_matches_filter(
176+
model,
177+
_get_subclass_inserter(
178+
AutoQuantizableLinearWeight, method="to_quantized", **kwargs
179+
),
180+
filter_fn,
181+
)
182+
183+
@torch.no_grad()
184+
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear):
185+
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list)
186+
if not isinstance(example_input, (tuple, list)):
187+
assert isinstance(example_input, torch.Tensor)
188+
example_input = [example_input]
189+
model(*example_input)
190+
change_autoquantizable_to_quantized(model)
191+
return model
192+
156193
def swap_conv2d_1x1_to_linear(model, filter_fn=None):
157194
"""
158195
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.

0 commit comments

Comments
 (0)