Skip to content

Commit 7f4150f

Browse files
committed
Update on "Autoquant"
Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f98800eb760> f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.07677136734128 1311.548416 0 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff12050> f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff27b80> f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.6846738075837493 3.1023880932480097 2144.447488 25 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f40> f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986cfd33a0> f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 3.1286065466701984 2.210085652768612 2144.447488 22 <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9888dd2b30> f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.329739556647837 2228.913664 39 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9880133fd0> f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f988175b610> f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.4762858413159847 3.2240213360637426 2228.913664 38 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9881782f80> f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f70> f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 73ad4a0 commit 7f4150f

File tree

5 files changed

+68
-38
lines changed

5 files changed

+68
-38
lines changed

test/test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,12 +1204,15 @@ def test_auto_quant(self):
12041204
torch.nn.Linear(1280,3840),
12051205
torch.nn.ReLU(),
12061206
torch.nn.Linear(3840,1280),
1207+
torch.nn.ReLU(),
12071208
).to("cuda").to(torch.bfloat16)
1208-
example_input = torch.randn(65536,1280, device="cuda", dtype=torch.bfloat16)
1209+
example_input = torch.randn(65536, 1280, device="cuda", dtype=torch.bfloat16)
12091210
torch._inductor.config.epilogue_fusion = False
12101211
torch._inductor.config.use_mixed_mm = True
12111212
torch._inductor.config.force_fuse_int_mm_with_mul = True
12121213
torch._inductor.config.coordinate_descent_tuning = True
1214+
torch._dynamo.config.automatic_dynamic_shapes = False
1215+
torch._dynamo.reset() # TODO use in autoquantizer
12131216
do_autoquant(model, example_input)
12141217

12151218
if __name__ == "__main__":

torchao/quantization/autoquant.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,11 @@
1212

1313
AUTOQUANT_CACHE = {}
1414

15-
def check_cache(shape, cls):
16-
if shape in AUTOQUANT_CACHE:
17-
return AUTOQUANT_CACHE[shape].get(cls, None)
18-
else:
19-
return None
15+
def check_cache(cls, shape, dtype):
16+
return AUTOQUANT_CACHE.get((cls, shape, dtype), None)
2017

21-
def update_cache(shape, cls, res):
22-
if not shape in AUTOQUANT_CACHE:
23-
AUTOQUANT_CACHE[shape] = {}
24-
AUTOQUANT_CACHE[shape][cls] = res
18+
def update_cache(cls, shape, dtype, res):
19+
AUTOQUANT_CACHE[(cls, shape, dtype)] = res
2520

2621
class AutoQuantizableLinearWeight(torch.Tensor):
2722
"""
@@ -43,7 +38,8 @@ def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
4338
def __init__(self, weight, qtensor_class_list, *args, **kwargs):
4439
self.weight = weight
4540
self.qtensor_class_list = qtensor_class_list
46-
self.cache_shape = None
41+
self.logged_shape = None
42+
self.logged_dtype = None
4743

4844
def __repr__(self):
4945
return (
@@ -52,36 +48,46 @@ def __repr__(self):
5248
)
5349

5450
@staticmethod
55-
def tune_autoquant(act_mat, w_autoquant, bias):
51+
def log_shape(act_mat, w_autoquant, bias):
5652
orig_shape = act_mat.shape
5753
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
58-
cache_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
59-
w_autoquant.cache_shape = cache_shape
60-
for cur_cls in w_autoquant.qtensor_class_list:
61-
if check_cache(cache_shape, cur_cls) is None:
62-
with torch.no_grad():
63-
print(cur_cls, cache_shape)
64-
print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage())
65-
res = cur_cls._autoquant_test(act_mat.clone(), w_autoquant.weight.clone(), None if bias is None else bias.clone())
66-
update_cache(cache_shape, cur_cls, res)
67-
print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage())
54+
logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
55+
logged_dtype = act_mat.dtype
56+
w_autoquant.logged_shape = logged_shape
57+
w_autoquant.logged_dtype = logged_dtype
58+
for q_cls in w_autoquant.qtensor_class_list:
59+
if check_cache(q_cls, logged_shape, logged_dtype) is None:
60+
update_cache(q_cls, logged_shape, logged_dtype, None)
6861
y = torch.mm(act_mat, w_autoquant.weight.t())
6962
y = y.reshape(*orig_shape[:-1], y.shape[-1])
7063
if bias is not None:
7164
y += bias
7265
return y
7366

67+
def tune_autoquant(self, q_cls):
68+
act_shape, w_shape, bias_shape = self.logged_shape
69+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
70+
with torch.no_grad():
71+
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
72+
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
73+
print(q_cls, self.logged_shape, self.logged_dtype)
74+
print("mem", torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage())
75+
res = q_cls._autoquant_test(act_mat, self.weight, bias)
76+
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)
77+
7478
def to_quantized(self):
75-
if self.cache_shape is None or self.cache_shape not in AUTOQUANT_CACHE:
76-
raise RuntimeError("must run module normally to find best quantization option")
79+
if self.logged_shape is None or self.logged_dtype is None:
80+
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
7781
best_time = torch.inf
7882
best_cls = None
79-
for cur_cls in self.qtensor_class_list:
80-
cls_res = AUTOQUANT_CACHE[self.cache_shape].get(cur_cls, torch.inf)
83+
for q_cls in self.qtensor_class_list:
84+
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
85+
self.tune_autoquant(q_cls)
86+
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
8187
if best_time >= cls_res:
8288
best_time = cls_res
83-
best_cls = cur_cls
84-
# need to handle random cls args/kwargs?
89+
best_cls = q_cls
90+
# TODO handle random cls args/kwargs? or should they be curried
8591
self = best_cls.from_float(self.weight)
8692
return self
8793

@@ -113,7 +119,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
113119
args[1],
114120
args[2] if len(args)>2 else None
115121
)
116-
return cls.tune_autoquant(mat1, w_autoquant, bias)
122+
return cls.log_shape(mat1, w_autoquant, bias)
117123

118124
try:
119125
with torch._C.DisableTorchFunctionSubclass():
@@ -155,9 +161,10 @@ def from_float(cls, weight):
155161
return weight
156162

157163
DEFAULT_CLASS_LIST = [
164+
Int8DynamicallyQuantizedLinearWeight,
158165
DefaultLinear,
159166
Int8WeightOnlyQuantizedLinearWeight,
160-
Int8DynamicallyQuantizedLinearWeight,
167+
161168
]
162169

163170
if False:

torchao/quantization/quant_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def change_autoquantizable_to_quantized(model, **kwargs):
184184
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear):
185185
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list)
186186
if not isinstance(example_input, (tuple, list)):
187+
assert isinstance(example_input, torch.Tensor)
187188
example_input = [example_input]
188189
model(*example_input)
189190
change_autoquantizable_to_quantized(model)

torchao/quantization/subclass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def _autoquant_test(cls, act_mat, weight, bias):
206206
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
207207
with torch.no_grad():
208208
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
209+
209210
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
210211
act_mat.reshape(-1, act_mat.shape[-1])
211212
)
@@ -217,6 +218,7 @@ def _autoquant_test(cls, act_mat, weight, bias):
217218
with torch.no_grad():
218219
res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
219220
print(cls, res, res2)
221+
breakpoint()
220222
return (res+res2)/2
221223

222224
def dequantize(self, dtype=None):
@@ -331,8 +333,8 @@ def _autoquant_test(cls, act_mat, weight, bias):
331333
q_c_matmul,
332334
act_mat.reshape(-1, act_mat.shape[-1]),
333335
w_qtensor.int_data)
334-
print(cls, res, res2
335-
)
336+
337+
print(cls, res, res2)
336338
return (res+res2)/2
337339

338340
class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase):

torchao/quantization/utils.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,31 @@ def get_model_size_in_bytes(model):
8888
s += b.nelement() * b.element_size()
8989
return s
9090

91+
import time
92+
93+
def benchmark_torch_function(iters, f, *args, **kwargs):
94+
f(*args, **kwargs)
95+
f(*args, **kwargs)
96+
f(*args, **kwargs)
97+
if torch.cuda.is_available():
98+
torch.cuda.synchronize()
99+
start_event = torch.cuda.Event(enable_timing=True)
100+
end_event = torch.cuda.Event(enable_timing=True)
101+
start_event.record()
102+
else:
103+
t0 = time.time()
104+
for i in range(iters):
105+
f(*args, **kwargs)
106+
if torch.cuda.is_available():
107+
end_event.record()
108+
torch.cuda.synchronize()
109+
return start_event.elapsed_time(end_event)
110+
else:
111+
return (time.time() - t0)
112+
91113
def benchmark(f, *args, **kwargs):
92114
t0 = Timer(
93115
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
94116
)
95117
# warmup
96-
t0.timeit(10).median
97-
t0.blocked_autorange()
98-
res = t0.timeit(20)
99-
print(res)
100-
101-
return res.median * 1e3
118+
return benchmark_torch_function(10, f, *args, **kwargs)

0 commit comments

Comments
 (0)