From de51c5ffdc2c6036b7d0fae7f56354f727dab65a Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Wed, 15 Jan 2025 18:39:15 +0000 Subject: [PATCH 1/4] [Feat]: Add support for kleidiai quantization schemes Description: Allow int8_dynamic_activation_intx_weight to work with aten _dyn_quant_matmul_4bit op Needs : pytorch/pytorch#134124 or Pytorch > 2.6.0 Signed-off-by: Nikhil Gupta --- torchao/dtypes/affine_quantized_tensor.py | 7 +- ...8_dynamic_activation_intx_weight_layout.py | 66 +++++++++++++++++-- torchao/experimental/quant_api.py | 53 +++++++++++---- torchao/quantization/quant_api.py | 7 +- 4 files changed, 112 insertions(+), 21 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 715aaeb9ec..15ab6e54ed 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -204,6 +204,7 @@ def from_hp_to_intx( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, + bias: Optional[torch.Tensor] = None ): """Convert a high precision tensor to an integer affine quantized tensor.""" original_shape = input_float.shape @@ -276,7 +277,11 @@ def from_hp_to_intx( data = _layout.post_process(data) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) + args = [data, scale, zero_point, _layout] + # Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias + if bias is not None: + args.append(bias) + tensor_impl = tensor_impl_ctr(*args) return cls( tensor_impl, block_size, diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 7b2b1da145..fd1d4b7c69 100644 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +from enum import Enum, auto from typing import Optional, Tuple import torch @@ -20,6 +21,9 @@ from torchao.quantization.quant_primitives import ( ZeroPointDomain, ) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, +) logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -31,17 +35,33 @@ handler.setFormatter(formatter) logger.addHandler(handler) +class Target(Enum): + """Enum that indicates the backend target""" + + NATIVE = auto() + ATEN = auto() + +def target_from_str(target: str) -> Target: + if target.lower() == "native": + return Target.NATIVE + elif target.lower() == "aten": + return Target.ATEN + else: + raise ValueError(f"Invalid target: {target}") class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): bit_width: Optional[int] group_size: Optional[int] has_weight_zeros: Optional[bool] + # The target platform for the layout, 'native' or 'aten' + target: Optional[Target] def __init__( self, bit_width: Optional[int] = None, group_size: Optional[int] = None, has_weight_zeros: Optional[bool] = None, + target: Optional[str] = "native", ): if bit_width is not None: assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" @@ -51,6 +71,7 @@ def __init__( self.bit_width = bit_width self.group_size = group_size self.has_weight_zeros = has_weight_zeros + self.target = target_from_str(target) if not self.has_params_set(): assert ( @@ -60,13 +81,14 @@ def __init__( ), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False" def extra_repr(self): - return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}" + return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}" def has_params_set(self) -> bool: return ( (self.bit_width is not None) and (self.group_size is not None) and (self.has_weight_zeros is not None) + and (self.target is not None) ) @@ -125,9 +147,11 @@ def from_plain( scale: torch.Tensor, zero_point: Optional[torch.Tensor], layout: Layout, + bias: Optional[torch.Tensor] = None, ): assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" + assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}" # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor # when AOTI supports int @@ -136,6 +160,13 @@ def from_plain( n_tensor = torch.empty(0, n, dtype=torch.int8) k_tensor = torch.empty(0, k, dtype=torch.int8) + if layout.target == Target.ATEN: + assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + int_data = int_data.add(8) + int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8) + packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n) + return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) + if layout.has_weight_zeros: args = [ int_data.to(torch.int8), @@ -211,16 +242,13 @@ def __tensor_unflatten__( def _linear_check(input_tensor, weight_tensor, bias): layout = weight_tensor.tensor_impl.get_layout() return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( - bias is None + bias is None or layout.target == Target.ATEN # Aten target allows bias ) def _linear_impl(input_tensor, weight_tensor, bias): - assert ( - bias is None - ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" - def _impl_2d(input_tensor, weight_tensor): + def _impl_2d_native(input_tensor, weight_tensor): assert input_tensor.dim() == 2 assert weight_tensor.dim() == 2 @@ -255,6 +283,31 @@ def _impl_2d(input_tensor, weight_tensor): torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight" )(*args) + def _impl_2d_aten(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.tensor_impl.get_layout().group_size + packed_weight = weight_tensor.tensor_impl.packed_weight + return torch.ops.aten._dyn_quant_matmul_4bit( + input_tensor, packed_weight, group_size, k_, n) + + target = weight_tensor.tensor_impl.get_layout().target + + if target == Target.ATEN: + assert ( + TORCH_VERSION_AT_LEAST_2_6 == 1 + ), "Target.ATEN requires torch >= 2.6.0" + _impl_2d = _impl_2d_aten + elif target == Target.NATIVE: + _impl_2d = _impl_2d_native + assert ( + bias is None + ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' " + if input_tensor.dim() == 2: return _impl_2d(input_tensor, weight_tensor) @@ -268,7 +321,6 @@ def _impl_2d(input_tensor, weight_tensor): res = res.reshape(*lead_shape, m, n) return res - register_aqt_quantized_linear_dispatch( _linear_check, _linear_impl, diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 4e0906d0a0..57b2f66089 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import sys import logging from typing import Optional, Union @@ -18,14 +19,18 @@ PerGroup, PerRow, ) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, +) +from torchao.dtypes import PlainLayout logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) -import sys handler = logging.StreamHandler(sys.stdout) -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) @@ -506,6 +511,7 @@ def int8_dynamic_activation_intx_weight( weight_dtype: torch.dtype = torch.int4, granularity: Union[PerRow, PerGroup] = PerGroup(128), has_weight_zeros: bool = False, + target: str = "native", weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow @@ -531,13 +537,28 @@ def int8_dynamic_activation_intx_weight( - The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32) - act_mapping_type must be MappingType.ASYMMETRIC """ - try: - torch.ops.torchao._pack_8bit_act_4bit_weight - except AttributeError: - raise Exception( - "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." - + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." - ) + + if target == "aten": + if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) or \ + weight_dtype != torch.int4 or \ + has_weight_zeros != True or \ + weight_mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError( + f"target 'aten' requires:\n" + f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" + f"- has_weight_zeros to be True,\n" + f"- weight_dtype to be torch.int4,\n" + f"- weight_mapping_type to be MappingType.SYMMETRIC" + ) + elif not isinstance(layout, PlainLayout): + try: + torch.ops.torchao._pack_8bit_act_4bit_weight + except AttributeError: + raise Exception( + "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." + + " You can also set target to 'aten' if you are using ARM CPU." + + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." + ) dtype_to_bit_width = { torch.int1: 1, @@ -556,7 +577,7 @@ def int8_dynamic_activation_intx_weight( bit_width = dtype_to_bit_width[weight_dtype] layout_arg = layout - def apply(weight): + def apply(weight, bias: Optional[torch.Tensor] = None): if isinstance(granularity, PerGroup): group_size = granularity.group_size elif isinstance(granularity, PerRow): @@ -569,6 +590,7 @@ def apply(weight): assert weight.shape[-1] % group_size == 0 layout = layout_arg + scale_dtype = None if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): assert ( weight.device == torch.device("cpu") @@ -584,7 +606,13 @@ def apply(weight): bit_width=bit_width, group_size=group_size, has_weight_zeros=has_weight_zeros, + target=target, ) + if target == "aten": + assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + if torch.backends.kleidiai.is_available(): + if isinstance(granularity, PerGroup): + scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype quant_min = -(1 << (bit_width - 1)) quant_max = (1 << (bit_width - 1)) - 1 @@ -596,12 +624,14 @@ def apply(weight): quant_min=quant_min, quant_max=quant_max, eps=torch.finfo(torch.float32).eps, + scale_dtype=scale_dtype, zero_point_dtype=torch.int8, preserve_zero=has_weight_zeros, zero_point_domain=ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE, _layout=layout, + bias=bias ) # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused @@ -620,7 +650,8 @@ def apply(weight): weight = to_linear_activation_quantized(weight, activation_quant_func) return weight - return _get_linear_subclass_inserter(apply) + propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten" + return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias) class UIntxWeightOnlyQuantizedLinear(nn.Module): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 02af4ced91..b1f6d7e2bb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -450,15 +450,18 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs): +def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ def insert_subclass(lin): requires_grad = allow_requires_grad and lin.weight.requires_grad + args = [lin.weight] + if propagate_bias == True: + args.append(lin.bias) lin.weight = torch.nn.Parameter( - constructor(lin.weight, **kwargs), requires_grad=requires_grad + constructor(*args, **kwargs), requires_grad=requires_grad ) lin.extra_repr = types.MethodType(_linear_extra_repr, lin) return lin From 8fdf6a944590b4713ac43a238e7cc62eb9e7a06f Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Fri, 17 Jan 2025 13:14:17 +0000 Subject: [PATCH 2/4] [Fix]: Use Custom Tensor for dyn_quant_matmul_4bit aten op Signed-off-by: Nikhil Gupta --- torchao/dtypes/affine_quantized_tensor.py | 7 +- ...8_dynamic_activation_intx_weight_layout.py | 81 +++++++++++++++++++ torchao/experimental/quant_api.py | 44 +++++----- torchao/quantization/quant_api.py | 5 +- 4 files changed, 107 insertions(+), 30 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 15ab6e54ed..715aaeb9ec 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -204,7 +204,6 @@ def from_hp_to_intx( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, - bias: Optional[torch.Tensor] = None ): """Convert a high precision tensor to an integer affine quantized tensor.""" original_shape = input_float.shape @@ -277,11 +276,7 @@ def from_hp_to_intx( data = _layout.post_process(data) tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - args = [data, scale, zero_point, _layout] - # Only PackedLinearInt8DynamicActivationIntxWeightLayout() with "aten" target supports bias - if bias is not None: - args.append(bias) - tensor_impl = tensor_impl_ctr(*args) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) return cls( tensor_impl, block_size, diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py index fd1d4b7c69..9d42596793 100644 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -12,6 +12,8 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.affine_quantized_tensor_ops import ( @@ -20,7 +22,11 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( ZeroPointDomain, + MappingType, + choose_qparams_affine, + quantize_affine, ) + from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, ) @@ -325,3 +331,78 @@ def _impl_2d_aten(input_tensor, weight_tensor): _linear_check, _linear_impl, ) + + +class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor): + """ + PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class. + """ + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(), + use_hqq: bool = False, + bias: Optional[torch.Tensor] = None + ): + assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" + assert isinstance( + _layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" + assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + + scale, zero_point = choose_qparams_affine( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None + # TODO should probably consolidate ZeroPointDomain.NONE and None + if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + zero_point = None + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + # Note: output will be uint8 tensor for sub byte tensors for now + + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + +to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 57b2f66089..8c63874dc0 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -494,6 +494,7 @@ def quantize(self, model: nn.Module) -> nn.Module: from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( PackedLinearInt8DynamicActivationIntxWeightLayout, + to_packedlinearint8dynamicactivationintxweight_quantized_intx, ) from torchao.quantization.linear_activation_quantized_tensor import ( to_linear_activation_quantized, @@ -576,6 +577,7 @@ def int8_dynamic_activation_intx_weight( ) bit_width = dtype_to_bit_width[weight_dtype] layout_arg = layout + propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten" def apply(weight, bias: Optional[torch.Tensor] = None): if isinstance(granularity, PerGroup): @@ -591,6 +593,10 @@ def apply(weight, bias: Optional[torch.Tensor] = None): layout = layout_arg scale_dtype = None + tensor_quantizer = to_affine_quantized_intx + quant_min = -(1 << (bit_width - 1)) + quant_max = (1 << (bit_width - 1)) - 1 + if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): assert ( weight.device == torch.device("cpu") @@ -613,26 +619,23 @@ def apply(weight, bias: Optional[torch.Tensor] = None): if torch.backends.kleidiai.is_available(): if isinstance(granularity, PerGroup): scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype - - quant_min = -(1 << (bit_width - 1)) - quant_max = (1 << (bit_width - 1)) - 1 - weight = to_affine_quantized_intx( - weight, - mapping_type=weight_mapping_type, - block_size=(1, group_size), - target_dtype=torch.int32, - quant_min=quant_min, - quant_max=quant_max, - eps=torch.finfo(torch.float32).eps, - scale_dtype=scale_dtype, - zero_point_dtype=torch.int8, - preserve_zero=has_weight_zeros, - zero_point_domain=ZeroPointDomain.INT - if has_weight_zeros - else ZeroPointDomain.NONE, - _layout=layout, - bias=bias - ) + tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx + + quantizer_args = [weight, + weight_mapping_type, + (1, group_size), + torch.int32, + quant_min, + quant_max, + torch.finfo(torch.float32).eps, + scale_dtype, + torch.int8, + has_weight_zeros, + ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE, + layout, + False] + ([bias] if propagate_bias else []) + + weight = tensor_quantizer(*quantizer_args) # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused # with the kernel and it should not be applied separately @@ -650,7 +653,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None): weight = to_linear_activation_quantized(weight, activation_quant_func) return weight - propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten" return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b1f6d7e2bb..bbe9b1cb6b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -457,11 +457,10 @@ def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, pro def insert_subclass(lin): requires_grad = allow_requires_grad and lin.weight.requires_grad - args = [lin.weight] if propagate_bias == True: - args.append(lin.bias) + kwargs["bias"] = lin.bias lin.weight = torch.nn.Parameter( - constructor(*args, **kwargs), requires_grad=requires_grad + constructor(lin.weight, **kwargs), requires_grad=requires_grad ) lin.extra_repr = types.MethodType(_linear_extra_repr, lin) return lin From ca8a5f105021bb6ddbd2f6c980c58dcd89097e9d Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Mon, 20 Jan 2025 13:08:15 +0000 Subject: [PATCH 3/4] [Refactor]: Move target attribute to Layout Class & fix target checks Signed-off-by: Nikhil Gupta --- torchao/experimental/quant_api.py | 43 ++++++++++++++++++------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 8c63874dc0..48026c9489 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -495,6 +495,7 @@ def quantize(self, model: nn.Module) -> nn.Module: from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( PackedLinearInt8DynamicActivationIntxWeightLayout, to_packedlinearint8dynamicactivationintxweight_quantized_intx, + Target, ) from torchao.quantization.linear_activation_quantized_tensor import ( to_linear_activation_quantized, @@ -512,10 +513,9 @@ def int8_dynamic_activation_intx_weight( weight_dtype: torch.dtype = torch.int4, granularity: Union[PerRow, PerGroup] = PerGroup(128), has_weight_zeros: bool = False, - target: str = "native", weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="native"), # PlainLayout() also works, but will be slow ): """ Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers. @@ -539,19 +539,16 @@ def int8_dynamic_activation_intx_weight( - act_mapping_type must be MappingType.ASYMMETRIC """ - if target == "aten": - if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) or \ - weight_dtype != torch.int4 or \ - has_weight_zeros != True or \ - weight_mapping_type != MappingType.SYMMETRIC: - raise NotImplementedError( - f"target 'aten' requires:\n" - f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" - f"- has_weight_zeros to be True,\n" - f"- weight_dtype to be torch.int4,\n" - f"- weight_mapping_type to be MappingType.SYMMETRIC" + def is_torchao_op_skippable(layout): + return ( + isinstance(layout, PlainLayout) or + ( + isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and + layout.target == Target.ATEN ) - elif not isinstance(layout, PlainLayout): + ) + + if not is_torchao_op_skippable(layout): try: torch.ops.torchao._pack_8bit_act_4bit_weight except AttributeError: @@ -577,7 +574,7 @@ def int8_dynamic_activation_intx_weight( ) bit_width = dtype_to_bit_width[weight_dtype] layout_arg = layout - propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten" + propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target == Target.ATEN def apply(weight, bias: Optional[torch.Tensor] = None): if isinstance(granularity, PerGroup): @@ -612,13 +609,23 @@ def apply(weight, bias: Optional[torch.Tensor] = None): bit_width=bit_width, group_size=group_size, has_weight_zeros=has_weight_zeros, - target=target, + target="aten" if layout.target == Target.ATEN else "native", ) - if target == "aten": + if layout.target == Target.ATEN: + if weight_dtype != torch.int4 or \ + has_weight_zeros != True or \ + weight_mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError( + f"target 'aten' requires:\n" + f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" + f"- has_weight_zeros to be True,\n" + f"- weight_dtype to be torch.int4,\n" + f"- weight_mapping_type to be MappingType.SYMMETRIC" + ) assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" if torch.backends.kleidiai.is_available(): if isinstance(granularity, PerGroup): - scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype + scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx quantizer_args = [weight, From eb61a2456537f3cb8d67fd3a913af3d2c10b8869 Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Thu, 30 Jan 2025 15:53:12 +0000 Subject: [PATCH 4/4] [Fix]: Enable SYMMETRIC_NO_CLIPPING_ERR Mapping type and tests Signed-off-by: Nikhil Gupta --- torchao/experimental/docs/readme.md | 31 +++++++ torchao/experimental/quant_api.py | 4 +- ...tivation_intx_weight_layout_target_aten.py | 84 +++++++++++++++++++ 3 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index 7f0970f792..a178c9b328 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -98,6 +98,37 @@ quantize_( ) ``` +KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: + +```python +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.quantization.quant_primitives import MappingType + +my_model = Model() + +quantize_( + my_model, + int8_dynamic_activation_intx_weight( + weight_dtype=torch.int4, + granularity=PerGroup(32), # PerRow() is also supported + has_weight_zeros=True, # Should be True + weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"), + ), +) +``` + If you get stuck, consult `torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py` for a working example. diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 48026c9489..e77d09d98b 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -614,13 +614,13 @@ def apply(weight, bias: Optional[torch.Tensor] = None): if layout.target == Target.ATEN: if weight_dtype != torch.int4 or \ has_weight_zeros != True or \ - weight_mapping_type != MappingType.SYMMETRIC: + weight_mapping_type == MappingType.ASYMMETRIC: raise NotImplementedError( f"target 'aten' requires:\n" f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" f"- has_weight_zeros to be True,\n" f"- weight_dtype to be torch.int4,\n" - f"- weight_mapping_type to be MappingType.SYMMETRIC" + f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR" ) assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" if torch.backends.kleidiai.is_available(): diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py new file mode 100644 index 0000000000..c1c5ed771e --- /dev/null +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch + +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.utils import unwrap_tensor_subclass +from torchao.quantization.quant_primitives import MappingType + + +class TestPackedLinearInt8DynamicActivationIntxWeightLayoutAten(unittest.TestCase): + def test_accuracy(self): + """ + Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing + its results to the results of a reference model that uses PlainLayout() + """ + granularities = [PerRow()] + m = 32 + n = 128 + k = 256 + activations = torch.randn(m, k) + weight_mapping_type = MappingType.SYMMETRIC_NO_CLIPPING_ERR + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for weight_dtype in [ + torch.int4, + ]: + for has_weight_zeros in [True]: + for granularity in granularities: + print( + f"Testing weight_dtype={weight_dtype}, has_weight_zeros={ + has_weight_zeros}, granularity={granularity}" + ) + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + weight_mapping_type=weight_mapping_type, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="aten"), # default + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PlainLayout(), + ), + ) + + with torch.no_grad(): + res = quantized_model(activations) + ref = quantized_model_reference(activations) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.04) + + +if __name__ == "__main__": + unittest.main()