Skip to content

Commit 7995409

Browse files
committed
Autoquant
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL (pytorch-labs/segment-anything-fast#114, huggingface/diffusion-fast@176e85f) Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3986099 Pull Request resolved: #38
1 parent 969038f commit 7995409

File tree

6 files changed

+356
-3
lines changed

6 files changed

+356
-3
lines changed

test/test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
change_linear_weights_to_int8_woqtensors,
2525
change_linear_weights_to_int4_woqtensors,
2626
_replace_with_custom_fn_if_matches_filter,
27+
do_autoquant
2728
)
2829
from torchao.quantization.quant_primitives import (
2930
dequantize_per_channel,
@@ -53,6 +54,7 @@
5354
compute_error as SQNR,
5455
_fqn_to_op_to_shape_to_count,
5556
LoggingTensorMode,
57+
benchmark
5658
)
5759
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
5860
import os
@@ -1195,6 +1197,32 @@ def test_on_dummy_distilbert(self):
11951197
print("sqnr_pt_quant", sqnr_pt_quant)
11961198
self.assertTrue(sqnr_sq >= 8.0)
11971199

1200+
# TODO FINISH TEST CODE
1201+
class TestAutoQuant(unittest.TestCase):
1202+
def test_auto_quant(self):
1203+
torch._inductor.config.epilogue_fusion = False
1204+
torch._inductor.config.use_mixed_mm = True
1205+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1206+
torch._inductor.config.coordinate_descent_tuning = True
1207+
torch._dynamo.config.automatic_dynamic_shapes = False
1208+
1209+
for m,k,n in [
1210+
(1, 1024, 1024),
1211+
(64, 1024, 1024),
1212+
(4096, 1024, 1024),
1213+
(1, 1024, 4096),
1214+
(64, 1024, 4096),
1215+
(1, 4096, 1024),
1216+
(64, 4096, 1024),
1217+
(4096, 4096, 1024),
1218+
]:
1219+
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
1220+
model = torch.nn.Sequential(
1221+
torch.nn.ReLU(),
1222+
torch.nn.Linear(k,n),
1223+
torch.nn.ReLU(),
1224+
).to("cuda").to(torch.bfloat16)
1225+
do_autoquant(model, example_input)
11981226

11991227
if __name__ == "__main__":
12001228
unittest.main()

torchao/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"dynamically_quantize_per_channel",
2626
"dequantize_per_tensor",
2727
"dequantize_per_channel",
28+
"do_autoquant",
29+
"change_linears_to_autoquantizable",
30+
"change_autoquantizable_to_quantized",
2831
"quant_int8_dynamic_linear",
2932
"quant_int8_matmul",
3033
"quant_int8_dynamic_per_token_linear",

torchao/quantization/autoquant.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import torch
2+
3+
from .subclass import ( # noqa
4+
Int8DynamicallyQuantizedLinearWeight,
5+
Int8WeightOnlyQuantizedLinearWeight,
6+
QuantizedLinearWeightBase,
7+
)
8+
from torch.utils._python_dispatch import return_and_correct_aliasing
9+
from .utils import benchmark
10+
from .quant_primitives import (
11+
quantize_activation_per_token_absmax,
12+
safe_int_mm,
13+
)
14+
import torch.nn.functional as F
15+
16+
aten = torch.ops.aten
17+
18+
AUTOQUANT_CACHE = {}
19+
20+
def check_cache(cls, shape, dtype):
21+
return AUTOQUANT_CACHE.get((cls, shape, dtype), None)
22+
23+
def update_cache(cls, shape, dtype, res):
24+
AUTOQUANT_CACHE[(cls, shape, dtype)] = res
25+
26+
class AutoQuantizableLinearWeight(torch.Tensor):
27+
"""
28+
when run, finds best type of quantization for this tensor and swaps itself with that
29+
"""
30+
@staticmethod
31+
def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
32+
kwargs["device"] = weight.device
33+
kwargs["layout"] = (
34+
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
35+
)
36+
kwargs["dtype"] = (
37+
kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype
38+
)
39+
kwargs["requires_grad"] = False
40+
shape = kwargs.pop("shape", weight.shape)
41+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
42+
43+
def __init__(self, weight, qtensor_class_list, *args, **kwargs):
44+
self.weight = weight
45+
self.qtensor_class_list = qtensor_class_list
46+
self.logged_shape = None
47+
self.logged_dtype = None
48+
49+
def __repr__(self):
50+
return (
51+
f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, "
52+
f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})"
53+
)
54+
55+
@staticmethod
56+
def log_shape(act_mat, w_autoquant, bias):
57+
orig_shape = act_mat.shape
58+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
59+
logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
60+
logged_dtype = act_mat.dtype
61+
w_autoquant.logged_shape = logged_shape
62+
w_autoquant.logged_dtype = logged_dtype
63+
for q_cls in w_autoquant.qtensor_class_list:
64+
if check_cache(q_cls, logged_shape, logged_dtype) is None:
65+
update_cache(q_cls, logged_shape, logged_dtype, None)
66+
y = torch.mm(act_mat, w_autoquant.weight.t())
67+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
68+
if bias is not None:
69+
y += bias
70+
return y
71+
72+
def tune_autoquant(self, q_cls, best_time):
73+
act_shape, w_shape, bias_shape = self.logged_shape
74+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
75+
with torch.no_grad():
76+
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
77+
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
78+
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time)
79+
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)
80+
81+
def to_quantized(self, error_on_unseen, **kwargs):
82+
if error_on_unseen and (self.logged_shape is None or self.logged_dtype is None):
83+
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
84+
elif (self.logged_shape is None or self.logged_dtype is None) and not error_on_unseen:
85+
# default back to non-quantized weight if not seen
86+
self = AQFloatLinearWeight.from_float(self.weight)
87+
return self
88+
best_time = torch.inf
89+
best_cls = None
90+
do_print=False
91+
for q_cls in self.qtensor_class_list:
92+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
93+
do_print=True
94+
self.tune_autoquant(q_cls, best_time)
95+
torch._dynamo.reset()
96+
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
97+
if best_time >= cls_res:
98+
best_time = cls_res
99+
best_cls = q_cls
100+
if do_print:
101+
print(f"shape={self.logged_shape}, dtype={self.logged_dtype}, best_cls={best_cls}")
102+
# TODO handle random cls args/kwargs? or should they be curried
103+
self = best_cls.from_float(self.weight)
104+
return self
105+
106+
def _apply_fn_to_data(self, fn):
107+
return self.__class__(
108+
fn(self.weight), self.qtensor_class_list, dtype=self.dtype
109+
)
110+
111+
def __tensor_flatten__(self):
112+
return ["weight"], [self.qtensor_class_list, self.dtype, self.shape]
113+
114+
@classmethod
115+
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
116+
weight = tensor_data_dict["weight"]
117+
qtensor_class_list, dtype, shape = tensor_attributes[0]
118+
return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
119+
120+
@classmethod
121+
def from_float(cls, weight, qtensor_class_list):
122+
return cls(weight, qtensor_class_list)
123+
124+
@classmethod
125+
def __torch_function__(cls, func, types, args=(), kwargs=None):
126+
kwargs = {} if kwargs is None else kwargs
127+
128+
if func is torch.nn.functional.linear:
129+
mat1, w_autoquant, bias = (
130+
args[0],
131+
args[1],
132+
args[2] if len(args)>2 else None
133+
)
134+
return cls.log_shape(mat1, w_autoquant, bias)
135+
136+
try:
137+
with torch._C.DisableTorchFunctionSubclass():
138+
return func(*args, **kwargs)
139+
except:
140+
print(f"ERR: subclass doesn't implement {func}")
141+
142+
@classmethod
143+
def __torch_dispatch__(cls, func, types, args, kwargs):
144+
if func is aten.detach.default:
145+
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
146+
147+
class AQMixin():
148+
"""
149+
Mixin to turn normal quantized subclasses into autoquantizable ones
150+
"""
151+
@classmethod
152+
def _autoquant_test(cls, act_mat, weight, bias, best_time, *args, **kwargs):
153+
w_qtensor = cls.from_float(weight)
154+
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
155+
with torch.no_grad():
156+
torch.cuda.synchronize()
157+
res = benchmark(q_c_op, act_mat, w_qtensor, bias, best_time=best_time)
158+
print(cls, res)
159+
return res
160+
161+
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
162+
"""
163+
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
164+
"""
165+
@classmethod
166+
def _autoquant_test(cls, act_mat, weight, bias, best_time):
167+
# SAM best is between .51 to .60, SDXL also performs best in this range
168+
INTERPOLATION_CONSTANT=.55
169+
w_qtensor = cls.from_float(weight)
170+
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
171+
act_mat.reshape(-1, act_mat.shape[-1])
172+
)
173+
quantized_matmul = (
174+
lambda x_vals_int8, x_scales, w_vals_int8:
175+
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
176+
)
177+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
178+
with torch.no_grad():
179+
res_matmul=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data, best_time=best_time)
180+
print(cls, "matmul", res_matmul)
181+
182+
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
183+
if res_matmul>=best_time:
184+
return res_matmul
185+
186+
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
187+
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
188+
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
189+
print(cls, "full", INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul)
190+
return INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
191+
192+
193+
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
194+
"""
195+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
196+
"""
197+
198+
class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
199+
"""
200+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
201+
uses a different kernel
202+
"""
203+
@staticmethod
204+
def _quantized_op(act_mat, w_qtensor, bias):
205+
orig_dtype = act_mat.dtype
206+
orig_shape = act_mat.shape
207+
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
208+
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
209+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
210+
if bias is not None:
211+
y += bias
212+
return y.to(orig_dtype)
213+
214+
@classmethod
215+
def _autoquant_test(cls, act_mat, weight, bias, best_time):
216+
# if act_mat has batchsize>2 don't use this kernel
217+
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2:
218+
return torch.inf
219+
return super()._autoquant_test(act_mat, weight, bias, best_time)
220+
221+
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
222+
def _quantized_op(act_mat, w_qtensor, bias):
223+
orig_shape = act_mat.shape
224+
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
225+
y=y.reshape(*orig_shape[:-1], y.shape[-1])
226+
if bias is not None:
227+
y += bias
228+
return y
229+
230+
231+
class AQFloatLinearWeight(torch.Tensor, AQMixin):
232+
"""
233+
A class to be used in concert with AutoQuantizableLinearWeight to provide a
234+
default/non-quantized option. Only implements the bare minimum needed to work with the
235+
AutoQuantizableLinearWeight class using the same interfaces that would normally be
236+
used by QTensor subclasses but for a default linear op instead.
237+
"""
238+
def __init__(self):
239+
super().__init__()
240+
241+
@staticmethod
242+
def _quantized_op(act_mat, w_qtensor, bias):
243+
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
244+
245+
@classmethod
246+
def from_float(cls, weight):
247+
return weight
248+
249+
DEFAULT_CLASS_LIST = [
250+
AQFloatLinearWeight,
251+
AQInt8DynamicallyQuantizedLinearWeight,
252+
AQWeightOnlyQuantizedLinearWeight,
253+
AQWeightOnlyQuantizedLinearWeight2,
254+
AQWeightOnlyQuantizedLinearWeight3,
255+
]

torchao/quantization/quant_api.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,18 @@
2828
from .weight_only import (
2929
WeightOnlyInt8QuantLinear,
3030
)
31+
from .autoquant import AutoQuantizableLinearWeight, DEFAULT_CLASS_LIST
3132

3233
__all__ = [
3334
"apply_weight_only_int8_quant",
3435
"apply_dynamic_quant",
3536
"change_linear_weights_to_int8_dqtensors",
3637
"change_linear_weights_to_int8_woqtensors",
3738
"change_linear_weights_to_int4_woqtensors",
38-
"swap_conv2d_1x1_to_linear"
39+
"swap_conv2d_1x1_to_linear",
40+
"do_autoquant",
41+
"change_linears_to_autoquantizable",
42+
"change_autoquantizable_to_quantized",
3943
]
4044

4145

@@ -95,9 +99,11 @@ def apply_dynamic_quant(model, filter_fn=None):
9599

96100

97101
def _get_subclass_inserter(cls, **kwargs):
102+
method = kwargs.pop("method", "from_float")
98103
def insert_subclass(lin):
99104
lin.weight = torch.nn.Parameter(
100-
cls.from_float(lin.weight, **kwargs), requires_grad=False
105+
# cls.from_float(...)
106+
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
101107
)
102108
return lin
103109

@@ -153,6 +159,46 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
153159
filter_fn,
154160
)
155161

162+
163+
def change_linears_to_autoquantizable(model, **kwargs):
164+
filter_fn = kwargs.pop("filter_fn", _is_linear)
165+
kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST)
166+
_replace_with_custom_fn_if_matches_filter(
167+
model,
168+
_get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs),
169+
filter_fn if filter_fn is not None else _is_linear,
170+
)
171+
172+
def change_autoquantizable_to_quantized(model, **kwargs):
173+
filter_fn = kwargs.pop(
174+
"filter_fn",
175+
lambda mod, *args:
176+
_is_linear(mod, *args) and
177+
isinstance(mod.weight, AutoQuantizableLinearWeight)
178+
)
179+
error_on_unseen=kwargs.pop("error_on_unseen", True)
180+
_replace_with_custom_fn_if_matches_filter(
181+
model,
182+
_get_subclass_inserter(
183+
AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs
184+
),
185+
filter_fn,
186+
)
187+
188+
@torch.no_grad()
189+
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear):
190+
hold = torch._dynamo.config.automatic_dynamic_shapes
191+
torch._dynamo.config.automatic_dynamic_shapes = False
192+
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list)
193+
if not isinstance(example_input, (tuple, list)):
194+
assert isinstance(example_input, torch.Tensor)
195+
example_input = [example_input]
196+
model(*example_input)
197+
change_autoquantizable_to_quantized(model)
198+
torch._dynamo.config.automatic_dynamic_shapes = hold
199+
torch._dynamo.reset()
200+
return model
201+
156202
def swap_conv2d_1x1_to_linear(model, filter_fn=None):
157203
"""
158204
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.

0 commit comments

Comments
 (0)