From 889dca34f36a2821413dd13693eb2985e23d0a59 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 24 Jun 2025 10:59:26 -0700 Subject: [PATCH] Add support for float8 activation for Int4PreshuffledTensor Summary: Added basic op support like linear and bmm, we have both float8 and bf16 in the same Tensor because it's the same dtype, only difference is whether the activation is quantized or not. Although there is some differneces in implementation: bf16 activaton: * group_scale * group_zero fp8 activation * group_scale * row_scale Test Plan: python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: https://github.com/pytorch/ao/pull/2437, branch: jerryzh168/stack/4 --- .../int4/test_int4_preshuffled_tensor.py} | 99 +++++-- torchao/quantization/__init__.py | 6 +- torchao/quantization/quant_api.py | 27 +- torchao/quantization/quantize_/__init__.py | 9 - .../quantize_/workflows/__init__.py | 7 + .../quantize_/workflows/int4/__init__.py | 0 .../int4/int4_preshuffled_tensor.py} | 247 ++++++++++++------ 7 files changed, 271 insertions(+), 124 deletions(-) rename test/quantization/quantize_/{test_int4_groupwise_preshuffle.py => workflows/int4/test_int4_preshuffled_tensor.py} (50%) create mode 100644 torchao/quantization/quantize_/workflows/__init__.py create mode 100644 torchao/quantization/quantize_/workflows/int4/__init__.py rename torchao/quantization/quantize_/{int4_groupwise_preshuffle_tensor.py => workflows/int4/int4_preshuffled_tensor.py} (54%) diff --git a/test/quantization/quantize_/test_int4_groupwise_preshuffle.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py similarity index 50% rename from test/quantization/quantize_/test_int4_groupwise_preshuffle.py rename to test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index 9bfe6dffdb..ba3656995c 100644 --- a/test/quantization/quantize_/test_int4_groupwise_preshuffle.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -4,14 +4,18 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import tempfile import unittest import torch from torch.testing._internal.common_utils import ( TestCase, + instantiate_parametrized_tests, + parametrize, run_tests, ) +from torchao.float8.config import e4m3_dtype from torchao.quantization import ( FbgemmConfig, quantize_, @@ -23,6 +27,45 @@ is_sm_at_least_90, ) +if TORCH_VERSION_AT_LEAST_2_8: + BF16_ACT_CONFIG = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + preshuffle=True, + ) + + BF16_ACT_BMM_CONFIG = FbgemmConfig( + input_dtype=torch.bfloat16, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + preshuffle=True, + ) + + FP8_ACT_CONFIG = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 128], + preshuffle=True, + ) + + FP8_ACT_BMM_CONFIG = FbgemmConfig( + input_dtype=e4m3_dtype, + weight_dtype=torch.int4, + output_dtype=torch.bfloat16, + block_size=[1, 1, 128], + preshuffle=True, + ) + +else: + BF16_ACT_CONFIG = None + BF16_ACT_BMM_CONFIG = None + FP8_ACT_CONFIG = None + FP8_ACT_BMM_CONFIG = None + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -30,35 +73,25 @@ @unittest.skipIf( not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) -class TestInt4GroupwisePreshuffleTensor(TestCase): +class TestInt4PreshuffledTensor(TestCase): def setUp(self): - self.config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - preshuffle=True, - ) - self.bmm_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - preshuffle=True, - ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] - def test_linear(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_linear(self, config): dtype = torch.bfloat16 device = "cuda" input = torch.randn(1, 128, dtype=dtype, device=device) linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) original = linear(input) - quantize_(linear, self.config) + quantize_(linear, config) quantized = linear(input) self.assertTrue(compute_error(original, quantized) > 20) - def test_bmm(self): + # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449` + # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG]) + @parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG]) + def test_bmm(self, bmm_config): class M(torch.nn.Module): def __init__(self, weight): super().__init__() @@ -74,32 +107,46 @@ def forward(self, x): m = M(weight).eval() original = m(input) m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantize_(m, bmm_config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) - def test_to_device(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_to_device(self, config): for device in self.GPU_DEVICES: linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device=device) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) linear.to(device) - def test_module_path(self): + @parametrize("config", [BF16_ACT_CONFIG, FP8_ACT_CONFIG]) + def test_module_path(self, config): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - quantize_(linear, self.config) + quantize_(linear, config) self.assertEqual( str(type(linear.weight)), - "", + "", ) + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +instantiate_parametrized_tests(TestInt4PreshuffledTensor) + if __name__ == "__main__": run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index e75fe5e048..f87d038430 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -87,8 +87,8 @@ dequantize_affine, quantize_affine, ) -from .quantize_ import ( - Int4GroupwisePreshuffleTensor, +from .quantize_.workflows import ( + Int4PreshuffledTensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -153,7 +153,7 @@ "ModuleFqnToConfig", "FbgemmConfig", # tensor subclasses - "Int4GroupwisePreshuffleTensor", + "Int4PreshuffledTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4662e20fc9..d692b52bdc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -67,8 +67,8 @@ LinearActivationWeightObservedTensor, ) from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size -from torchao.quantization.quantize_ import ( - Int4GroupwisePreshuffleTensor, +from torchao.quantization.quantize_.workflows import ( + Int4PreshuffledTensor, ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -2056,6 +2056,7 @@ class FbgemmConfig(AOBaseConfig): weight_dtype (torch.dtype): weight dtype of the kernel output_dtype (torch.dtype): output dtype of the kernel group_size (int): The group size for weight + preshuffle (bool): whether preshuffle the weights or not """ input_dtype: torch.dtype @@ -2073,7 +2074,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: _SUPPORTED_DTYPES = { (torch.bfloat16, torch.int4, torch.bfloat16), - (e4m3_dtype, e4m3_dtype, torch.bfloat16), + (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16), } if ( @@ -2082,8 +2083,10 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: and (config.output_dtype == torch.bfloat16) ): if config.preshuffle: - weight = Int4GroupwisePreshuffleTensor.from_float( - module.weight, config.block_size + weight = Int4PreshuffledTensor.from_float( + module.weight, + config.block_size, + activation_dtype=torch.bfloat16, ) else: weight = to_fbgemm_int4( @@ -2093,6 +2096,20 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: module.weight = torch.nn.Parameter(weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module + if ( + (config.input_dtype == e4m3_dtype) + and (config.weight_dtype == torch.int4) + and (config.output_dtype == torch.bfloat16) + ): + if config.preshuffle: + weight = Int4PreshuffledTensor.from_float( + module.weight, + config.block_size, + activation_dtype=torch.float8_e4m3fn, + ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module elif ( (config.input_dtype == e4m3_dtype) and (config.weight_dtype == e4m3_dtype) diff --git a/torchao/quantization/quantize_/__init__.py b/torchao/quantization/quantize_/__init__.py index 049b71631b..e69de29bb2 100644 --- a/torchao/quantization/quantize_/__init__.py +++ b/torchao/quantization/quantize_/__init__.py @@ -1,9 +0,0 @@ -from .int4_groupwise_preshuffle_tensor import ( - Int4GroupwisePreshuffleTensor, -) - -Int4GroupwisePreshuffleTensor.__module__ = "torchao.quantization" - -__all__ = [ - "Int4GroupwisePreshuffleTensor", -] diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py new file mode 100644 index 0000000000..40548e0e0e --- /dev/null +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -0,0 +1,7 @@ +from .int4.int4_preshuffled_tensor import ( + Int4PreshuffledTensor, +) + +__all__ = [ + "Int4PreshuffledTensor", +] diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py similarity index 54% rename from torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py rename to torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index 1313be5128..11cd0a145a 100644 --- a/torchao/quantization/quantize_/int4_groupwise_preshuffle_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -6,7 +6,7 @@ import importlib.util -from typing import List +from typing import List, Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing @@ -18,7 +18,7 @@ ) __all__ = [ - "Int4GroupwisePreshuffleTensor", + "Int4PreshuffledTensor", ] aten = torch.ops.aten @@ -26,30 +26,39 @@ if importlib.util.find_spec("fbgemm_gpu") is None: quantize_int4_preshuffle = None + quantize_fp8_row = None else: - from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle + from fbgemm_gpu.experimental.gen_ai.quantize import ( + quantize_fp8_row, + quantize_int4_preshuffle, + ) -class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): +class Int4PreshuffledTensor(TorchAOBaseTensor): """ Groupwise int4 weight only quantization Tensor Attributes: - packed_weight: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed - group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor - dtype is the same as the original Tensor dtype - group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor - dtype is the same as the original Tensor dtype + _data: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed + for bf16 activation: + group_scale: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor + dtype is the same as the original Tensor dtype + group_zero: (K/group_size, N) for 2D Tensor, (B, N, K/group_size) for 3D Tensor + dtype is the same as the original Tensor dtype + for float8 activation: + group_scale: (K/group_size/8, 8, N) for 2D Tensor, (B, K/group_size/8, 8, N) for 3D Tensor + dtype is float8 + row_scale: (N,) for 2D Tensor, (B, N) for 3D Tensor + dtype is the same as the original Tensor dtype Non-Tensor Attributes: group_size: the group size for groupwise quantization - shape_multiplier: is the multipler from packed_weight to the real weight, since + shape_multiplier: is the multipler from _data to the real weight, since we pack the weight for int4, for example, when we pack the last dimension for a 2D tensor, the shape_multiplier will be [1, 2] shape: shape of the original Tensor - Note: - Details for preshuffle for fbgemm kernel: + Note on Details for preshuffle for fbgemm kernel: We use WGMMA instruction for efficient matrix multiplication in H100 Tensor Core. To address a major inefficiency in how WGMMA tiles are loaded into shared memory before @@ -61,58 +70,88 @@ class Int4GroupwisePreshuffleTensor(TorchAOBaseTensor): loads so having to load all four groups is wasteful. We can optimize weight loading by shuffling the order of elements such that all 4 groups are sequential in memory. This allows us to perform a single 64 bit load to move all needed weights for the thread into register memory. + + Note for float8 activation int4 weight kernel: + float8 activation int4 weight kernel doesn't work with zero_point, since it use table lookup approach which + requires symmetric quantization """ - tensor_data_attrs = ["packed_weight", "group_scale", "group_zero"] + tensor_data_attrs = ["_data", "group_scale"] tensor_attributes = ["group_size", "shape_multiplier", "shape"] def __new__( - cls, packed_weight, group_scale, group_zero, group_size, shape_multiplier, shape + cls, + _data, + group_scale, + group_zero, + row_scale, + group_size, + shape_multiplier, + shape, ): kwargs = {} - kwargs["device"] = packed_weight.device + kwargs["device"] = _data.device kwargs["dtype"] = group_scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, - packed_weight, - group_scale, - group_zero, - group_size, - shape_multiplier, - shape, + _data: torch.Tensor, + group_scale: torch.Tensor, + group_zero: Optional[torch.Tensor], + row_scale: Optional[torch.Tensor], + group_size: int, + shape_multiplier: List[int], + shape: List[int], ): - self.packed_weight = packed_weight + # one and only one of group_scale and group_zero should be None + assert group_zero is None or row_scale is None + assert not (group_zero is not None and row_scale is not None) + self._data = _data self.group_scale = group_scale self.group_zero = group_zero + self.row_scale = row_scale self.shape_multiplier = shape_multiplier self.group_size = group_size def __tensor_flatten__(self): - return self.tensor_data_attrs, [ - getattr(self, attr) for attr in self.tensor_attributes - ] + if getattr(self, "group_zero") is None: + assert getattr(self, "row_scale") is not None + return self.tensor_data_attrs + ["row_scale"], [ + getattr(self, attr) for attr in self.tensor_attributes + ] + else: + return self.tensor_data_attrs + ["group_zero"], [ + getattr(self, attr) for attr in self.tensor_attributes + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): + tensors = [tensor_data_dict[name] for name in cls.tensor_data_attrs] + tensors.append(tensor_data_dict.get("group_zero", None)) + tensors.append(tensor_data_dict.get("row_scale", None)) return cls( - *[tensor_data_dict[name] for name in cls.tensor_data_attrs], + *tensors, *tensor_attributes, ) def _apply_fn_to_data(self, fn): + tensors = [fn(getattr(self, name)) for name in self.tensor_data_attrs] + t1 = getattr(self, "group_zero") + tensors.append(fn(t1) if t1 is not None else None) + t2 = getattr(self, "row_scale") + tensors.append(fn(t2) if t2 is not None else None) return self.__class__( - *[fn(getattr(self, attr)) for attr in self.tensor_data_attrs], + *tensors, *[getattr(self, attr) for attr in self.tensor_attributes], ) def __repr__(self): return ( - f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " + f"{self.__class__.__name__}(weight={self._data}, group_size={self.group_size}, " f"shape_multiplier={self.shape_multiplier}, shape={self.shape}, device={self.device}, dtype={self.dtype}, " f"requires_grad={self.requires_grad})" ) @@ -124,9 +163,10 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") return self.__class__( - self.packed_weight.to(device), + self._data.to(device), self.group_scale.to(device), - self.group_zero.to(device), + self.group_zero.to(device) if self.group_zero is not None else None, + self.row_scale.to(device) if self.row_scale is not None else None, self.group_size, self.shape_multiplier, self.shape, @@ -137,44 +177,72 @@ def from_float( cls, w: torch.Tensor, block_size: List[int], + activation_dtype: torch.dtype = torch.bfloat16, ): assert len(block_size) == w.ndim, ( f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {w.ndim=}" ) + + _SUPPORTED_DTYPE_TO_STR = { + torch.bfloat16: "bf16", + torch.float8_e4m3fn: "fp8", + } + assert activation_dtype in _SUPPORTED_DTYPE_TO_STR, ( + f"activation dtype {activation_dtype} is not supported, supported ones are: {_SUPPORTED_DTYPE_TO_STR.keys()}" + ) + if quantize_int4_preshuffle is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + assert all(x == 1 for x in block_size[:-1]), ( + "Only groupwise quant is supported right now" + ) group_size = block_size[-1] original_shape = w.shape + activation_dtype_str = _SUPPORTED_DTYPE_TO_STR[activation_dtype] + if w.ndim >= 3: wq, scales = zip( - *[quantize_int4_preshuffle(i.cuda(), dtype="bf16") for i in w] + *[ + quantize_int4_preshuffle(i.cuda(), dtype=activation_dtype_str) + for i in w + ] ) wq = torch.stack(wq, dim=0) - group_scale, group_zero = zip(*scales) - group_zero = torch.stack(group_zero, dim=0).contiguous() + group_scale, group_zero_or_row_scale = zip(*scales) + group_zero_or_row_scale = torch.stack( + group_zero_or_row_scale, dim=0 + ).contiguous() group_scale = torch.stack(group_scale, dim=0).contiguous() else: - wq, (group_scale, group_zero) = quantize_int4_preshuffle( - w.cuda(), dtype="bf16" + wq, (group_scale, group_zero_or_row_scale) = quantize_int4_preshuffle( + w.cuda(), dtype=activation_dtype_str ) + if activation_dtype == torch.bfloat16: + group_zero = group_zero_or_row_scale + row_scale = None + else: + group_zero = None + row_scale = group_zero_or_row_scale + shape_multiplier = [1] * wq.ndim shape_multiplier[-1] = 2 del w - return Int4GroupwisePreshuffleTensor( - packed_weight=wq, + return Int4PreshuffledTensor( + _data=wq, group_scale=group_scale, group_zero=group_zero, + row_scale=row_scale, group_size=group_size, shape_multiplier=shape_multiplier, shape=original_shape, ) -implements = Int4GroupwisePreshuffleTensor.implements +implements = Int4PreshuffledTensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -187,23 +255,21 @@ def _(func, types, args, kwargs): orig_input_size = input_tensor.size() orig_out_features = weight_tensor.shape[-2] - wq = weight_tensor.packed_weight.contiguous() + wq = weight_tensor._data.contiguous() group_scale = weight_tensor.group_scale.contiguous() - group_zero = weight_tensor.group_zero.contiguous() - - if input_tensor.dim() == 3: - B, M, _ = input_tensor.shape - _, N, _ = wq.shape - res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) - for i in range(B): - res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( - input_tensor[i], wq[i], group_scale[i], group_zero[i] - ) - else: - # Otherwise run gemm normally. + # bf16 activation + if weight_tensor.group_zero is not None: + group_zero = weight_tensor.group_zero.contiguous() res = torch.ops.fbgemm.bf16i4bf16_shuffled( input_tensor, wq, group_scale, group_zero ) + else: + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + res = torch.ops.fbgemm.f8i4bf16_shuffled( + xq, wq, x_scale, row_scale, group_scale + ) res = res.reshape(*orig_input_size[:-1], orig_out_features) if bias is not None: @@ -221,17 +287,27 @@ def _(func, types, args, kwargs): orig_out_features = weight_tensor.shape[-2] assert weight_tensor.shape_multiplier[-1] == 2 - wq = weight_tensor.packed_weight - group_scale = weight_tensor.group_scale - group_zero = weight_tensor.group_zero - # from https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1715-L1722 - B, M, _ = input_tensor.shape - _, N, _ = wq.shape - res = torch.empty((B, M, N), device=input_tensor.device, dtype=torch.bfloat16) - for i in range(B): - res[i] = torch.ops.fbgemm.bf16i4bf16_shuffled( - input_tensor[i], wq[i], group_scale[i], group_zero[i] + wq = weight_tensor._data.contiguous() + group_scale = weight_tensor.group_scale.contiguous() + if weight_tensor.group_zero is not None: + group_zero = weight_tensor.group_zero.contiguous() + res = torch.ops.fbgemm.bf16i4bf16_shuffled_batched( + input_tensor, wq, group_scale, group_zero ) + else: + assert weight_tensor.row_scale is not None + row_scale = weight_tensor.row_scale.contiguous() + xq, x_scale = quantize_fp8_row(input_tensor) + # From: https://github.com/pytorch/FBGEMM/blob/ba8f2b7adb90e096cff8818716f7cc3587030f70/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py#L1654 + assert xq.dim() == 3 + B, M, _ = xq.shape + _, N, _ = wq.shape + res = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16) + for i in range(B): + res[i] = torch.ops.fbgemm.f8i4bf16_shuffled( + xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i] + ) + res = res.reshape(*orig_input_size[:-1], orig_out_features) return res @@ -250,16 +326,23 @@ def _(func, types, args, kwargs): ) -def _same_metadata( - self: "Int4GroupwisePreshuffleTensor", src: "Int4GroupwisePreshuffleTensor" -) -> bool: +def _same_metadata(self: "Int4PreshuffledTensor", src: "Int4PreshuffledTensor") -> bool: return ( - isinstance(self, Int4GroupwisePreshuffleTensor) - and isinstance(src, Int4GroupwisePreshuffleTensor) + isinstance(self, Int4PreshuffledTensor) + and isinstance(src, Int4PreshuffledTensor) and self.shape == src.shape - and self.packed_weight.shape == src.packed_weight.shape + and self._data.shape == src._data.shape and self.group_scale.shape == src.group_scale.shape - and self.group_zero.shape == src.group_zero.shape + and ( + self.group_zero.shape == src.group_zero.shape + if self.group_zero is not None + else src.group_zero is None + ) + and ( + self.row_scale.shape == src.row_scale.shape + if self.row_scale is not None + else src.row_scale is None + ) and self.group_size == src.group_size and self.shape_multiplier == src.shape_multiplier ) @@ -287,33 +370,33 @@ def _(func, types, args, kwargs): dim = dim + tensor_0.ndim for i in range(1, len(tensors)): - assert tensor_0.packed_weight.ndim == tensors[i].packed_weight.ndim + assert tensor_0._data.ndim == tensors[i]._data.ndim assert tensor_0.group_scale.ndim == tensors[i].group_scale.ndim assert tensor_0.group_zero.ndim == tensors[i].group_zero.ndim assert tensor_0.group_size == tensors[i].group_size assert tensor_0.shape_multiplier == tensors[i].shape_multiplier - packed_weight = [t.packed_weight for t in tensors] + _data = [t._data for t in tensors] group_scale = [t.group_scale for t in tensors] group_zero = [t.group_zero for t in tensors] - # with group wise quantization, dimension of group_scale, packed_weight and + # with group wise quantization, dimension of group_scale, _data and # origianl shape will be the same, so original dim argument applies - # to both packed_weight and group_scale - cat_packed_weight = aten.cat.default(packed_weight, dim) - if cat_packed_weight.ndim == 2: + # to both _data and group_scale + cat_data = aten.cat.default(_data, dim) + if cat_data.ndim == 2: sz_dim = 1 - dim else: sz_dim = dim cat_group_scale = aten.cat.default(group_scale, sz_dim) cat_group_zero = aten.cat.default(group_zero, sz_dim) - new_shape = list(cat_packed_weight.shape) + new_shape = list(cat_data.shape) for i in range(len(tensor_0.shape_multiplier)): new_shape[i] *= tensor_0.shape_multiplier[i] new_shape = tuple(new_shape) new = tensor_0.__class__( - cat_packed_weight, + cat_data, cat_group_scale, cat_group_zero, group_size=tensor_0.group_size, @@ -326,19 +409,19 @@ def _(func, types, args, kwargs): @implements(aten.transpose.int) def _(func, types, args, kwargs): self, dim0, dim1 = args - packed_weight = self.packed_weight.transpose(dim0, dim1).contiguous() + _data = self._data.transpose(dim0, dim1).contiguous() shape_multiplier = self.shape_multiplier.copy() shape_multiplier[dim0], shape_multiplier[dim1] = ( shape_multiplier[dim1], shape_multiplier[dim0], ) - tensor_shape = list(packed_weight.shape) + tensor_shape = list(_data.shape) for i in range(len(shape_multiplier)): tensor_shape[i] *= shape_multiplier[i] tensor_shape = tuple(tensor_shape) new = self.__class__( - packed_weight, + _data, self.group_scale, self.group_zero, self.group_size, @@ -348,6 +431,8 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) +Int4PreshuffledTensor.__module__ = "torchao.quantization" + if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with Int4GroupwisePreshuffleTensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([Int4GroupwisePreshuffleTensor]) + # Allow a model with Int4PreshuffledTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([Int4PreshuffledTensor])