diff --git a/test/quantization/quantize_/test_int4_groupwise_preshuffle.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py similarity index 50% rename from test/quantization/quantize_/test_int4_groupwise_preshuffle.py rename to test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index 9bfe6dffdb..ba3656995c 100644 --- a/test/quantization/quantize_/test_int4_groupwise_preshuffle.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -4,14 +4,18 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import tempfile import unittest import torch from torch.testing._internal.common_utils import ( TestCase, + instantiate_parametrized_tests, + parametrize, run_tests, ) +from torchao.float8.config import e4m3_dtype from torchao.quantization import ( FbgemmConfig, quantize_, @@ -23,6 +27,45 @@ is_sm_at_least_90, ) +if TORCH_VERSION_AT_LEAST_2_8: + BF16_ACT_CONFIG = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + preshuffle=True, + ) + + BF16_ACT_BMM_CONFIG = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + preshuffle=True, + ) + + FP8_ACT_CONFIG = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + preshuffle=True, + ) + + FP8_ACT_BMM_CONFIG = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + preshuffle=True, + ) + +else: + BF16_ACT_CONFIG = None + BF16_ACT_BMM_CONFIG = None + FP8_ACT_CONFIG = None + FP8_ACT_BMM_CONFIG = None + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -30,35 +73,25 @@ @unittest.skipIf( not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) -class TestInt4GroupwisePreshuffleTensor(TestCase): +class TestInt4PreshuffledTensor(TestCase): def setUp(self): - self.config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - preshuffle=True, - ) - self.bmm_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - preshuffle=True, - ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] - def test_linear(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_linear(self, config): dtype = torch.bfloat16 device = "cuda" input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - quantize_(linear, self.config) + quantize_(linear, config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) - def test_bmm(self): + # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449` + # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG]) + @parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG]) + def test_bmm(self, bmm_config): class M(torch.nn.Module): def __init__(self, weight): super().__init__() @@ -74,32 +107,46 @@ def forward(self, x): m = M(weight).eval() original = m(input) m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantize_(m, bmm_config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) - def test_to_device(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_to_device(self, config): for device in self.GPU_DEVICES: linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device=device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) - def test_module_path(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_module_path(self, config): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) self.assertEqual( str(type(linear.weight)), - "", + "", ) + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestInt4PreshuffledTensor) + if __name__ == "__main__": run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index e75fe5e048..f87d038430 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -87,8 +87,8 @@ dequantize_affine, quantize_affine, ) -from .quantize_ import ( - Int4GroupwisePreshuffleTensor, +from .quantize_.workflows import ( + Int4PreshuffledTensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -153,7 +153,7 @@ "ModuleFqnToConfig", "FbgemmConfig", # tensor subclasses - "Int4GroupwisePreshuffleTensor", + "Int4PreshuffledTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4662e20fc9..d692b52bdc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -67,8 +67,8 @@ LinearActivationWeightObservedTensor, ) from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size -from torchao.quantization.quantize_ import ( - Int4GroupwisePreshuffleTensor, +from torchao.quantization.quantize_.workflows import ( + Int4PreshuffledTensor, ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -2056,6 +2056,7 @@ class FbgemmConfig(AOBaseConfig): weight_dtype (torch.dtype): weight dtype of the kernel output_dtype (torch.dtype): output dtype of the kernel group_size (int): The group size for weight + preshuffle (bool): whether preshuffle the weights or not """ input_dtype: torch.dtype @@ -2073,7 +2074,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: _SUPPORTED_DTYPES = { (torch.bfloat16, torch.int4, torch.bfloat16), - (e4m3_dtype, e4m3_dtype, torch.bfloat16), + (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16), } if ( @@ -2082,8 +2083,10 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: and (config.output_dtype == torch.bfloat16) ): if config.preshuffle: - weight = Int4GroupwisePreshuffleTensor.from_float( - module.weight, config.block_size + weight = Int4PreshuffledTensor.from_float( + module.weight, + config.block_size, + activation_dtype=torch.bfloat16, ) else: weight = to_fbgemm_int4( @@ -2093,6 +2096,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module + if ( + (config.input_dtype == e4m3_dtype) + and (config.weight_dtype == torch.int4) + and (config.output_dtype == torch.bfloat16) + ): + if config.preshuffle: + weight = Int4PreshuffledTensor.from_float( + module.weight, + config.block_size, + activation_dtype=torch.float8_e4m3fn, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module elif ( (config.input_dtype == e4m3_dtype) and (config.weight_dtype == e4m3_dtype) diff --git a/torchao/quantization/quantize_/__init__.py b/torchao/quantization/quantize_/__init__.py index 049b71631b..e69de29bb2 100644 --- a/torchao/quantization/quantize_/__init__.py +++ b/torchao/quantization/quantize_/__init__.py @@ -1,9 +0,0 @@ -from .int4_groupwise_preshuffle_tensor import ( - Int4GroupwisePreshuffleTensor, -) - -Int4GroupwisePreshuffleTensor.__module__ = "torchao.quantization" - -__all__ = [ - "Int4GroupwisePreshuffleTensor", -] diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py new file mode 100644 index 0000000000..40548e0e0e --- /dev/null +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -0,0 +1,7 @@ +from .int4.int4_preshuffled_tensor import ( + Int4PreshuffledTensor, +) + +__all__ = [ + "Int4PreshuffledTensor", +] diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py similarity index 54% rename from torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py rename to torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index 1313be5128..11cd0a145a 100644 --- a/torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -6,7 +6,7 @@ import importlib.util -from typing import List +from typing import List, Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -18,7 +18,7 @@ ) __all__ = [ - "Int4GroupwisePreshuffleTensor", + "Int4PreshuffledTensor", ] aten = torch.ops.aten @@ -26,30 +26,39 @@ if importlib.util.find_spec("fbgemm_gpu") is None: quantize_int4_preshuffle = None + quantize_fp8_row = None else: - from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle + from fbgemm_gpu.experimental.gen_ai.quantize import ( + quantize_fp8_row, + quantize_int4_preshuffle, + ) -class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): +class Int4PreshuffledTensor(TorchAOBaseTensor): """ Groupwise int4 weight only quantization Tensor Attributes: - packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed - group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor - dtype is the same as the original Tensor dtype - group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor - dtype is the same as the original Tensor dtype + _data: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed + for bf16 activation: + group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor + dtype is the same as the original Tensor dtype + group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor + dtype is the same as the original Tensor dtype + for float8 activation: + group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor + dtype is float8 + row_scale: (N,) for 2D Tensor, (B, N) for 3D Tensor + dtype is the same as the original Tensor dtype Non-Tensor Attributes: group_size: the group size for groupwise quantization - shape_multiplier: is the multipler from packed_weight to the real weight, since + shape_multiplier: is the multipler from _data to the real weight, since we pack the weight for int4, for example, when we pack the last dimension for a 2D tensor, the shape_multiplier will be [1, 2] shape: shape of the original Tensor - Note: - Details for preshuffle for fbgemm kernel: + Note on Details for preshuffle for fbgemm kernel: We use WGMMA instruction for efficient matrix multiplication in H100 Tensor Core. To address a major inefficiency in how WGMMA tiles are loaded into shared memory before @@ -61,58 +70,88 @@ class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): loads so having to load all four groups is wasteful. We can optimize weight loading by shuffling the order of elements such that all 4 groups are sequential in memory. This allows us to perform a single 64 bit load to move all needed weights for the thread into register memory. + + Note for float8 activation int4 weight kernel: + float8 activation int4 weight kernel doesn't work with zero_point, since it use table lookup approach which + requires symmetric quantization """ - tensor_data_attrs = ["packed_weight", "group_scale", "group_zero"] + tensor_data_attrs = ["_data", "group_scale"] tensor_attributes = ["group_size", "shape_multiplier", "shape"] def __new__( - cls, packed_weight, group_scale, group_zero, group_size, shape_multiplier, shape + cls, + _data, + group_scale, + group_zero, + row_scale, + group_size, + shape_multiplier, + shape, ): kwargs = {} - kwargs["device"] = packed_weight.device + kwargs["device"] = _data.device kwargs["dtype"] = group_scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, - packed_weight, - group_scale, - group_zero, - group_size, - shape_multiplier, - shape, + _data: torch.Tensor, + group_scale: torch.Tensor, + group_zero: Optional[torch.Tensor], + row_scale: Optional[torch.Tensor], + group_size: int, + shape_multiplier: List[int], + shape: List[int], ): - self.packed_weight = packed_weight + # one and only one of group_scale and group_zero should be None + assert group_zero is None or row_scale is None + assert not (group_zero is not None and row_scale is not None) + self._data = _data self.group_scale = group_scale self.group_zero = group_zero + self.row_scale = row_scale self.shape_multiplier = shape_multiplier self.group_size = group_size def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] + if getattr(self, "group_zero") is None: + assert getattr(self, "row_scale") is not None + return self.tensor_data_attrs + ["row_scale"], [ + getattr(self, attr) for attr in self.tensor_attributes + ] + else: + return self.tensor_data_attrs + ["group_zero"], [ + getattr(self, attr) for attr in self.tensor_attributes + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): + tensors = [tensor_data_dict[name] for name in cls.tensor_data_attrs] + tensors.append(tensor_data_dict.get("group_zero", None)) + tensors.append(tensor_data_dict.get("row_scale", None)) return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensors, *tensor_attributes, ) def _apply_fn_to_data(self, fn): + tensors = [fn(getattr(self, name)) for name in self.tensor_data_attrs] + t1 = getattr(self, "group_zero") + tensors.append(fn(t1) if t1 is not None else None) + t2 = getattr(self, "row_scale") + tensors.append(fn(t2) if t2 is not None else None) return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *tensors, *[getattr(self, attr) for attr in self.tensor_attributes], ) def __repr__(self): return ( - f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " + f"{self.__class__.__name__}(weight={self._data}, group_size={self.group_size}, " f"shape_multiplier={self.shape_multiplier}, shape={self.shape}, device={self.device}, dtype={self.dtype}, " f"requires_grad={self.requires_grad})" ) @@ -124,9 +163,10 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") return self.__class__( - self.packed_weight.to(device), + self._data.to(device), self.group_scale.to(device), - self.group_zero.to(device), + self.group_zero.to(device) if self.group_zero is not None else None, + self.row_scale.to(device) if self.row_scale is not None else None, self.group_size, self.shape_multiplier, self.shape, @@ -137,44 +177,72 @@ def from_float( cls, w: torch.Tensor, block_size: List[int], + activation_dtype: torch.dtype = torch.bfloat16, ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" ) + + _SUPPORTED_DTYPE_TO_STR = { + torch.bfloat16: "bf16", + torch.float8_e4m3fn: "fp8", + } + assert activation_dtype in _SUPPORTED_DTYPE_TO_STR, ( + f"activation dtype {activation_dtype} is not supported, supported ones are: {_SUPPORTED_DTYPE_TO_STR.keys()}" + ) + if quantize_int4_preshuffle is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + assert all(x == 1 for x in block_size[:-1]), ( + "Only groupwise quant is supported right now" + ) group_size = block_size[-1] original_shape = w.shape + activation_dtype_str = _SUPPORTED_DTYPE_TO_STR[activation_dtype] + if w.ndim >= 3: wq, scales = zip( - *[quantize_int4_preshuffle(i.cuda(), dtype="bf16") for i in w] + *[ + quantize_int4_preshuffle(i.cuda(), dtype=activation_dtype_str) + for i in w + ] ) wq = torch.stack(wq, dim=0) - group_scale, group_zero = zip(*scales) - group_zero = torch.stack(group_zero, dim=0).contiguous() + group_scale, group_zero_or_row_scale = zip(*scales) + group_zero_or_row_scale = torch.stack( + group_zero_or_row_scale, dim=0 + ).contiguous() group_scale = torch.stack(group_scale, dim=0).contiguous() else: - wq, (group_scale, group_zero) = quantize_int4_preshuffle( - w.cuda(), dtype="bf16" + wq, (group_scale, group_zero_or_row_scale) = quantize_int4_preshuffle( + w.cuda(), dtype=activation_dtype_str ) + if activation_dtype == torch.bfloat16: + group_zero = group_zero_or_row_scale + row_scale = None + else: + group_zero = None + row_scale = group_zero_or_row_scale + shape_multiplier = [1] * wq.ndim shape_multiplier[-1] = 2 del w - return Int4GroupwisePreshuffleTensor( - packed_weight=wq, + return Int4PreshuffledTensor( + _data=wq, group_scale=group_scale, group_zero=group_zero, + row_scale=row_scale, group_size=group_size, shape_multiplier=shape_multiplier, shape=original_shape, ) -implements = Int4GroupwisePreshuffleTensor.implements +implements = Int4PreshuffledTensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -187,23 +255,21 @@ def _(func, types, args, kwargs): orig_input_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] - wq = weight_tensor.packed_weight.contiguous() + wq = weight_tensor._data.contiguous() group_scale = weight_tensor.group_scale.contiguous() - group_zero = weight_tensor.group_zero.contiguous() - - if input_tensor.dim() == 3: - B, M, _ = input_tensor.shape - _, N, _ = wq.shape - res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) - for i in range(B): - res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( - input_tensor[i], wq[i], group_scale[i], group_zero[i] - ) - else: - # Otherwise run gemm normally. + # bf16 activation + if weight_tensor.group_zero is not None: + group_zero = weight_tensor.group_zero.contiguous() res = torch.ops.fbgemm.bf16i4bf16_shuffled( input_tensor, wq, group_scale, group_zero ) + else: + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + res = torch.ops.fbgemm.f8i4bf16_shuffled( + xq, wq, x_scale, row_scale, group_scale + ) res = res.reshape(*orig_input_size[:-1], orig_out_features) if bias is not None: @@ -221,17 +287,27 @@ def _(func, types, args, kwargs): orig_out_features = weight_tensor.shape[-2] assert weight_tensor.shape_multiplier[-1] == 2 - wq = weight_tensor.packed_weight - group_scale = weight_tensor.group_scale - group_zero = weight_tensor.group_zero - # from https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1715-L1722 - B, M, _ = input_tensor.shape - _, N, _ = wq.shape - res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) - for i in range(B): - res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( - input_tensor[i], wq[i], group_scale[i], group_zero[i] + wq = weight_tensor._data.contiguous() + group_scale = weight_tensor.group_scale.contiguous() + if weight_tensor.group_zero is not None: + group_zero = weight_tensor.group_zero.contiguous() + res = torch.ops.fbgemm.bf16i4bf16_shuffled_batched( + input_tensor, wq, group_scale, group_zero ) + else: + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + # From: https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1654 + assert xq.dim() == 3 + B, M, _ = xq.shape + _, N, _ = wq.shape + res = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16) + for i in range(B): + res[i] = torch.ops.fbgemm.f8i4bf16_shuffled( + xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i] + ) + res = res.reshape(*orig_input_size[:-1], orig_out_features) return res @@ -250,16 +326,23 @@ def _(func, types, args, kwargs): ) -def _same_metadata( - self: "Int4GroupwisePreshuffleTensor", src: "Int4GroupwisePreshuffleTensor" -) -> bool: +def _same_metadata(self: "Int4PreshuffledTensor", src: "Int4PreshuffledTensor") -> bool: return ( - isinstance(self, Int4GroupwisePreshuffleTensor) - and isinstance(src, Int4GroupwisePreshuffleTensor) + isinstance(self, Int4PreshuffledTensor) + and isinstance(src, Int4PreshuffledTensor) and self.shape == src.shape - and self.packed_weight.shape == src.packed_weight.shape + and self._data.shape == src._data.shape and self.group_scale.shape == src.group_scale.shape - and self.group_zero.shape == src.group_zero.shape + and ( + self.group_zero.shape == src.group_zero.shape + if self.group_zero is not None + else src.group_zero is None + ) + and ( + self.row_scale.shape == src.row_scale.shape + if self.row_scale is not None + else src.row_scale is None + ) and self.group_size == src.group_size and self.shape_multiplier == src.shape_multiplier ) @@ -287,33 +370,33 @@ def _(func, types, args, kwargs): dim = dim + tensor_0.ndim for i in range(1, len(tensors)): - assert tensor_0.packed_weight.ndim == tensors[i].packed_weight.ndim + assert tensor_0._data.ndim == tensors[i]._data.ndim assert tensor_0.group_scale.ndim == tensors[i].group_scale.ndim assert tensor_0.group_zero.ndim == tensors[i].group_zero.ndim assert tensor_0.group_size == tensors[i].group_size assert tensor_0.shape_multiplier == tensors[i].shape_multiplier - packed_weight = [t.packed_weight for t in tensors] + _data = [t._data for t in tensors] group_scale = [t.group_scale for t in tensors] group_zero = [t.group_zero for t in tensors] - # with group wise quantization, dimension of group_scale, packed_weight and + # with group wise quantization, dimension of group_scale, _data and # origianl shape will be the same, so original dim argument applies - # to both packed_weight and group_scale - cat_packed_weight = aten.cat.default(packed_weight, dim) - if cat_packed_weight.ndim == 2: + # to both _data and group_scale + cat_data = aten.cat.default(_data, dim) + if cat_data.ndim == 2: sz_dim = 1 - dim else: sz_dim = dim cat_group_scale = aten.cat.default(group_scale, sz_dim) cat_group_zero = aten.cat.default(group_zero, sz_dim) - new_shape = list(cat_packed_weight.shape) + new_shape = list(cat_data.shape) for i in range(len(tensor_0.shape_multiplier)): new_shape[i] *= tensor_0.shape_multiplier[i] new_shape = tuple(new_shape) new = tensor_0.__class__( - cat_packed_weight, + cat_data, cat_group_scale, cat_group_zero, group_size=tensor_0.group_size, @@ -326,19 +409,19 @@ def _(func, types, args, kwargs): @implements(aten.transpose.int) def _(func, types, args, kwargs): self, dim0, dim1 = args - packed_weight = self.packed_weight.transpose(dim0, dim1).contiguous() + _data = self._data.transpose(dim0, dim1).contiguous() shape_multiplier = self.shape_multiplier.copy() shape_multiplier[dim0], shape_multiplier[dim1] = ( shape_multiplier[dim1], shape_multiplier[dim0], ) - tensor_shape = list(packed_weight.shape) + tensor_shape = list(_data.shape) for i in range(len(shape_multiplier)): tensor_shape[i] *= shape_multiplier[i] tensor_shape = tuple(tensor_shape) new = self.__class__( - packed_weight, + _data, self.group_scale, self.group_zero, self.group_size, @@ -348,6 +431,8 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) +Int4PreshuffledTensor.__module__ = "torchao.quantization" + if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int4GroupwisePreshuffleTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int4GroupwisePreshuffleTensor]) + # Allow a model with Int4PreshuffledTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([Int4PreshuffledTensor])