From 73ad4a090ac96b234e4280ef884de6172d4c8277 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 22 Feb 2024 11:37:25 -0800 Subject: [PATCH 1/7] Autoquant Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread 3.07677136734128 1311.548416 0 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread 3.6846738075837493 3.1023880932480097 2144.447488 25 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread 3.1286065466701984 2.210085652768612 2144.447488 22 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread 3.329739556647837 2228.913664 39 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread 3.4762858413159847 3.2240213360637426 2228.913664 38 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/test.py | 16 +++ test/test_autoquant.py | 35 ++++++ torchao/quantization/autoquant.py | 200 ++++++++++++++++++++++++++++++ torchao/quantization/quant_api.py | 38 +++++- torchao/quantization/subclass.py | 44 ++++++- torchao/quantization/utils.py | 13 ++ 6 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 test/test_autoquant.py create mode 100644 torchao/quantization/autoquant.py diff --git a/test/test.py b/test/test.py index fe3b3ec8a7..8e8ad04593 100644 --- a/test/test.py +++ b/test/test.py @@ -24,6 +24,7 @@ change_linear_weights_to_int8_woqtensors, change_linear_weights_to_int4_woqtensors, _replace_with_custom_fn_if_matches_filter, + do_autoquant ) from torchao.quantization.quant_primitives import ( dequantize_per_channel, @@ -1195,6 +1196,21 @@ def test_on_dummy_distilbert(self): print("sqnr_pt_quant", sqnr_pt_quant) self.assertTrue(sqnr_sq >= 8.0) +class TestAutoQuant(unittest.TestCase): + def test_auto_quant(self): + model = torch.nn.Sequential( + # torch.nn.Linear(5120,1280), + # torch.nn.ReLU(), + torch.nn.Linear(1280,3840), + torch.nn.ReLU(), + torch.nn.Linear(3840,1280), + ).to("cuda").to(torch.bfloat16) + example_input = torch.randn(65536,1280, device="cuda", dtype=torch.bfloat16) + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.use_mixed_mm = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.coordinate_descent_tuning = True + do_autoquant(model, example_input) if __name__ == "__main__": unittest.main() diff --git a/test/test_autoquant.py b/test/test_autoquant.py new file mode 100644 index 0000000000..5d354185a7 --- /dev/null +++ b/test/test_autoquant.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: ignore-errors +import copy +import unittest + +import torch +import torch.nn as nn +from torchao.quantization.quant_api import ( + change_linears_to_autoquantizable, + change_autoquantizable_to_quantized +) +from torchao.quantization.autoquant import do_autoquant +from torch._dynamo import config +torch.manual_seed(0) +config.cache_size_limit = 100 + + +class AutoquantTests(unittest.TestCase): + def test_autoquant_e2e(self): + model = torch.nn.Sequential(torch.nn.Linear(32,32), torch.nn.ReLU(), torch.nn.Linear(32,32)).cuda().to(torch.bfloat16) + print(model, model[0].weight) + example_input = torch.randn((1,64,32), dtype=torch.bfloat16, device=torch.cuda) + out=model(example_input) + print(out.sum()) + do_autoquant(model) + print(model, model[0].weight) + print(model(example_input).sum()) + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py new file mode 100644 index 0000000000..edbc130816 --- /dev/null +++ b/torchao/quantization/autoquant.py @@ -0,0 +1,200 @@ +import torch + +from .subclass import ( # noqa + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, + QuantizedLinearWeightBase, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from .utils import benchmark + +aten = torch.ops.aten + +AUTOQUANT_CACHE = {} + +def check_cache(shape, cls): + if shape in AUTOQUANT_CACHE: + return AUTOQUANT_CACHE[shape].get(cls, None) + else: + return None + +def update_cache(shape, cls, res): + if not shape in AUTOQUANT_CACHE: + AUTOQUANT_CACHE[shape] = {} + AUTOQUANT_CACHE[shape][cls] = res + +class AutoQuantizableLinearWeight(torch.Tensor): + """ + when run, finds best type of quantization for this tensor and swaps itself with that + """ + @staticmethod + def __new__(cls, weight, qtensor_class_list, *args, **kwargs): + kwargs["device"] = weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout + ) + kwargs["dtype"] = ( + kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype + ) + kwargs["requires_grad"] = False + shape = kwargs.pop("shape", weight.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, weight, qtensor_class_list, *args, **kwargs): + self.weight = weight + self.qtensor_class_list = qtensor_class_list + self.cache_shape = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})" + ) + + @staticmethod + def tune_autoquant(act_mat, w_autoquant, bias): + orig_shape = act_mat.shape + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + cache_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape) + w_autoquant.cache_shape = cache_shape + for cur_cls in w_autoquant.qtensor_class_list: + if check_cache(cache_shape, cur_cls) is None: + with torch.no_grad(): + print(cur_cls, cache_shape) + print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage()) + res = cur_cls._autoquant_test(act_mat.clone(), w_autoquant.weight.clone(), None if bias is None else bias.clone()) + update_cache(cache_shape, cur_cls, res) + print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage()) + y = torch.mm(act_mat, w_autoquant.weight.t()) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y + + def to_quantized(self): + if self.cache_shape is None or self.cache_shape not in AUTOQUANT_CACHE: + raise RuntimeError("must run module normally to find best quantization option") + best_time = torch.inf + best_cls = None + for cur_cls in self.qtensor_class_list: + cls_res = AUTOQUANT_CACHE[self.cache_shape].get(cur_cls, torch.inf) + if best_time >= cls_res: + best_time = cls_res + best_cls = cur_cls + # need to handle random cls args/kwargs? + self = best_cls.from_float(self.weight) + return self + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), self.qtensor_class_list, dtype=self.dtype + ) + + def __tensor_flatten__(self): + return ["weight"], [self.qtensor_class_list, self.dtype, self.shape] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + weight = tensor_data_dict["weight"] + qtensor_class_list, dtype, shape = tensor_attributes[0] + return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + + @classmethod + def from_float(cls, weight, qtensor_class_list): + return cls(weight, qtensor_class_list) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + mat1, w_autoquant, bias = ( + args[0], + args[1], + args[2] if len(args)>2 else None + ) + return cls.tune_autoquant(mat1, w_autoquant, bias) + + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func is aten.detach.default: + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + + +class DefaultLinear(torch.Tensor): + """ + An class to be used in concert with AutoQuantizableLinearWeight to provide a + default/non-quantized option. Only implements the bare minimum needed to work with the + AutoQuantizableLinearWeight class using the same interfaces that would normally be + used by QTensor subclasses but for a default linear op instead. + """ + def __init__(self): + super().__init__() + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias): + w_qtensor = cls.from_float(weight) + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") + with torch.no_grad(): + res=benchmark(q_c_op, act_mat, w_qtensor, bias) + print(cls, res) + return res + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor, bias) + + @classmethod + def from_float(cls, weight): + return weight + +DEFAULT_CLASS_LIST = [ + DefaultLinear, + Int8WeightOnlyQuantizedLinearWeight, + Int8DynamicallyQuantizedLinearWeight, +] + +if False: + # def _get_to_kwargs(self, *args, **kwargs): + # device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + # device = self.device if device is None else device + # dtype = self.dtype if dtype is None else dtype + # memory_format = ( + # memory_format if memory_format is not None else torch.preserve_format + # ) + # kwargs = { + # "device": device, + # "dtype": dtype, + # "memory_format": memory_format, + # } + # return kwargs + + # def to(self, *args, **kwargs): + # kwargs = self._get_to_kwargs(*args, **kwargs) + # return self.__class__( + # self.int_data.to(kwargs["device"]), + # self.q_scales.to(kwargs["device"]), + # self.transposed, + # self.shape, + # **kwargs, + # ) + + # def _apply_fn_to_data(self, fn): + # return self.__class__( + # fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype + # ) + + # def _change_shape(self, shape): + # return self.__class__( + # self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype + # ) + + # def half(self): + # return self.to(torch.float16) + pass diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8e19014eda..9e92396f5d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -28,6 +28,7 @@ from .weight_only import ( WeightOnlyInt8QuantLinear, ) +from .autoquant import AutoQuantizableLinearWeight, DEFAULT_CLASS_LIST __all__ = [ "apply_weight_only_int8_quant", @@ -95,9 +96,11 @@ def apply_dynamic_quant(model, filter_fn=None): def _get_subclass_inserter(cls, **kwargs): + method = kwargs.pop("method", "from_float") def insert_subclass(lin): lin.weight = torch.nn.Parameter( - cls.from_float(lin.weight, **kwargs), requires_grad=False + # cls.from_float(...) + getattr(cls, method)(lin.weight, **kwargs), requires_grad=False ) return lin @@ -153,6 +156,39 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): filter_fn, ) + +def change_linears_to_autoquantizable(model, **kwargs): + filter_fn = kwargs.pop("filter_fn", _is_linear) + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), + filter_fn if filter_fn is not None else _is_linear, + ) + +def change_autoquantizable_to_quantized(model, **kwargs): + filter_fn = kwargs.pop( + "filter_fn", + lambda mod, *args: + _is_linear(mod, *args) and + isinstance(mod.weight, AutoQuantizableLinearWeight) + ) + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter( + AutoQuantizableLinearWeight, method="to_quantized", **kwargs + ), + filter_fn, + ) + +@torch.no_grad() +def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear): + change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list) + if not isinstance(example_input, (tuple, list)): + example_input = [example_input] + model(*example_input) + change_autoquantizable_to_quantized(model) + return model + def swap_conv2d_1x1_to_linear(model, filter_fn=None): """ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 44b37c67c1..1b37700894 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -13,8 +13,11 @@ groupwise_affine_quantize_tensor, quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, + quantize_activation_per_token_absmax, + quant_int8_per_token_matmul, + safe_int_mm, ) -from .utils import find_multiple +from .utils import find_multiple, benchmark import warnings @@ -197,6 +200,25 @@ def _quantized_op(act_mat, w_qtensor, bias): act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype ) + @classmethod + def _autoquant_test(cls, act_mat, weight, bias): + w_qtensor = cls.from_float(weight) + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") + with torch.no_grad(): + res=benchmark(q_c_op, act_mat, w_qtensor, bias) + x_vals_int8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_int8, x_scales, w_vals_int8: + safe_int_mm(x_vals_int8, w_vals_int8) * x_scales + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune") + with torch.no_grad(): + res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) + print(cls, res, res2) + return (res+res2)/2 + def dequantize(self, dtype=None): """ Obtain the dequantized version of the quantized tensor subclass @@ -292,6 +314,26 @@ def _quantized_op(act_mat, w_qtensor, bias): y += bias return y.to(orig_dtype) + @classmethod + def _autoquant_test(cls, act_mat, weight, bias): + w_qtensor = cls.from_float(weight) + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") + with torch.no_grad(): + res=benchmark(q_c_op, act_mat, w_qtensor, bias) + + quantized_matmul = ( + lambda act_mat, w_vals_int8: + torch.mm(act_mat, w_vals_int8.to(act_mat.dtype)) + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune") + with torch.no_grad(): + res2=benchmark( + q_c_matmul, + act_mat.reshape(-1, act_mat.shape[-1]), + w_qtensor.int_data) + print(cls, res, res2 + ) + return (res+res2)/2 class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase): """ diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 73621e6297..0bbe6faa95 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -7,6 +7,7 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.benchmark import Timer __all__ = [ "find_multiple", @@ -86,3 +87,15 @@ def get_model_size_in_bytes(model): for b in model.buffers(): s += b.nelement() * b.element_size() return s + +def benchmark(f, *args, **kwargs): + t0 = Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + # warmup + t0.timeit(10).median + t0.blocked_autorange() + res = t0.timeit(20) + print(res) + + return res.median * 1e3 From 7f4150f7f9183090fa38e5614b35d9f529e7bafe Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 27 Feb 2024 10:44:07 -0800 Subject: [PATCH 2/7] Update on "Autoquant" Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread 3.07677136734128 1311.548416 0 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread 3.6846738075837493 3.1023880932480097 2144.447488 25 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread 3.1286065466701984 2.210085652768612 2144.447488 22 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread 3.329739556647837 2228.913664 39 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread 3.4762858413159847 3.2240213360637426 2228.913664 38 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/test.py | 5 ++- torchao/quantization/autoquant.py | 65 +++++++++++++++++-------------- torchao/quantization/quant_api.py | 1 + torchao/quantization/subclass.py | 6 ++- torchao/quantization/utils.py | 29 +++++++++++--- 5 files changed, 68 insertions(+), 38 deletions(-) diff --git a/test/test.py b/test/test.py index 8e8ad04593..0347139796 100644 --- a/test/test.py +++ b/test/test.py @@ -1204,12 +1204,15 @@ def test_auto_quant(self): torch.nn.Linear(1280,3840), torch.nn.ReLU(), torch.nn.Linear(3840,1280), + torch.nn.ReLU(), ).to("cuda").to(torch.bfloat16) - example_input = torch.randn(65536,1280, device="cuda", dtype=torch.bfloat16) + example_input = torch.randn(65536, 1280, device="cuda", dtype=torch.bfloat16) torch._inductor.config.epilogue_fusion = False torch._inductor.config.use_mixed_mm = True torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.coordinate_descent_tuning = True + torch._dynamo.config.automatic_dynamic_shapes = False + torch._dynamo.reset() # TODO use in autoquantizer do_autoquant(model, example_input) if __name__ == "__main__": diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index edbc130816..4267a86e7d 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -12,16 +12,11 @@ AUTOQUANT_CACHE = {} -def check_cache(shape, cls): - if shape in AUTOQUANT_CACHE: - return AUTOQUANT_CACHE[shape].get(cls, None) - else: - return None +def check_cache(cls, shape, dtype): + return AUTOQUANT_CACHE.get((cls, shape, dtype), None) -def update_cache(shape, cls, res): - if not shape in AUTOQUANT_CACHE: - AUTOQUANT_CACHE[shape] = {} - AUTOQUANT_CACHE[shape][cls] = res +def update_cache(cls, shape, dtype, res): + AUTOQUANT_CACHE[(cls, shape, dtype)] = res class AutoQuantizableLinearWeight(torch.Tensor): """ @@ -43,7 +38,8 @@ def __new__(cls, weight, qtensor_class_list, *args, **kwargs): def __init__(self, weight, qtensor_class_list, *args, **kwargs): self.weight = weight self.qtensor_class_list = qtensor_class_list - self.cache_shape = None + self.logged_shape = None + self.logged_dtype = None def __repr__(self): return ( @@ -52,36 +48,46 @@ def __repr__(self): ) @staticmethod - def tune_autoquant(act_mat, w_autoquant, bias): + def log_shape(act_mat, w_autoquant, bias): orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1]) - cache_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape) - w_autoquant.cache_shape = cache_shape - for cur_cls in w_autoquant.qtensor_class_list: - if check_cache(cache_shape, cur_cls) is None: - with torch.no_grad(): - print(cur_cls, cache_shape) - print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage()) - res = cur_cls._autoquant_test(act_mat.clone(), w_autoquant.weight.clone(), None if bias is None else bias.clone()) - update_cache(cache_shape, cur_cls, res) - print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage()) + logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape) + logged_dtype = act_mat.dtype + w_autoquant.logged_shape = logged_shape + w_autoquant.logged_dtype = logged_dtype + for q_cls in w_autoquant.qtensor_class_list: + if check_cache(q_cls, logged_shape, logged_dtype) is None: + update_cache(q_cls, logged_shape, logged_dtype, None) y = torch.mm(act_mat, w_autoquant.weight.t()) y = y.reshape(*orig_shape[:-1], y.shape[-1]) if bias is not None: y += bias return y + def tune_autoquant(self, q_cls): + act_shape, w_shape, bias_shape = self.logged_shape + if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None: + with torch.no_grad(): + act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device) + bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device) + print(q_cls, self.logged_shape, self.logged_dtype) + print("mem", torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage()) + res = q_cls._autoquant_test(act_mat, self.weight, bias) + update_cache(q_cls, self.logged_shape, self.logged_dtype, res) + def to_quantized(self): - if self.cache_shape is None or self.cache_shape not in AUTOQUANT_CACHE: - raise RuntimeError("must run module normally to find best quantization option") + if self.logged_shape is None or self.logged_dtype is None: + raise RuntimeError("must run module normally to get shape, dtype info for autoquant") best_time = torch.inf best_cls = None - for cur_cls in self.qtensor_class_list: - cls_res = AUTOQUANT_CACHE[self.cache_shape].get(cur_cls, torch.inf) + for q_cls in self.qtensor_class_list: + if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None: + self.tune_autoquant(q_cls) + cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf) if best_time >= cls_res: best_time = cls_res - best_cls = cur_cls - # need to handle random cls args/kwargs? + best_cls = q_cls + # TODO handle random cls args/kwargs? or should they be curried self = best_cls.from_float(self.weight) return self @@ -113,7 +119,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args)>2 else None ) - return cls.tune_autoquant(mat1, w_autoquant, bias) + return cls.log_shape(mat1, w_autoquant, bias) try: with torch._C.DisableTorchFunctionSubclass(): @@ -155,9 +161,10 @@ def from_float(cls, weight): return weight DEFAULT_CLASS_LIST = [ + Int8DynamicallyQuantizedLinearWeight, DefaultLinear, Int8WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, + ] if False: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9e92396f5d..6d569e688a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -184,6 +184,7 @@ def change_autoquantizable_to_quantized(model, **kwargs): def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear): change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list) if not isinstance(example_input, (tuple, list)): + assert isinstance(example_input, torch.Tensor) example_input = [example_input] model(*example_input) change_autoquantizable_to_quantized(model) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 1b37700894..34f8ee4f01 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -206,6 +206,7 @@ def _autoquant_test(cls, act_mat, weight, bias): q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") with torch.no_grad(): res=benchmark(q_c_op, act_mat, w_qtensor, bias) + x_vals_int8, x_scales = quantize_activation_per_token_absmax( act_mat.reshape(-1, act_mat.shape[-1]) ) @@ -217,6 +218,7 @@ def _autoquant_test(cls, act_mat, weight, bias): with torch.no_grad(): res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) print(cls, res, res2) + breakpoint() return (res+res2)/2 def dequantize(self, dtype=None): @@ -331,8 +333,8 @@ def _autoquant_test(cls, act_mat, weight, bias): q_c_matmul, act_mat.reshape(-1, act_mat.shape[-1]), w_qtensor.int_data) - print(cls, res, res2 - ) + + print(cls, res, res2) return (res+res2)/2 class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase): diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 0bbe6faa95..c79d83b65e 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -88,14 +88,31 @@ def get_model_size_in_bytes(model): s += b.nelement() * b.element_size() return s +import time + +def benchmark_torch_function(iters, f, *args, **kwargs): + f(*args, **kwargs) + f(*args, **kwargs) + f(*args, **kwargs) + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + else: + t0 = time.time() + for i in range(iters): + f(*args, **kwargs) + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + return start_event.elapsed_time(end_event) + else: + return (time.time() - t0) + def benchmark(f, *args, **kwargs): t0 = Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) # warmup - t0.timeit(10).median - t0.blocked_autorange() - res = t0.timeit(20) - print(res) - - return res.median * 1e3 + return benchmark_torch_function(10, f, *args, **kwargs) From 20c81b963a43be1098f6e17fd53c748d451b7bf5 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 1 Mar 2024 18:34:28 -0800 Subject: [PATCH 3/7] Update on "Autoquant" Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread 3.07677136734128 1311.548416 0 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread 3.6846738075837493 3.1023880932480097 2144.447488 25 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread 3.1286065466701984 2.210085652768612 2144.447488 22 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread 3.329739556647837 2228.913664 39 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread 3.4762858413159847 3.2240213360637426 2228.913664 38 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/test.py | 39 +++++++--- torchao/quantization/__init__.py | 3 + torchao/quantization/autoquant.py | 118 +++++++++++++++++++++++++----- torchao/quantization/quant_api.py | 13 +++- torchao/quantization/subclass.py | 43 +---------- torchao/quantization/utils.py | 27 ++----- 6 files changed, 147 insertions(+), 96 deletions(-) diff --git a/test/test.py b/test/test.py index 0347139796..f4e6c09b95 100644 --- a/test/test.py +++ b/test/test.py @@ -54,6 +54,7 @@ compute_error as SQNR, _fqn_to_op_to_shape_to_count, LoggingTensorMode, + benchmark ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -1198,22 +1199,38 @@ def test_on_dummy_distilbert(self): class TestAutoQuant(unittest.TestCase): def test_auto_quant(self): - model = torch.nn.Sequential( - # torch.nn.Linear(5120,1280), - # torch.nn.ReLU(), - torch.nn.Linear(1280,3840), - torch.nn.ReLU(), - torch.nn.Linear(3840,1280), - torch.nn.ReLU(), - ).to("cuda").to(torch.bfloat16) - example_input = torch.randn(65536, 1280, device="cuda", dtype=torch.bfloat16) torch._inductor.config.epilogue_fusion = False torch._inductor.config.use_mixed_mm = True torch._inductor.config.force_fuse_int_mm_with_mul = True torch._inductor.config.coordinate_descent_tuning = True torch._dynamo.config.automatic_dynamic_shapes = False - torch._dynamo.reset() # TODO use in autoquantizer - do_autoquant(model, example_input) + + for m,k,n in [ + (1, 1024, 1024), + (64, 1024, 1024), + (4096, 1024, 1024), + (1, 1024, 4096), + (64, 1024, 4096), + (1, 4096, 1024), + (64, 4096, 1024), + (4096, 4096, 1024), + ]: + print("testing", m, k, n) + example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + model = torch.nn.Sequential( + # torch.nn.ReLU(), + torch.nn.Linear(k,n), + # torch.nn.ReLU(), + # torch.nn.Linear(1280,3840), + # torch.nn.ReLU(), + # torch.nn.Linear(3840,1280), + # torch.nn.ReLU(), + # torch.nn.Linear(1280,1024), + # torch.nn.ReLU(), + # torch.nn.Linear(1024,4096), + # torch.nn.ReLU(), + ).to("cuda").to(torch.bfloat16) + do_autoquant(model, example_input) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 80599cb71c..525008a77d 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -25,6 +25,9 @@ "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", + "do_autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", "quant_int8_dynamic_linear", "quant_int8_matmul", "quant_int8_dynamic_per_token_linear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 4267a86e7d..8a624b6b92 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -7,6 +7,11 @@ ) from torch.utils._python_dispatch import return_and_correct_aliasing from .utils import benchmark +from .quant_primitives import ( + quantize_activation_per_token_absmax, + safe_int_mm, +) +import torch.nn.functional as F aten = torch.ops.aten @@ -70,23 +75,30 @@ def tune_autoquant(self, q_cls): with torch.no_grad(): act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device) bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device) - print(q_cls, self.logged_shape, self.logged_dtype) - print("mem", torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage()) res = q_cls._autoquant_test(act_mat, self.weight, bias) update_cache(q_cls, self.logged_shape, self.logged_dtype, res) - def to_quantized(self): - if self.logged_shape is None or self.logged_dtype is None: + def to_quantized(self, error_on_unseen, **kwargs): + if error_on_unseen and (self.logged_shape is None or self.logged_dtype is None): raise RuntimeError("must run module normally to get shape, dtype info for autoquant") + elif (self.logged_shape is None or self.logged_dtype is None) and not error_on_unseen: + # default back to non-quantized weight if not seen + self = AQFloatLinearWeight.from_float(self.weight) + return self best_time = torch.inf best_cls = None + do_print=False for q_cls in self.qtensor_class_list: if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None: + do_print=True self.tune_autoquant(q_cls) + torch._dynamo.reset() cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf) if best_time >= cls_res: best_time = cls_res best_cls = q_cls + if do_print: + print(f"shape={self.logged_shape}, dtype={self.logged_dtype}, best_cls={best_cls}") # TODO handle random cls args/kwargs? or should they be curried self = best_cls.from_float(self.weight) return self @@ -132,26 +144,93 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.detach.default: return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) - -class DefaultLinear(torch.Tensor): +class AQMixin(): """ - An class to be used in concert with AutoQuantizableLinearWeight to provide a - default/non-quantized option. Only implements the bare minimum needed to work with the - AutoQuantizableLinearWeight class using the same interfaces that would normally be - used by QTensor subclasses but for a default linear op instead. + Mixin to turn normal quantized subclasses into autoquantizable ones """ - def __init__(self): - super().__init__() - @classmethod def _autoquant_test(cls, act_mat, weight, bias): w_qtensor = cls.from_float(weight) - q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") + func = lambda act_mat, w_qtensor, bias: F.relu(cls._quantized_op(F.relu(act_mat), w_qtensor, bias)) + q_c_op = torch.compile(func, mode="max-autotune") + # q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") with torch.no_grad(): - res=benchmark(q_c_op, act_mat, w_qtensor, bias) + torch.cuda.synchronize() + res = benchmark(q_c_op, act_mat, w_qtensor, bias) print(cls, res) return res +class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): + """ + AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight + """ + @classmethod + def _autoquant_test(cls, act_mat, weight, bias): + res = super()._autoquant_test(act_mat, weight, bias) + w_qtensor = cls.from_float(weight) + x_vals_int8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_int8, x_scales, w_vals_int8: + safe_int_mm(x_vals_int8, w_vals_int8) * x_scales + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune") + with torch.no_grad(): + res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) + print(cls, "matmul", res2) + # for SAM best is between .458-.499, SDXL .45=3.094 .47=2.880 .48=3.036 .5=2.930 + return res + + +class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + """ + +class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + orig_shape = act_mat.shape + act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) + y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y.to(orig_dtype) + + @classmethod + def _autoquant_test(cls, act_mat, weight, bias): + # if act_mat has batchsize>2 don't use this kernel + if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2: + return torch.inf + return super()._autoquant_test(act_mat, weight, bias) + +class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + def _quantized_op(act_mat, w_qtensor, bias): + orig_shape = act_mat.shape + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) + y=y.reshape(*orig_shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y + + +class AQFloatLinearWeight(torch.Tensor, AQMixin): + """ + A class to be used in concert with AutoQuantizableLinearWeight to provide a + default/non-quantized option. Only implements the bare minimum needed to work with the + AutoQuantizableLinearWeight class using the same interfaces that would normally be + used by QTensor subclasses but for a default linear op instead. + """ + def __init__(self): + super().__init__() + @staticmethod def _quantized_op(act_mat, w_qtensor, bias): return torch.nn.functional.linear(act_mat, w_qtensor, bias) @@ -161,10 +240,11 @@ def from_float(cls, weight): return weight DEFAULT_CLASS_LIST = [ - Int8DynamicallyQuantizedLinearWeight, - DefaultLinear, - Int8WeightOnlyQuantizedLinearWeight, - + AQFloatLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + AQWeightOnlyQuantizedLinearWeight3, ] if False: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6d569e688a..9b1e0dc9c9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,7 +36,10 @@ "change_linear_weights_to_int8_dqtensors", "change_linear_weights_to_int8_woqtensors", "change_linear_weights_to_int4_woqtensors", - "swap_conv2d_1x1_to_linear" + "swap_conv2d_1x1_to_linear", + "do_autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", ] @@ -159,6 +162,7 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): def change_linears_to_autoquantizable(model, **kwargs): filter_fn = kwargs.pop("filter_fn", _is_linear) + kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) _replace_with_custom_fn_if_matches_filter( model, _get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), @@ -172,22 +176,27 @@ def change_autoquantizable_to_quantized(model, **kwargs): _is_linear(mod, *args) and isinstance(mod.weight, AutoQuantizableLinearWeight) ) + error_on_unseen=kwargs.pop("error_on_unseen", True) _replace_with_custom_fn_if_matches_filter( model, _get_subclass_inserter( - AutoQuantizableLinearWeight, method="to_quantized", **kwargs + AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs ), filter_fn, ) @torch.no_grad() def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear): + hold = torch._dynamo.config.automatic_dynamic_shapes + torch._dynamo.config.automatic_dynamic_shapes = False change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list) if not isinstance(example_input, (tuple, list)): assert isinstance(example_input, torch.Tensor) example_input = [example_input] model(*example_input) change_autoquantizable_to_quantized(model) + torch._dynamo.config.automatic_dynamic_shapes = hold + torch._dynamo.reset() return model def swap_conv2d_1x1_to_linear(model, filter_fn=None): diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 34f8ee4f01..6acdff235d 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -200,27 +200,6 @@ def _quantized_op(act_mat, w_qtensor, bias): act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype ) - @classmethod - def _autoquant_test(cls, act_mat, weight, bias): - w_qtensor = cls.from_float(weight) - q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") - with torch.no_grad(): - res=benchmark(q_c_op, act_mat, w_qtensor, bias) - - x_vals_int8, x_scales = quantize_activation_per_token_absmax( - act_mat.reshape(-1, act_mat.shape[-1]) - ) - quantized_matmul = ( - lambda x_vals_int8, x_scales, w_vals_int8: - safe_int_mm(x_vals_int8, w_vals_int8) * x_scales - ) - q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune") - with torch.no_grad(): - res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) - print(cls, res, res2) - breakpoint() - return (res+res2)/2 - def dequantize(self, dtype=None): """ Obtain the dequantized version of the quantized tensor subclass @@ -293,7 +272,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127): # however the external representation of our tensor will maintain the correct # shape attribute which needs to be tracked directly. int_data = w_int_repr.contiguous().t() - if cls is not Int8DynamicallyQuantizedLinearWeight: + if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight): int_data = int_data.contiguous() return cls( int_data, w_scales, False, input_float.shape, dtype=input_float.dtype @@ -316,26 +295,6 @@ def _quantized_op(act_mat, w_qtensor, bias): y += bias return y.to(orig_dtype) - @classmethod - def _autoquant_test(cls, act_mat, weight, bias): - w_qtensor = cls.from_float(weight) - q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") - with torch.no_grad(): - res=benchmark(q_c_op, act_mat, w_qtensor, bias) - - quantized_matmul = ( - lambda act_mat, w_vals_int8: - torch.mm(act_mat, w_vals_int8.to(act_mat.dtype)) - ) - q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune") - with torch.no_grad(): - res2=benchmark( - q_c_matmul, - act_mat.reshape(-1, act_mat.shape[-1]), - w_qtensor.int_data) - - print(cls, res, res2) - return (res+res2)/2 class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase): """ diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index c79d83b65e..2cadab4644 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -88,31 +88,14 @@ def get_model_size_in_bytes(model): s += b.nelement() * b.element_size() return s -import time - -def benchmark_torch_function(iters, f, *args, **kwargs): - f(*args, **kwargs) - f(*args, **kwargs) - f(*args, **kwargs) - if torch.cuda.is_available(): - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - else: - t0 = time.time() - for i in range(iters): - f(*args, **kwargs) - if torch.cuda.is_available(): - end_event.record() - torch.cuda.synchronize() - return start_event.elapsed_time(end_event) - else: - return (time.time() - t0) def benchmark(f, *args, **kwargs): t0 = Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) + # warmup - return benchmark_torch_function(10, f, *args, **kwargs) + t0.timeit(10) + + res=t0.blocked_autorange(min_run_time=.5) + return res.median * 1e3 From 0823e95920d0eab0fd858e22a9a48f3eb1607cea Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 1 Mar 2024 18:35:45 -0800 Subject: [PATCH 4/7] Update on "Autoquant" Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread 3.07677136734128 1311.548416 0 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread 3.6846738075837493 3.1023880932480097 2144.447488 25 (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread 3.1286065466701984 2.210085652768612 2144.447488 22 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread 3.329739556647837 2228.913664 39 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread 3.4762858413159847 3.2240213360637426 2228.913664 38 (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torchao/quantization/autoquant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 8a624b6b92..510b27ada0 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -151,7 +151,7 @@ class AQMixin(): @classmethod def _autoquant_test(cls, act_mat, weight, bias): w_qtensor = cls.from_float(weight) - func = lambda act_mat, w_qtensor, bias: F.relu(cls._quantized_op(F.relu(act_mat), w_qtensor, bias)) + func = lambda a, b, c: F.relu(cls._quantized_op(F.relu(a), b, c)) q_c_op = torch.compile(func, mode="max-autotune") # q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") with torch.no_grad(): From c6d59e5192e4e8b675ebd97e428f2a6944c53918 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 5 Mar 2024 15:51:14 -0800 Subject: [PATCH 5/7] Update on "Autoquant" Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL (https://github.com/pytorch-labs/segment-anything-fast/pull/114, https://github.com/huggingface/diffusion-fast/commit/176e85f9afdbd61df9f035c84313c5f85c3c597a) Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/test.py | 14 ++---- test/test_autoquant.py | 35 -------------- torchao/quantization/autoquant.py | 78 +++++++++---------------------- torchao/quantization/subclass.py | 5 +- torchao/quantization/utils.py | 12 ++++- 5 files changed, 37 insertions(+), 107 deletions(-) delete mode 100644 test/test_autoquant.py diff --git a/test/test.py b/test/test.py index f4e6c09b95..7bb7dc90c5 100644 --- a/test/test.py +++ b/test/test.py @@ -1197,6 +1197,7 @@ def test_on_dummy_distilbert(self): print("sqnr_pt_quant", sqnr_pt_quant) self.assertTrue(sqnr_sq >= 8.0) +# TODO FINISH TEST CODE class TestAutoQuant(unittest.TestCase): def test_auto_quant(self): torch._inductor.config.epilogue_fusion = False @@ -1215,20 +1216,11 @@ def test_auto_quant(self): (64, 4096, 1024), (4096, 4096, 1024), ]: - print("testing", m, k, n) example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) model = torch.nn.Sequential( - # torch.nn.ReLU(), + torch.nn.ReLU(), torch.nn.Linear(k,n), - # torch.nn.ReLU(), - # torch.nn.Linear(1280,3840), - # torch.nn.ReLU(), - # torch.nn.Linear(3840,1280), - # torch.nn.ReLU(), - # torch.nn.Linear(1280,1024), - # torch.nn.ReLU(), - # torch.nn.Linear(1024,4096), - # torch.nn.ReLU(), + torch.nn.ReLU(), ).to("cuda").to(torch.bfloat16) do_autoquant(model, example_input) diff --git a/test/test_autoquant.py b/test/test_autoquant.py deleted file mode 100644 index 5d354185a7..0000000000 --- a/test/test_autoquant.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# mypy: ignore-errors -import copy -import unittest - -import torch -import torch.nn as nn -from torchao.quantization.quant_api import ( - change_linears_to_autoquantizable, - change_autoquantizable_to_quantized -) -from torchao.quantization.autoquant import do_autoquant -from torch._dynamo import config -torch.manual_seed(0) -config.cache_size_limit = 100 - - -class AutoquantTests(unittest.TestCase): - def test_autoquant_e2e(self): - model = torch.nn.Sequential(torch.nn.Linear(32,32), torch.nn.ReLU(), torch.nn.Linear(32,32)).cuda().to(torch.bfloat16) - print(model, model[0].weight) - example_input = torch.randn((1,64,32), dtype=torch.bfloat16, device=torch.cuda) - out=model(example_input) - print(out.sum()) - do_autoquant(model) - print(model, model[0].weight) - print(model(example_input).sum()) - -if __name__ == "__main__": - unittest.main() diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 510b27ada0..8392234314 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -69,13 +69,13 @@ def log_shape(act_mat, w_autoquant, bias): y += bias return y - def tune_autoquant(self, q_cls): + def tune_autoquant(self, q_cls, best_time): act_shape, w_shape, bias_shape = self.logged_shape if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None: with torch.no_grad(): act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device) bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device) - res = q_cls._autoquant_test(act_mat, self.weight, bias) + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time) update_cache(q_cls, self.logged_shape, self.logged_dtype, res) def to_quantized(self, error_on_unseen, **kwargs): @@ -91,7 +91,7 @@ def to_quantized(self, error_on_unseen, **kwargs): for q_cls in self.qtensor_class_list: if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None: do_print=True - self.tune_autoquant(q_cls) + self.tune_autoquant(q_cls, best_time) torch._dynamo.reset() cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf) if best_time >= cls_res: @@ -149,14 +149,12 @@ class AQMixin(): Mixin to turn normal quantized subclasses into autoquantizable ones """ @classmethod - def _autoquant_test(cls, act_mat, weight, bias): + def _autoquant_test(cls, act_mat, weight, bias, best_time, *args, **kwargs): w_qtensor = cls.from_float(weight) - func = lambda a, b, c: F.relu(cls._quantized_op(F.relu(a), b, c)) - q_c_op = torch.compile(func, mode="max-autotune") - # q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") with torch.no_grad(): torch.cuda.synchronize() - res = benchmark(q_c_op, act_mat, w_qtensor, bias) + res = benchmark(q_c_op, act_mat, w_qtensor, bias, best_time=best_time) print(cls, res) return res @@ -165,8 +163,9 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLi AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ @classmethod - def _autoquant_test(cls, act_mat, weight, bias): - res = super()._autoquant_test(act_mat, weight, bias) + def _autoquant_test(cls, act_mat, weight, bias, best_time): + # SAM best is between .51 to .60, SDXL also performs best in this range + INTERPOLATION_CONSTANT=.55 w_qtensor = cls.from_float(weight) x_vals_int8, x_scales = quantize_activation_per_token_absmax( act_mat.reshape(-1, act_mat.shape[-1]) @@ -177,10 +176,18 @@ def _autoquant_test(cls, act_mat, weight, bias): ) q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune") with torch.no_grad(): - res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) - print(cls, "matmul", res2) - # for SAM best is between .458-.499, SDXL .45=3.094 .47=2.880 .48=3.036 .5=2.930 - return res + res_matmul=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data, best_time=best_time) + print(cls, "matmul", res_matmul) + + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op + if res_matmul>=best_time: + return res_matmul + + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT + to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) + res = super()._autoquant_test(act_mat, weight, bias, to_beat) + print(cls, "full", INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul) + return INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): @@ -205,11 +212,11 @@ def _quantized_op(act_mat, w_qtensor, bias): return y.to(orig_dtype) @classmethod - def _autoquant_test(cls, act_mat, weight, bias): + def _autoquant_test(cls, act_mat, weight, bias, best_time): # if act_mat has batchsize>2 don't use this kernel if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2: return torch.inf - return super()._autoquant_test(act_mat, weight, bias) + return super()._autoquant_test(act_mat, weight, bias, best_time) class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): def _quantized_op(act_mat, w_qtensor, bias): @@ -246,42 +253,3 @@ def from_float(cls, weight): AQWeightOnlyQuantizedLinearWeight2, AQWeightOnlyQuantizedLinearWeight3, ] - -if False: - # def _get_to_kwargs(self, *args, **kwargs): - # device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - # device = self.device if device is None else device - # dtype = self.dtype if dtype is None else dtype - # memory_format = ( - # memory_format if memory_format is not None else torch.preserve_format - # ) - # kwargs = { - # "device": device, - # "dtype": dtype, - # "memory_format": memory_format, - # } - # return kwargs - - # def to(self, *args, **kwargs): - # kwargs = self._get_to_kwargs(*args, **kwargs) - # return self.__class__( - # self.int_data.to(kwargs["device"]), - # self.q_scales.to(kwargs["device"]), - # self.transposed, - # self.shape, - # **kwargs, - # ) - - # def _apply_fn_to_data(self, fn): - # return self.__class__( - # fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype - # ) - - # def _change_shape(self, shape): - # return self.__class__( - # self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype - # ) - - # def half(self): - # return self.to(torch.float16) - pass diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 6acdff235d..221c760eab 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -13,11 +13,8 @@ groupwise_affine_quantize_tensor, quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, - quantize_activation_per_token_absmax, - quant_int8_per_token_matmul, - safe_int_mm, ) -from .utils import find_multiple, benchmark +from .utils import find_multiple import warnings diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 2cadab4644..e973ab8ca9 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -90,12 +90,20 @@ def get_model_size_in_bytes(model): def benchmark(f, *args, **kwargs): + if "best_time" in kwargs: + best_time = kwargs.pop("best_time") + else: + best_time = torch.inf t0 = Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) # warmup t0.timeit(10) - - res=t0.blocked_autorange(min_run_time=.5) + res=t0.adaptive_autorange(min_run_time=.1) + # run more if median vs median minus iqr (interpolated based on number of runs left) is lower than best_time, + # stop if good res.iqr/res.median or have 20 samples + while res.median-res.iqr+res.iqr*len(res.times)/20 < best_time * 1e-3 and not (res.iqr/res.median<.02 or len(res.times)>=20): + res2 = t0.adaptive_autorange(min_run_time=.5) + res=res.merge([res2, res])[0] return res.median * 1e3 From 97733c29fddbe9d4c6a766b00e4d1aef3350d545 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 19 Mar 2024 15:14:58 -0700 Subject: [PATCH 6/7] Update on "Autoquant" Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL https://github.com/pytorch-labs/segment-anything-fast/pull/114 https://github.com/HDCharles/sdxl-fast/commit/8d9942ab05a552f25f5bfe09da02719ce255467f Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/test.py | 48 ++++++++- torchao/quantization/autoquant.py | 169 ++++++++++++++++++------------ torchao/quantization/quant_api.py | 26 ++++- 3 files changed, 165 insertions(+), 78 deletions(-) diff --git a/test/test.py b/test/test.py index 7bb7dc90c5..e1317c9a5a 100644 --- a/test/test.py +++ b/test/test.py @@ -54,7 +54,13 @@ compute_error as SQNR, _fqn_to_op_to_shape_to_count, LoggingTensorMode, - benchmark +) +from torchao.quantization.autoquant import ( + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + AQWeightOnlyQuantizedLinearWeight3 + ) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -882,6 +888,36 @@ def test_int8_weight_only_quant_subclass(self): Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype ) + def test_aq_int8_dynamic_quant_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype + ) + + def test_aq_int8_weight_only_quant_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype + ) + + def test_aq_int8_weight_only_quant_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype + ) + + def test_aq_int8_weight_only_quant_2_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype + ) + + def test_aq_int8_weight_only_quant_3_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype + ) + def test_int4_weight_only_quant_subclass(self): self._test_lin_weight_subclass_impl( Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8] @@ -1197,19 +1233,17 @@ def test_on_dummy_distilbert(self): print("sqnr_pt_quant", sqnr_pt_quant) self.assertTrue(sqnr_sq >= 8.0) -# TODO FINISH TEST CODE class TestAutoQuant(unittest.TestCase): - def test_auto_quant(self): + def test_autoquant(self): torch._inductor.config.epilogue_fusion = False torch._inductor.config.use_mixed_mm = True torch._inductor.config.force_fuse_int_mm_with_mul = True - torch._inductor.config.coordinate_descent_tuning = True torch._dynamo.config.automatic_dynamic_shapes = False for m,k,n in [ (1, 1024, 1024), (64, 1024, 1024), - (4096, 1024, 1024), + (2**15, 1024, 1024), (1, 1024, 4096), (64, 1024, 4096), (1, 4096, 1024), @@ -1222,7 +1256,11 @@ def test_auto_quant(self): torch.nn.Linear(k,n), torch.nn.ReLU(), ).to("cuda").to(torch.bfloat16) + out = model(example_input) do_autoquant(model, example_input) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 8392234314..60eb29127c 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -1,34 +1,34 @@ import torch - +import os +from subprocess import check_output from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, ) from torch.utils._python_dispatch import return_and_correct_aliasing -from .utils import benchmark from .quant_primitives import ( quantize_activation_per_token_absmax, safe_int_mm, ) import torch.nn.functional as F - +from torch._inductor.utils import do_bench aten = torch.ops.aten AUTOQUANT_CACHE = {} -def check_cache(cls, shape, dtype): - return AUTOQUANT_CACHE.get((cls, shape, dtype), None) +def check_cache(cls, shapes_and_dtype): + return AUTOQUANT_CACHE.get((cls,)+shapes_and_dtype, None) -def update_cache(cls, shape, dtype, res): - AUTOQUANT_CACHE[(cls, shape, dtype)] = res +def update_cache(cls, shapes_and_dtype, res): + AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res class AutoQuantizableLinearWeight(torch.Tensor): """ when run, finds best type of quantization for this tensor and swaps itself with that """ @staticmethod - def __new__(cls, weight, qtensor_class_list, *args, **kwargs): + def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): kwargs["device"] = weight.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else weight.layout @@ -40,11 +40,11 @@ def __new__(cls, weight, qtensor_class_list, *args, **kwargs): shape = kwargs.pop("shape", weight.shape) return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, weight, qtensor_class_list, *args, **kwargs): + def __init__(self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): self.weight = weight self.qtensor_class_list = qtensor_class_list - self.logged_shape = None - self.logged_dtype = None + self.logged_data = {} + self.mode = mode def __repr__(self): return ( @@ -54,72 +54,72 @@ def __repr__(self): @staticmethod def log_shape(act_mat, w_autoquant, bias): - orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1]) - logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape) logged_dtype = act_mat.dtype - w_autoquant.logged_shape = logged_shape - w_autoquant.logged_dtype = logged_dtype + logged_shapes = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape,) + shapes_and_dtype = logged_shapes + (logged_dtype,) + w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(shapes_and_dtype, 0) for q_cls in w_autoquant.qtensor_class_list: - if check_cache(q_cls, logged_shape, logged_dtype) is None: - update_cache(q_cls, logged_shape, logged_dtype, None) - y = torch.mm(act_mat, w_autoquant.weight.t()) - y = y.reshape(*orig_shape[:-1], y.shape[-1]) - if bias is not None: - y += bias - return y + if check_cache(q_cls, shapes_and_dtype) is None: + update_cache(q_cls, shapes_and_dtype, None) - def tune_autoquant(self, q_cls, best_time): - act_shape, w_shape, bias_shape = self.logged_shape - if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None: + def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): + act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype + if check_cache(q_cls, shapes_and_dtype) is None: with torch.no_grad(): - act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device) - bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device) - res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time) - update_cache(q_cls, self.logged_shape, self.logged_dtype, res) + act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) + bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device) + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + update_cache(q_cls, shapes_and_dtype, res) def to_quantized(self, error_on_unseen, **kwargs): - if error_on_unseen and (self.logged_shape is None or self.logged_dtype is None): + if error_on_unseen and self.logged_data == {}: raise RuntimeError("must run module normally to get shape, dtype info for autoquant") - elif (self.logged_shape is None or self.logged_dtype is None) and not error_on_unseen: + elif (self.logged_data == {}) and not error_on_unseen: # default back to non-quantized weight if not seen self = AQFloatLinearWeight.from_float(self.weight) - return self + return self best_time = torch.inf best_cls = None do_print=False + # check each class for q_cls in self.qtensor_class_list: - if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None: - do_print=True - self.tune_autoquant(q_cls, best_time) - torch._dynamo.reset() - cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf) + # for each logged shape+dtype, benchmark + cls_res=0 + for shapes_and_dtype, times_seen in self.logged_data.items(): + if check_cache(q_cls, shapes_and_dtype) is None: + do_print=True + self.tune_autoquant(q_cls, shapes_and_dtype, best_time) + torch._dynamo.reset() + cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen if best_time >= cls_res: best_time = cls_res best_cls = q_cls + # only print if this is the first time seeing some cls+shape combo, + # otherwise we will print the same thing for every layer. if do_print: - print(f"shape={self.logged_shape}, dtype={self.logged_dtype}, best_cls={best_cls}") - # TODO handle random cls args/kwargs? or should they be curried + print(f"for {self.logged_data}, best_cls={best_cls}") + # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) return self def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.weight), self.qtensor_class_list, dtype=self.dtype + fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode ) def __tensor_flatten__(self): - return ["weight"], [self.qtensor_class_list, self.dtype, self.shape] + return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] @classmethod def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): weight = tensor_data_dict["weight"] - qtensor_class_list, dtype, shape = tensor_attributes[0] - return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + qtensor_class_list, mode, dtype, shape = tensor_attributes[0] + return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) @classmethod - def from_float(cls, weight, qtensor_class_list): - return cls(weight, qtensor_class_list) + def from_float(cls, weight, qtensor_class_list, **kwargs): + return cls(weight, qtensor_class_list, **kwargs) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -131,8 +131,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args)>2 else None ) - return cls.log_shape(mat1, w_autoquant, bias) - + cls.log_shape(mat1, w_autoquant, bias) + return func(mat1, w_autoquant.weight, bias) try: with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -144,18 +144,47 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.detach.default: return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) +def do_autoquant_bench(op, *args, **kwargs): + rep = kwargs.pop("rep", 100) + warmup = kwargs.pop("warmup", 25) + with torch.no_grad(): + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + op(*args) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + op(*args) + res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") + return res + +def _is_interpolate_mode(mode): + if isinstance(mode, list) and mode[0]=="interpolate" and len(mode)==2 and isinstance(mode[1], float): + return True + return False + class AQMixin(): """ Mixin to turn normal quantized subclasses into autoquantizable ones """ @classmethod - def _autoquant_test(cls, act_mat, weight, bias, best_time, *args, **kwargs): + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): w_qtensor = cls.from_float(weight) - q_c_op = torch.compile(cls._quantized_op, mode="max-autotune") - with torch.no_grad(): - torch.cuda.synchronize() - res = benchmark(q_c_op, act_mat, w_qtensor, bias, best_time=best_time) - print(cls, res) + if _is_interpolate_mode(mode): + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs") + else: + func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) + q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias) + if res < best_time*1.1: + res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900) + res=(res2*.9+res*.1) + print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") return res class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): @@ -163,9 +192,12 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLi AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ @classmethod - def _autoquant_test(cls, act_mat, weight, bias, best_time): - # SAM best is between .51 to .60, SDXL also performs best in this range - INTERPOLATION_CONSTANT=.55 + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + if not _is_interpolate_mode(mode): + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) + + # SAM best is between .8 to 1, SDXL also performs best in this range + INTERPOLATION_CONSTANT = mode[1] w_qtensor = cls.from_float(weight) x_vals_int8, x_scales = quantize_activation_per_token_absmax( act_mat.reshape(-1, act_mat.shape[-1]) @@ -174,10 +206,10 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time): lambda x_vals_int8, x_scales, w_vals_int8: safe_int_mm(x_vals_int8, w_vals_int8) * x_scales ) - q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune") + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): - res_matmul=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data, best_time=best_time) - print(cls, "matmul", res_matmul) + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) + print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op if res_matmul>=best_time: @@ -186,9 +218,10 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time): # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) res = super()._autoquant_test(act_mat, weight, bias, to_beat) - print(cls, "full", INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul) - return INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul - + max_int_const_win = (best_time-res_matmul)/(res-res_matmul) + res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul + print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") + return res_f class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): """ @@ -206,17 +239,17 @@ def _quantized_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) - y = y.reshape(*orig_shape[:-1], y.shape[-1]) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales if bias is not None: y += bias return y.to(orig_dtype) @classmethod - def _autoquant_test(cls, act_mat, weight, bias, best_time): + def _autoquant_test(cls, act_mat, *args): # if act_mat has batchsize>2 don't use this kernel - if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2: + if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32: return torch.inf - return super()._autoquant_test(act_mat, weight, bias, best_time) + return super()._autoquant_test(act_mat, *args) class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): def _quantized_op(act_mat, w_qtensor, bias): @@ -227,7 +260,6 @@ def _quantized_op(act_mat, w_qtensor, bias): y += bias return y - class AQFloatLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a @@ -251,5 +283,6 @@ def from_float(cls, weight): AQInt8DynamicallyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight2, - AQWeightOnlyQuantizedLinearWeight3, + # AQWeightOnlyQuantizedLinearWeight3, + # 3rd version gets picked in situations where it is slower for the interpolation mode ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9b1e0dc9c9..942bf043e9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -161,8 +161,14 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): def change_linears_to_autoquantizable(model, **kwargs): + """ + Converts all linear weight tensors to the + AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed + by running the model and then calling change_autoquantizable_to_quantized + """ filter_fn = kwargs.pop("filter_fn", _is_linear) kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) + kwargs["mode"] = kwargs.get("mode", ["relu", None]) _replace_with_custom_fn_if_matches_filter( model, _get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), @@ -170,11 +176,16 @@ def change_linears_to_autoquantizable(model, **kwargs): ) def change_autoquantizable_to_quantized(model, **kwargs): + """ + Converts AutoQuantizableLinearWeight tensor subclasses + to various quantized/non-quantized tensor subclasses depending + on benchmark results. Expectation is that these modules are + torch.compiled afterwards. + """ filter_fn = kwargs.pop( "filter_fn", lambda mod, *args: - _is_linear(mod, *args) and - isinstance(mod.weight, AutoQuantizableLinearWeight) + hasattr(mod, "weight") and isinstance(mod.weight, AutoQuantizableLinearWeight) ) error_on_unseen=kwargs.pop("error_on_unseen", True) _replace_with_custom_fn_if_matches_filter( @@ -186,15 +197,20 @@ def change_autoquantizable_to_quantized(model, **kwargs): ) @torch.no_grad() -def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear): +def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs): + """ + Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape + across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer + and applies that type of quantization. + """ hold = torch._dynamo.config.automatic_dynamic_shapes torch._dynamo.config.automatic_dynamic_shapes = False - change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list) + change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) if not isinstance(example_input, (tuple, list)): assert isinstance(example_input, torch.Tensor) example_input = [example_input] model(*example_input) - change_autoquantizable_to_quantized(model) + change_autoquantizable_to_quantized(model, **kwargs) torch._dynamo.config.automatic_dynamic_shapes = hold torch._dynamo.reset() return model From 29214a9efa750f2e1159bd700ecc2c226699d741 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 19 Mar 2024 15:27:37 -0700 Subject: [PATCH 7/7] Update on "Autoquant" Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL https://github.com/pytorch-labs/segment-anything-fast/pull/114 https://github.com/HDCharles/sdxl-fast/commit/8d9942ab05a552f25f5bfe09da02719ce255467f Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- README.md | 41 ++++++++++++---- __init__.py | 0 torchao/__init__.py | 24 ++++++++++ torchao/quantization/__init__.py | 2 +- torchao/quantization/autoquant.py | 78 ++++++++++++++++++++++--------- torchao/quantization/quant_api.py | 13 +++--- torchao/quantization/utils.py | 21 --------- 7 files changed, 120 insertions(+), 59 deletions(-) create mode 100644 __init__.py diff --git a/README.md b/README.md index cd34c0d8ac..45e51c828f 100644 --- a/README.md +++ b/README.md @@ -43,29 +43,50 @@ The following apis use quantized [tensor subclasses](https://pytorch.org/docs/st This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible. -### A8W8 Dynamic Quantization +### Autoquantization -The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this -converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul. - -Example +The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. ``` import torch -from torchao.quantization import quant_api +import torchao + +# inductor settings which improve torch.compile runtime for quantized modules +torch._inductor.config.force_fuse_int_mm_with_mul +torch._inductor.config.use_mixed_mm # some user model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int8_dqtensors(model) +# perform autoquantization +torchao.autoquant(model, (input)) # compile the model to improve performance model = torch.compile(model, mode='max-autotune') model(input) ``` + +### A8W8 Dynamic Quantization + +The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this +converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul. + +Example + +``` +# some user model and example input +... + +# convert linear modules to quantized linear modules +torchao.change_linear_weights_to_int8_dqtensors(model) + +# compile the model to improve performance +... +``` + This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor. @@ -81,7 +102,7 @@ Example ... # convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int8_woqtensors(model) +torchao.change_linear_weights_to_int8_woqtensors(model) # compile the model to improve performance ... @@ -102,7 +123,7 @@ Example ... # convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int4_woqtensors(model) +torchao.change_linear_weights_to_int4_woqtensors(model) # compile the model to improve performance ... diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/__init__.py b/torchao/__init__.py index e69de29bb2..c2634c5365 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -0,0 +1,24 @@ +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + change_linear_weights_to_int4_woqtensors, + swap_conv2d_1x1_to_linear, + autoquant, + change_linears_to_autoquantizable, + change_autoquantizable_to_quantized, +) + +__all__ = [ + "apply_weight_only_int8_quant", + "apply_dynamic_quant", + "change_linear_weights_to_int8_dqtensors", + "change_linear_weights_to_int8_woqtensors", + "change_linear_weights_to_int4_woqtensors", + "swap_conv2d_1x1_to_linear" + "safe_int_mm", + "autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", +] diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 525008a77d..1b421ab8e4 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -25,7 +25,7 @@ "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", - "do_autoquant", + "autoquant", "change_linears_to_autoquantizable", "change_autoquantizable_to_quantized", "quant_int8_dynamic_linear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 60eb29127c..f05958c84c 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -1,6 +1,4 @@ import torch -import os -from subprocess import check_output from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -79,26 +77,56 @@ def to_quantized(self, error_on_unseen, **kwargs): # default back to non-quantized weight if not seen self = AQFloatLinearWeight.from_float(self.weight) return self + + + # only want to do shape+final print a single time if multiple layers + # see/have same shapes so we gate on check_cache being empty for + # at least one of the class/shape combinations. + do_final_print = False + print_once = True + + def count_shapes(self, do_print=True): + differe_shape_count=0 + for shapes_and_dtype, times_seen in self.logged_data.items(): + differe_shape_count += 1 + if do_print: + act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype + print(f"activation_shapes: {act_shape}, times_seen: {times_seen}") + if do_print: + print(f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}") + return differe_shape_count + + # check each class best_time = torch.inf best_cls = None - do_print=False - # check each class for q_cls in self.qtensor_class_list: # for each logged shape+dtype, benchmark - cls_res=0 + cur_time=0 + shape_count = count_shapes(self, do_print=False) for shapes_and_dtype, times_seen in self.logged_data.items(): if check_cache(q_cls, shapes_and_dtype) is None: - do_print=True - self.tune_autoquant(q_cls, shapes_and_dtype, best_time) + # only do final print if we have to autotune at least one cls/shape pair + do_final_print=True + + # only print shapes once + if print_once == True: + print_once = False + count_shapes(self, do_print=True) + + time_for_best_shape = check_cache(best_cls, shapes_and_dtype) + time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape + self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) torch._dynamo.reset() - cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen - if best_time >= cls_res: - best_time = cls_res + cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen + if shape_count is not None and shape_count > 1: + print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") + if best_time >= cur_time: + best_time = cur_time best_cls = q_cls # only print if this is the first time seeing some cls+shape combo, # otherwise we will print the same thing for every layer. - if do_print: - print(f"for {self.logged_data}, best_cls={best_cls}") + if do_final_print: + print(f"best_cls={best_cls}\n") # TODO handle random cls args/kwargs? or should they be curried? self = best_cls.from_float(self.weight) return self @@ -145,6 +173,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) def do_autoquant_bench(op, *args, **kwargs): + """ + runs benchmark op(*args, **kwargs) avoiding torch.compile overhead + """ rep = kwargs.pop("rep", 100) warmup = kwargs.pop("warmup", 25) with torch.no_grad(): @@ -152,14 +183,14 @@ def do_autoquant_bench(op, *args, **kwargs): stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): - op(*args) + op(*args, **kwargs) stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - op(*args) + op(*args, **kwargs) res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") return res @@ -180,11 +211,11 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): else: func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") - res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias) + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) if res < best_time*1.1: res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900) res=(res2*.9+res*.1) - print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") + print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") return res class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): @@ -196,7 +227,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): if not _is_interpolate_mode(mode): return super()._autoquant_test(act_mat, weight, bias, best_time, mode) - # SAM best is between .8 to 1, SDXL also performs best in this range + # SAM best is between .8 and 1, SDXL also performs best in this range INTERPOLATION_CONSTANT = mode[1] w_qtensor = cls.from_float(weight) x_vals_int8, x_scales = quantize_activation_per_token_absmax( @@ -209,7 +240,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) - print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") + print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op if res_matmul>=best_time: @@ -220,7 +251,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): res = super()._autoquant_test(act_mat, weight, bias, to_beat) max_int_const_win = (best_time-res_matmul)/(res-res_matmul) res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul - print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") + print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") return res_f class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): @@ -252,6 +283,10 @@ def _autoquant_test(cls, act_mat, *args): return super()._autoquant_test(act_mat, *args) class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ def _quantized_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) @@ -265,7 +300,8 @@ class AQFloatLinearWeight(torch.Tensor, AQMixin): A class to be used in concert with AutoQuantizableLinearWeight to provide a default/non-quantized option. Only implements the bare minimum needed to work with the AutoQuantizableLinearWeight class using the same interfaces that would normally be - used by QTensor subclasses but for a default linear op instead. + used by QTensor subclasses but for a default linear op instead. Result of from_float + is not a tensor subclass, but rather the float tensor. """ def __init__(self): super().__init__() @@ -284,5 +320,5 @@ def from_float(cls, weight): AQWeightOnlyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight2, # AQWeightOnlyQuantizedLinearWeight3, - # 3rd version gets picked in situations where it is slower for the interpolation mode + # TODO this gets picked in places where it makes perf worse, why? ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 942bf043e9..06ffe21dcb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -37,7 +37,7 @@ "change_linear_weights_to_int8_woqtensors", "change_linear_weights_to_int4_woqtensors", "swap_conv2d_1x1_to_linear", - "do_autoquant", + "autoquant", "change_linears_to_autoquantizable", "change_autoquantizable_to_quantized", ] @@ -182,6 +182,9 @@ def change_autoquantizable_to_quantized(model, **kwargs): on benchmark results. Expectation is that these modules are torch.compiled afterwards. """ + hold = torch._dynamo.config.automatic_dynamic_shapes + torch._dynamo.config.automatic_dynamic_shapes = False + filter_fn = kwargs.pop( "filter_fn", lambda mod, *args: @@ -195,24 +198,22 @@ def change_autoquantizable_to_quantized(model, **kwargs): ), filter_fn, ) + torch._dynamo.config.automatic_dynamic_shapes = hold + torch._dynamo.reset() @torch.no_grad() -def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs): +def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs): """ Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer and applies that type of quantization. """ - hold = torch._dynamo.config.automatic_dynamic_shapes - torch._dynamo.config.automatic_dynamic_shapes = False change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) if not isinstance(example_input, (tuple, list)): assert isinstance(example_input, torch.Tensor) example_input = [example_input] model(*example_input) change_autoquantizable_to_quantized(model, **kwargs) - torch._dynamo.config.automatic_dynamic_shapes = hold - torch._dynamo.reset() return model def swap_conv2d_1x1_to_linear(model, filter_fn=None): diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index e973ab8ca9..73621e6297 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -7,7 +7,6 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.benchmark import Timer __all__ = [ "find_multiple", @@ -87,23 +86,3 @@ def get_model_size_in_bytes(model): for b in model.buffers(): s += b.nelement() * b.element_size() return s - - -def benchmark(f, *args, **kwargs): - if "best_time" in kwargs: - best_time = kwargs.pop("best_time") - else: - best_time = torch.inf - t0 = Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - - # warmup - t0.timeit(10) - res=t0.adaptive_autorange(min_run_time=.1) - # run more if median vs median minus iqr (interpolated based on number of runs left) is lower than best_time, - # stop if good res.iqr/res.median or have 20 samples - while res.median-res.iqr+res.iqr*len(res.times)/20 < best_time * 1e-3 and not (res.iqr/res.median<.02 or len(res.times)>=20): - res2 = t0.adaptive_autorange(min_run_time=.5) - res=res.merge([res2, res])[0] - return res.median * 1e3