Skip to content

Commit 20c81b9

Browse files
committed
Update on "Autoquant"
Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f98800eb760> f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.07677136734128 1311.548416 0 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff12050> f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff27b80> f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.6846738075837493 3.1023880932480097 2144.447488 25 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f40> f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986cfd33a0> f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 3.1286065466701984 2.210085652768612 2144.447488 22 <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9888dd2b30> f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.329739556647837 2228.913664 39 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9880133fd0> f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f988175b610> f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.4762858413159847 3.2240213360637426 2228.913664 38 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9881782f80> f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f70> f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 7f4150f commit 20c81b9

File tree

6 files changed

+147
-96
lines changed

6 files changed

+147
-96
lines changed

test/test.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
compute_error as SQNR,
5555
_fqn_to_op_to_shape_to_count,
5656
LoggingTensorMode,
57+
benchmark
5758
)
5859
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
5960
import os
@@ -1198,22 +1199,38 @@ def test_on_dummy_distilbert(self):
11981199

11991200
class TestAutoQuant(unittest.TestCase):
12001201
def test_auto_quant(self):
1201-
model = torch.nn.Sequential(
1202-
# torch.nn.Linear(5120,1280),
1203-
# torch.nn.ReLU(),
1204-
torch.nn.Linear(1280,3840),
1205-
torch.nn.ReLU(),
1206-
torch.nn.Linear(3840,1280),
1207-
torch.nn.ReLU(),
1208-
).to("cuda").to(torch.bfloat16)
1209-
example_input = torch.randn(65536, 1280, device="cuda", dtype=torch.bfloat16)
12101202
torch._inductor.config.epilogue_fusion = False
12111203
torch._inductor.config.use_mixed_mm = True
12121204
torch._inductor.config.force_fuse_int_mm_with_mul = True
12131205
torch._inductor.config.coordinate_descent_tuning = True
12141206
torch._dynamo.config.automatic_dynamic_shapes = False
1215-
torch._dynamo.reset() # TODO use in autoquantizer
1216-
do_autoquant(model, example_input)
1207+
1208+
for m,k,n in [
1209+
(1, 1024, 1024),
1210+
(64, 1024, 1024),
1211+
(4096, 1024, 1024),
1212+
(1, 1024, 4096),
1213+
(64, 1024, 4096),
1214+
(1, 4096, 1024),
1215+
(64, 4096, 1024),
1216+
(4096, 4096, 1024),
1217+
]:
1218+
print("testing", m, k, n)
1219+
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
1220+
model = torch.nn.Sequential(
1221+
# torch.nn.ReLU(),
1222+
torch.nn.Linear(k,n),
1223+
# torch.nn.ReLU(),
1224+
# torch.nn.Linear(1280,3840),
1225+
# torch.nn.ReLU(),
1226+
# torch.nn.Linear(3840,1280),
1227+
# torch.nn.ReLU(),
1228+
# torch.nn.Linear(1280,1024),
1229+
# torch.nn.ReLU(),
1230+
# torch.nn.Linear(1024,4096),
1231+
# torch.nn.ReLU(),
1232+
).to("cuda").to(torch.bfloat16)
1233+
do_autoquant(model, example_input)
12171234

12181235
if __name__ == "__main__":
12191236
unittest.main()

torchao/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"dynamically_quantize_per_channel",
2626
"dequantize_per_tensor",
2727
"dequantize_per_channel",
28+
"do_autoquant",
29+
"change_linears_to_autoquantizable",
30+
"change_autoquantizable_to_quantized",
2831
"quant_int8_dynamic_linear",
2932
"quant_int8_matmul",
3033
"quant_int8_dynamic_per_token_linear",

torchao/quantization/autoquant.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
)
88
from torch.utils._python_dispatch import return_and_correct_aliasing
99
from .utils import benchmark
10+
from .quant_primitives import (
11+
quantize_activation_per_token_absmax,
12+
safe_int_mm,
13+
)
14+
import torch.nn.functional as F
1015

1116
aten = torch.ops.aten
1217

@@ -70,23 +75,30 @@ def tune_autoquant(self, q_cls):
7075
with torch.no_grad():
7176
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
7277
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
73-
print(q_cls, self.logged_shape, self.logged_dtype)
74-
print("mem", torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage())
7578
res = q_cls._autoquant_test(act_mat, self.weight, bias)
7679
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)
7780

78-
def to_quantized(self):
79-
if self.logged_shape is None or self.logged_dtype is None:
81+
def to_quantized(self, error_on_unseen, **kwargs):
82+
if error_on_unseen and (self.logged_shape is None or self.logged_dtype is None):
8083
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
84+
elif (self.logged_shape is None or self.logged_dtype is None) and not error_on_unseen:
85+
# default back to non-quantized weight if not seen
86+
self = AQFloatLinearWeight.from_float(self.weight)
87+
return self
8188
best_time = torch.inf
8289
best_cls = None
90+
do_print=False
8391
for q_cls in self.qtensor_class_list:
8492
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
93+
do_print=True
8594
self.tune_autoquant(q_cls)
95+
torch._dynamo.reset()
8696
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
8797
if best_time >= cls_res:
8898
best_time = cls_res
8999
best_cls = q_cls
100+
if do_print:
101+
print(f"shape={self.logged_shape}, dtype={self.logged_dtype}, best_cls={best_cls}")
90102
# TODO handle random cls args/kwargs? or should they be curried
91103
self = best_cls.from_float(self.weight)
92104
return self
@@ -132,26 +144,93 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
132144
if func is aten.detach.default:
133145
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
134146

135-
136-
class DefaultLinear(torch.Tensor):
147+
class AQMixin():
137148
"""
138-
An class to be used in concert with AutoQuantizableLinearWeight to provide a
139-
default/non-quantized option. Only implements the bare minimum needed to work with the
140-
AutoQuantizableLinearWeight class using the same interfaces that would normally be
141-
used by QTensor subclasses but for a default linear op instead.
149+
Mixin to turn normal quantized subclasses into autoquantizable ones
142150
"""
143-
def __init__(self):
144-
super().__init__()
145-
146151
@classmethod
147152
def _autoquant_test(cls, act_mat, weight, bias):
148153
w_qtensor = cls.from_float(weight)
149-
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
154+
func = lambda act_mat, w_qtensor, bias: F.relu(cls._quantized_op(F.relu(act_mat), w_qtensor, bias))
155+
q_c_op = torch.compile(func, mode="max-autotune")
156+
# q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
150157
with torch.no_grad():
151-
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
158+
torch.cuda.synchronize()
159+
res = benchmark(q_c_op, act_mat, w_qtensor, bias)
152160
print(cls, res)
153161
return res
154162

163+
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
164+
"""
165+
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
166+
"""
167+
@classmethod
168+
def _autoquant_test(cls, act_mat, weight, bias):
169+
res = super()._autoquant_test(act_mat, weight, bias)
170+
w_qtensor = cls.from_float(weight)
171+
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
172+
act_mat.reshape(-1, act_mat.shape[-1])
173+
)
174+
quantized_matmul = (
175+
lambda x_vals_int8, x_scales, w_vals_int8:
176+
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
177+
)
178+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
179+
with torch.no_grad():
180+
res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
181+
print(cls, "matmul", res2)
182+
# for SAM best is between .458-.499, SDXL .45=3.094 .47=2.880 .48=3.036 .5=2.930
183+
return res
184+
185+
186+
class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
187+
"""
188+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
189+
"""
190+
191+
class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
192+
"""
193+
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
194+
uses a different kernel
195+
"""
196+
@staticmethod
197+
def _quantized_op(act_mat, w_qtensor, bias):
198+
orig_dtype = act_mat.dtype
199+
orig_shape = act_mat.shape
200+
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
201+
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
202+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
203+
if bias is not None:
204+
y += bias
205+
return y.to(orig_dtype)
206+
207+
@classmethod
208+
def _autoquant_test(cls, act_mat, weight, bias):
209+
# if act_mat has batchsize>2 don't use this kernel
210+
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2:
211+
return torch.inf
212+
return super()._autoquant_test(act_mat, weight, bias)
213+
214+
class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
215+
def _quantized_op(act_mat, w_qtensor, bias):
216+
orig_shape = act_mat.shape
217+
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
218+
y=y.reshape(*orig_shape[:-1], y.shape[-1])
219+
if bias is not None:
220+
y += bias
221+
return y
222+
223+
224+
class AQFloatLinearWeight(torch.Tensor, AQMixin):
225+
"""
226+
A class to be used in concert with AutoQuantizableLinearWeight to provide a
227+
default/non-quantized option. Only implements the bare minimum needed to work with the
228+
AutoQuantizableLinearWeight class using the same interfaces that would normally be
229+
used by QTensor subclasses but for a default linear op instead.
230+
"""
231+
def __init__(self):
232+
super().__init__()
233+
155234
@staticmethod
156235
def _quantized_op(act_mat, w_qtensor, bias):
157236
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
@@ -161,10 +240,11 @@ def from_float(cls, weight):
161240
return weight
162241

163242
DEFAULT_CLASS_LIST = [
164-
Int8DynamicallyQuantizedLinearWeight,
165-
DefaultLinear,
166-
Int8WeightOnlyQuantizedLinearWeight,
167-
243+
AQFloatLinearWeight,
244+
AQInt8DynamicallyQuantizedLinearWeight,
245+
AQWeightOnlyQuantizedLinearWeight,
246+
AQWeightOnlyQuantizedLinearWeight2,
247+
AQWeightOnlyQuantizedLinearWeight3,
168248
]
169249

170250
if False:

torchao/quantization/quant_api.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
"change_linear_weights_to_int8_dqtensors",
3737
"change_linear_weights_to_int8_woqtensors",
3838
"change_linear_weights_to_int4_woqtensors",
39-
"swap_conv2d_1x1_to_linear"
39+
"swap_conv2d_1x1_to_linear",
40+
"do_autoquant",
41+
"change_linears_to_autoquantizable",
42+
"change_autoquantizable_to_quantized",
4043
]
4144

4245

@@ -159,6 +162,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
159162

160163
def change_linears_to_autoquantizable(model, **kwargs):
161164
filter_fn = kwargs.pop("filter_fn", _is_linear)
165+
kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST)
162166
_replace_with_custom_fn_if_matches_filter(
163167
model,
164168
_get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs),
@@ -172,22 +176,27 @@ def change_autoquantizable_to_quantized(model, **kwargs):
172176
_is_linear(mod, *args) and
173177
isinstance(mod.weight, AutoQuantizableLinearWeight)
174178
)
179+
error_on_unseen=kwargs.pop("error_on_unseen", True)
175180
_replace_with_custom_fn_if_matches_filter(
176181
model,
177182
_get_subclass_inserter(
178-
AutoQuantizableLinearWeight, method="to_quantized", **kwargs
183+
AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs
179184
),
180185
filter_fn,
181186
)
182187

183188
@torch.no_grad()
184189
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear):
190+
hold = torch._dynamo.config.automatic_dynamic_shapes
191+
torch._dynamo.config.automatic_dynamic_shapes = False
185192
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list)
186193
if not isinstance(example_input, (tuple, list)):
187194
assert isinstance(example_input, torch.Tensor)
188195
example_input = [example_input]
189196
model(*example_input)
190197
change_autoquantizable_to_quantized(model)
198+
torch._dynamo.config.automatic_dynamic_shapes = hold
199+
torch._dynamo.reset()
191200
return model
192201

193202
def swap_conv2d_1x1_to_linear(model, filter_fn=None):

torchao/quantization/subclass.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -200,27 +200,6 @@ def _quantized_op(act_mat, w_qtensor, bias):
200200
act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype
201201
)
202202

203-
@classmethod
204-
def _autoquant_test(cls, act_mat, weight, bias):
205-
w_qtensor = cls.from_float(weight)
206-
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
207-
with torch.no_grad():
208-
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
209-
210-
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
211-
act_mat.reshape(-1, act_mat.shape[-1])
212-
)
213-
quantized_matmul = (
214-
lambda x_vals_int8, x_scales, w_vals_int8:
215-
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
216-
)
217-
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
218-
with torch.no_grad():
219-
res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
220-
print(cls, res, res2)
221-
breakpoint()
222-
return (res+res2)/2
223-
224203
def dequantize(self, dtype=None):
225204
"""
226205
Obtain the dequantized version of the quantized tensor subclass
@@ -293,7 +272,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
293272
# however the external representation of our tensor will maintain the correct
294273
# shape attribute which needs to be tracked directly.
295274
int_data = w_int_repr.contiguous().t()
296-
if cls is not Int8DynamicallyQuantizedLinearWeight:
275+
if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight):
297276
int_data = int_data.contiguous()
298277
return cls(
299278
int_data, w_scales, False, input_float.shape, dtype=input_float.dtype
@@ -316,26 +295,6 @@ def _quantized_op(act_mat, w_qtensor, bias):
316295
y += bias
317296
return y.to(orig_dtype)
318297

319-
@classmethod
320-
def _autoquant_test(cls, act_mat, weight, bias):
321-
w_qtensor = cls.from_float(weight)
322-
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
323-
with torch.no_grad():
324-
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
325-
326-
quantized_matmul = (
327-
lambda act_mat, w_vals_int8:
328-
torch.mm(act_mat, w_vals_int8.to(act_mat.dtype))
329-
)
330-
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
331-
with torch.no_grad():
332-
res2=benchmark(
333-
q_c_matmul,
334-
act_mat.reshape(-1, act_mat.shape[-1]),
335-
w_qtensor.int_data)
336-
337-
print(cls, res, res2)
338-
return (res+res2)/2
339298

340299
class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase):
341300
"""

torchao/quantization/utils.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -88,31 +88,14 @@ def get_model_size_in_bytes(model):
8888
s += b.nelement() * b.element_size()
8989
return s
9090

91-
import time
92-
93-
def benchmark_torch_function(iters, f, *args, **kwargs):
94-
f(*args, **kwargs)
95-
f(*args, **kwargs)
96-
f(*args, **kwargs)
97-
if torch.cuda.is_available():
98-
torch.cuda.synchronize()
99-
start_event = torch.cuda.Event(enable_timing=True)
100-
end_event = torch.cuda.Event(enable_timing=True)
101-
start_event.record()
102-
else:
103-
t0 = time.time()
104-
for i in range(iters):
105-
f(*args, **kwargs)
106-
if torch.cuda.is_available():
107-
end_event.record()
108-
torch.cuda.synchronize()
109-
return start_event.elapsed_time(end_event)
110-
else:
111-
return (time.time() - t0)
11291

11392
def benchmark(f, *args, **kwargs):
11493
t0 = Timer(
11594
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
11695
)
96+
11797
# warmup
118-
return benchmark_torch_function(10, f, *args, **kwargs)
98+
t0.timeit(10)
99+
100+
res=t0.blocked_autorange(min_run_time=.5)
101+
return res.median * 1e3

0 commit comments

Comments
 (0)