Skip to content

Autoquant #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
_replace_with_custom_fn_if_matches_filter,
do_autoquant
)
from torchao.quantization.quant_primitives import (
dequantize_per_channel,
Expand Down Expand Up @@ -53,6 +54,7 @@
compute_error as SQNR,
_fqn_to_op_to_shape_to_count,
LoggingTensorMode,
benchmark
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -1195,6 +1197,32 @@ def test_on_dummy_distilbert(self):
print("sqnr_pt_quant", sqnr_pt_quant)
self.assertTrue(sqnr_sq >= 8.0)

# TODO FINISH TEST CODE
class TestAutoQuant(unittest.TestCase):
def test_auto_quant(self):
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.use_mixed_mm = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.coordinate_descent_tuning = True
torch._dynamo.config.automatic_dynamic_shapes = False

for m,k,n in [
(1, 1024, 1024),
(64, 1024, 1024),
(4096, 1024, 1024),
(1, 1024, 4096),
(64, 1024, 4096),
(1, 4096, 1024),
(64, 4096, 1024),
(4096, 4096, 1024),
]:
example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to("cuda").to(torch.bfloat16)
do_autoquant(model, example_input)

if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"dynamically_quantize_per_channel",
"dequantize_per_tensor",
"dequantize_per_channel",
"do_autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"quant_int8_dynamic_linear",
"quant_int8_matmul",
"quant_int8_dynamic_per_token_linear",
Expand Down
255 changes: 255 additions & 0 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import torch

from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torch.utils._python_dispatch import return_and_correct_aliasing
from .utils import benchmark
from .quant_primitives import (
quantize_activation_per_token_absmax,
safe_int_mm,
)
import torch.nn.functional as F

aten = torch.ops.aten

AUTOQUANT_CACHE = {}

def check_cache(cls, shape, dtype):
return AUTOQUANT_CACHE.get((cls, shape, dtype), None)

def update_cache(cls, shape, dtype, res):
AUTOQUANT_CACHE[(cls, shape, dtype)] = res

class AutoQuantizableLinearWeight(torch.Tensor):
"""
when run, finds best type of quantization for this tensor and swaps itself with that
"""
@staticmethod
def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
kwargs["device"] = weight.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
)
kwargs["dtype"] = (
kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype
)
kwargs["requires_grad"] = False
shape = kwargs.pop("shape", weight.shape)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, weight, qtensor_class_list, *args, **kwargs):
self.weight = weight
self.qtensor_class_list = qtensor_class_list
self.logged_shape = None
self.logged_dtype = None

def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, "
f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})"
)

@staticmethod
def log_shape(act_mat, w_autoquant, bias):
orig_shape = act_mat.shape
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
logged_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
logged_dtype = act_mat.dtype
w_autoquant.logged_shape = logged_shape
w_autoquant.logged_dtype = logged_dtype
for q_cls in w_autoquant.qtensor_class_list:
if check_cache(q_cls, logged_shape, logged_dtype) is None:
update_cache(q_cls, logged_shape, logged_dtype, None)
y = torch.mm(act_mat, w_autoquant.weight.t())
y = y.reshape(*orig_shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y

def tune_autoquant(self, q_cls, best_time):
act_shape, w_shape, bias_shape = self.logged_shape
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
with torch.no_grad():
act_mat = torch.randn(act_shape, dtype=self.logged_dtype, device=self.device)
bias = None if bias_shape is None else torch.randn(bias_shape, dtype=self.logged_dtype, device=self.device)
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time)
update_cache(q_cls, self.logged_shape, self.logged_dtype, res)

def to_quantized(self, error_on_unseen, **kwargs):
if error_on_unseen and (self.logged_shape is None or self.logged_dtype is None):
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
elif (self.logged_shape is None or self.logged_dtype is None) and not error_on_unseen:
# default back to non-quantized weight if not seen
self = AQFloatLinearWeight.from_float(self.weight)
return self
best_time = torch.inf
best_cls = None
do_print=False
for q_cls in self.qtensor_class_list:
if check_cache(q_cls, self.logged_shape, self.logged_dtype) is None:
do_print=True
self.tune_autoquant(q_cls, best_time)
torch._dynamo.reset()
cls_res = AUTOQUANT_CACHE.get((q_cls, self.logged_shape, self.logged_dtype), torch.inf)
if best_time >= cls_res:
best_time = cls_res
best_cls = q_cls
if do_print:
print(f"shape={self.logged_shape}, dtype={self.logged_dtype}, best_cls={best_cls}")
# TODO handle random cls args/kwargs? or should they be curried
self = best_cls.from_float(self.weight)
return self

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.weight), self.qtensor_class_list, dtype=self.dtype
)

def __tensor_flatten__(self):
return ["weight"], [self.qtensor_class_list, self.dtype, self.shape]

@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
weight = tensor_data_dict["weight"]
qtensor_class_list, dtype, shape = tensor_attributes[0]
return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)

@classmethod
def from_float(cls, weight, qtensor_class_list):
return cls(weight, qtensor_class_list)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func is torch.nn.functional.linear:
mat1, w_autoquant, bias = (
args[0],
args[1],
args[2] if len(args)>2 else None
)
return cls.log_shape(mat1, w_autoquant, bias)

try:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except:
print(f"ERR: subclass doesn't implement {func}")

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func is aten.detach.default:
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))

class AQMixin():
"""
Mixin to turn normal quantized subclasses into autoquantizable ones
"""
@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, *args, **kwargs):
w_qtensor = cls.from_float(weight)
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
with torch.no_grad():
torch.cuda.synchronize()
res = benchmark(q_c_op, act_mat, w_qtensor, bias, best_time=best_time)
print(cls, res)
return res

class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
"""
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""
@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time):
# SAM best is between .51 to .60, SDXL also performs best in this range
INTERPOLATION_CONSTANT=.55
w_qtensor = cls.from_float(weight)
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
act_mat.reshape(-1, act_mat.shape[-1])
)
quantized_matmul = (
lambda x_vals_int8, x_scales, w_vals_int8:
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
with torch.no_grad():
res_matmul=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data, best_time=best_time)
print(cls, "matmul", res_matmul)

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
if res_matmul>=best_time:
return res_matmul

# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
print(cls, "full", INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul)
return INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul


class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
"""

class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
orig_dtype = act_mat.dtype
orig_shape = act_mat.shape
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y.to(orig_dtype)

@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time):
# if act_mat has batchsize>2 don't use this kernel
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>2:
return torch.inf
return super()._autoquant_test(act_mat, weight, bias, best_time)

class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
def _quantized_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
y=y.reshape(*orig_shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y


class AQFloatLinearWeight(torch.Tensor, AQMixin):
"""
A class to be used in concert with AutoQuantizableLinearWeight to provide a
default/non-quantized option. Only implements the bare minimum needed to work with the
AutoQuantizableLinearWeight class using the same interfaces that would normally be
used by QTensor subclasses but for a default linear op instead.
"""
def __init__(self):
super().__init__()

@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor, bias)

@classmethod
def from_float(cls, weight):
return weight

DEFAULT_CLASS_LIST = [
AQFloatLinearWeight,
AQInt8DynamicallyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
AQWeightOnlyQuantizedLinearWeight3,
]
50 changes: 48 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@
from .weight_only import (
WeightOnlyInt8QuantLinear,
)
from .autoquant import AutoQuantizableLinearWeight, DEFAULT_CLASS_LIST

__all__ = [
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_int8_dqtensors",
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear"
"swap_conv2d_1x1_to_linear",
"do_autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
]


Expand Down Expand Up @@ -95,9 +99,11 @@ def apply_dynamic_quant(model, filter_fn=None):


def _get_subclass_inserter(cls, **kwargs):
method = kwargs.pop("method", "from_float")
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(
cls.from_float(lin.weight, **kwargs), requires_grad=False
# cls.from_float(...)
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
)
return lin

Expand Down Expand Up @@ -153,6 +159,46 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
filter_fn,
)


def change_linears_to_autoquantizable(model, **kwargs):
filter_fn = kwargs.pop("filter_fn", _is_linear)
kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST)
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs),
filter_fn if filter_fn is not None else _is_linear,
)

def change_autoquantizable_to_quantized(model, **kwargs):
filter_fn = kwargs.pop(
"filter_fn",
lambda mod, *args:
_is_linear(mod, *args) and
isinstance(mod.weight, AutoQuantizableLinearWeight)
)
error_on_unseen=kwargs.pop("error_on_unseen", True)
_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs
),
filter_fn,
)

@torch.no_grad()
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear):
hold = torch._dynamo.config.automatic_dynamic_shapes
torch._dynamo.config.automatic_dynamic_shapes = False
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list)
if not isinstance(example_input, (tuple, list)):
assert isinstance(example_input, torch.Tensor)
example_input = [example_input]
model(*example_input)
change_autoquantizable_to_quantized(model)
torch._dynamo.config.automatic_dynamic_shapes = hold
torch._dynamo.reset()
return model

def swap_conv2d_1x1_to_linear(model, filter_fn=None):
"""
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.
Expand Down
Loading