Skip to content

Commit 0823e95

Browse files
committed
Update on "Autoquant"
Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f98800eb760> f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.07677136734128 1311.548416 0 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff12050> f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff27b80> f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.6846738075837493 3.1023880932480097 2144.447488 25 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f40> f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986cfd33a0> f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 3.1286065466701984 2.210085652768612 2144.447488 22 <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9888dd2b30> f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.329739556647837 2228.913664 39 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9880133fd0> f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f988175b610> f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.4762858413159847 3.2240213360637426 2228.913664 38 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9881782f80> f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f70> f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 20c81b9 commit 0823e95

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/quantization/autoquant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class AQMixin():
151151
@classmethod
152152
def _autoquant_test(cls, act_mat, weight, bias):
153153
w_qtensor = cls.from_float(weight)
154-
func = lambda act_mat, w_qtensor, bias: F.relu(cls._quantized_op(F.relu(act_mat), w_qtensor, bias))
154+
func = lambda a, b, c: F.relu(cls._quantized_op(F.relu(a), b, c))
155155
q_c_op = torch.compile(func, mode="max-autotune")
156156
# q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
157157
with torch.no_grad():

0 commit comments

Comments
 (0)