diff --git a/test/integration/test_serialization_bc.py b/test/integration/test_serialization_bc.py index eb15be3c8c..dc7a31bcf8 100644 --- a/test/integration/test_serialization_bc.py +++ b/test/integration/test_serialization_bc.py @@ -17,13 +17,13 @@ from torchao.utils import _is_fbgemm_genai_gpu_available, is_sm_at_least_90 _MODEL_NAMES = [ - "torchao-testing/opt-125m-float8dq-row-fbgemm", + "torchao-testing/opt-125m-float8dq-row-0.13-dev", + "torchao-testing/opt-125m-int4wo-preshuffled-0.13-dev", ] @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") -@unittest.skip("temporary skip since we have some refactor next") class TestSerializationBC(TestCase): """Test we can still load and run serialized model in previous AO versions we commit to have BC for 3 pytorch releases diff --git a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py index ba3656995c..26ff7313b9 100644 --- a/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py @@ -15,9 +15,9 @@ run_tests, ) -from torchao.float8.config import e4m3_dtype from torchao.quantization import ( - FbgemmConfig, + Float8ActivationInt4WeightConfig, + Int4WeightOnlyConfig, quantize_, ) from torchao.quantization.utils import compute_error @@ -27,44 +27,15 @@ is_sm_at_least_90, ) -if TORCH_VERSION_AT_LEAST_2_8: - BF16_ACT_CONFIG = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - preshuffle=True, - ) - - BF16_ACT_BMM_CONFIG = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - preshuffle=True, - ) - - FP8_ACT_CONFIG = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - preshuffle=True, - ) - - FP8_ACT_BMM_CONFIG = FbgemmConfig( - input_dtype=e4m3_dtype, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], - preshuffle=True, - ) - -else: - BF16_ACT_CONFIG = None - BF16_ACT_BMM_CONFIG = None - FP8_ACT_CONFIG = None - FP8_ACT_BMM_CONFIG = None +BF16_ACT_CONFIG = Int4WeightOnlyConfig( + group_size=128, + packing_format="preshuffled", +) + +FP8_ACT_CONFIG = Float8ActivationInt4WeightConfig( + group_size=128, + packing_format="preshuffled", +) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @@ -90,7 +61,7 @@ def test_linear(self, config): # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449` # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG]) - @parametrize("bmm_config", [FP8_ACT_BMM_CONFIG, BF16_ACT_BMM_CONFIG]) + @parametrize("bmm_config", [FP8_ACT_CONFIG, BF16_ACT_CONFIG]) def test_bmm(self, bmm_config): class M(torch.nn.Module): def __init__(self, weight): diff --git a/test/dtypes/test_fbgemm_int4.py b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py similarity index 83% rename from test/dtypes/test_fbgemm_int4.py rename to test/quantization/quantize_/workflows/int4/test_int4_tensor.py index eb1f059775..b8700ce91a 100644 --- a/test/dtypes/test_fbgemm_int4.py +++ b/test/quantization/quantize_/workflows/int4/test_int4_tensor.py @@ -13,7 +13,7 @@ ) from torchao.quantization import ( - FbgemmConfig, + Int4WeightOnlyConfig, quantize_, ) from torchao.quantization.utils import compute_error @@ -26,19 +26,11 @@ @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") -class TestFbgemmInt4Tensor(TestCase): +class TestInt4Tensor(TestCase): def setUp(self): - self.config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 128], - ) - self.bmm_config = FbgemmConfig( - input_dtype=torch.bfloat16, - weight_dtype=torch.int4, - output_dtype=torch.bfloat16, - block_size=[1, 1, 128], + self.config = Int4WeightOnlyConfig( + group_size=128, + packing_format="plain", ) self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] @@ -68,13 +60,9 @@ def test_slice(self): quantize_(dummy, self.config) weight1 = dummy.weight.narrow(0, 0, 64) weight2 = dummy.weight.narrow(1, 0, 128) - self.assertEqual( - weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64) - ) + self.assertEqual(weight1._data, dummy.weight._data.narrow(0, 0, 64)) self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64)) - self.assertEqual( - weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64) - ) + self.assertEqual(weight2._data, dummy.weight._data.narrow(1, 0, 64)) self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1)) # check for sliced weight, before and after float8 quantization @@ -100,12 +88,10 @@ def test_slice_and_copy_(self): param = l.weight param_data = param.data param_data = param_data.narrow(0, 0, 512) - assert ( - param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr() - ) + assert param.data._data.data_ptr() == param_data._data.data_ptr() assert param.data.scale.data_ptr() == param_data.scale.data_ptr() assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() - orig_value = param.data.packed_weight[0][0].item() + orig_value = param.data._data[0][0].item() # dummy_l has random input (shouldn't be 0) dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) @@ -116,7 +102,7 @@ def test_slice_and_copy_(self): param_data.copy_(quantized) # making sure param.data is updated - assert param.data.packed_weight[0][0] != orig_value + assert param.data._data[0][0] != orig_value def test_bmm(self): class M(torch.nn.Module): @@ -135,7 +121,7 @@ def forward(self, x): original = m(input) # we need to transpose the weight first for bmm m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) - quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True) + quantize_(m, self.config, filter_fn=lambda x, fqn: True) quantized = m(input) self.assertTrue(compute_error(original, quantized) > 18) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d6b1b9c440..575e154091 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -9,7 +9,6 @@ to_affine_quantized_intx_static, ) from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8 -from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, Float8Layout, @@ -64,8 +63,6 @@ "PackedLinearInt8DynamicActivationIntxWeightLayout", "to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight", "Int4XPULayout", - "to_fbgemm_int4", - "FbgemmInt4Tensor", "to_fbgemm_fp8", "FbgemmFp8Tensor", "Int8DynamicActInt4WeightCPULayout", diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 2a56f9cbcb..732dff94d9 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -44,6 +44,7 @@ from .quant_api import ( CutlassInt4PackedLayout, FbgemmConfig, + Float8ActivationInt4WeightConfig, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, Float8MMConfig, @@ -90,6 +91,7 @@ from .quantize_.workflows import ( Float8Tensor, Int4PreshuffledTensor, + Int4Tensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -141,6 +143,7 @@ "Int8DynamicActivationInt8WeightConfig", "Int8DynamicActivationIntxWeightConfig", "Int4WeightOnlyConfig", + "Float8ActivationInt4WeightConfig", "Int8WeightOnlyConfig", "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", @@ -154,6 +157,7 @@ "ModuleFqnToConfig", "FbgemmConfig", # tensor subclasses + "Int4Tensor", "Int4PreshuffledTensor", "Float8Tensor", # smooth quant - subject to change diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9c521f368d..189d2b1569 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -49,7 +49,6 @@ to_affine_quantized_floatx_static, to_affine_quantized_intx, to_fbgemm_fp8, - to_fbgemm_int4, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -72,6 +71,8 @@ from torchao.quantization.quantize_.workflows import ( Float8Tensor, Int4PreshuffledTensor, + Int4Tensor, + PackingFormat, ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -1108,6 +1109,8 @@ class Int4WeightOnlyConfig(AOBaseConfig): `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT + `packing_format`: the packing format for int4 tensor, this is the new structure we plan to use + existing layout based structure will use `_legacy` """ group_size: int = 128 @@ -1116,6 +1119,9 @@ class Int4WeightOnlyConfig(AOBaseConfig): zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE set_inductor_config: bool = True preserve_zero: Optional[bool] = None + # since not all tensors are migrated to the new structure yet, + # we use `_legacy' to represent the previous layout + packing_format: PackingFormat = "_legacy" # for BC @@ -1133,6 +1139,7 @@ def _int4_weight_only_quantize_tensor(weight, config): layout = config.layout use_hqq = config.use_hqq zero_point_domain = config.zero_point_domain + packing_format = config.packing_format if weight.shape[-1] % group_size != 0: logger.info( @@ -1140,8 +1147,25 @@ def _int4_weight_only_quantize_tensor(weight, config): ) return weight + block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) + + if packing_format == "preshuffled": + new_weight = Int4PreshuffledTensor.from_float( + weight, + block_size, + activation_dtype=torch.bfloat16, + ) + return new_weight + elif packing_format == "plain": + new_weight = Int4Tensor.from_float( + weight, + block_size, + ) + return new_weight + + assert packing_format == "_legacy", f"Unsupported packing_format: {packing_format}" + mapping_type = MappingType.ASYMMETRIC - block_size = tuple([1 for _ in range(weight.dim() - 1)] + [group_size]) target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -1213,6 +1237,38 @@ def _int4_weight_only_transform( return module +@dataclass +class Float8ActivationInt4WeightConfig(AOBaseConfig): + group_size: int = 128 + packing_format: PackingFormat = "preshuffled" + + +@register_quantize_module_handler(Float8ActivationInt4WeightConfig) +def _float8_activation_int4_weight_transform( + module: torch.nn.Module, config: Float8ActivationInt4WeightConfig +) -> torch.nn.Module: + assert hasattr(module, "weight"), ( + "applying int8 weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + group_size = config.group_size + packing_format = config.packing_format + + assert packing_format == "preshuffled", ( + f"only preshuffled packing_format supported right now, got: {packing_format}" + ) + weight = module.weight + block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) + new_weight = Int4PreshuffledTensor.from_float( + module.weight, + block_size, + activation_dtype=torch.float8_e4m3fn, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Int8WeightOnlyConfig(AOBaseConfig): """ @@ -1484,6 +1540,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): """ weight_dtype: torch.dtype = e4m3_dtype + packing_format: PackingFormat = "plain" set_inductor_config: bool = True @@ -1493,6 +1550,10 @@ class Float8WeightOnlyConfig(AOBaseConfig): def _float8_weight_only_quant_tensor(weight, config): block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) + packing_format = config.packing_format + assert packing_format == "plain", ( + f"Only plain packing_format is supported, got: {packing_format}" + ) new_weight = Float8Tensor.from_float( weight, config.weight_dtype, @@ -1610,6 +1671,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None activation_scale_ub: float = 1200.0 + packing_format: PackingFormat = "plain" set_inductor_config: bool = True def __post_init__(self): @@ -1631,8 +1693,13 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): weight_dtype = config.weight_dtype granularity = config.granularity mm_config = config.mm_config + packing_format = config.packing_format activation_scale_ub = config.activation_scale_ub + assert packing_format == "plain", ( + f"Only plain packing_format is supported, got {packing_format}" + ) + # Ensure works on device _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity @@ -1641,6 +1708,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked return weight + if isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" @@ -2083,7 +2151,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module: activation_dtype=torch.bfloat16, ) else: - weight = to_fbgemm_int4( + weight = Int4Tensor.from_float( module.weight, config.block_size, ) diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4190689989..c238e1bbda 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -4,8 +4,16 @@ from .int4.int4_preshuffled_tensor import ( Int4PreshuffledTensor, ) +from .int4.int4_tensor import ( + Int4Tensor, +) +from .packing_format import ( + PackingFormat, +) __all__ = [ + "Int4Tensor", "Int4PreshuffledTensor", "Float8Tensor", + "PackingFormat", ] diff --git a/torchao/quantization/quantize_/workflows/int4/__init__.py b/torchao/quantization/quantize_/workflows/int4/__init__.py index e69de29bb2..3394822214 100644 --- a/torchao/quantization/quantize_/workflows/int4/__init__.py +++ b/torchao/quantization/quantize_/workflows/int4/__init__.py @@ -0,0 +1,7 @@ +from .int4_preshuffled_tensor import Int4PreshuffledTensor +from .int4_tensor import Int4Tensor + +__all__ = [ + "Int4PreshuffledTensor", + "Int4Tensor", +] diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index 11cd0a145a..ceeda7a6b4 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -194,7 +194,7 @@ def from_float( if quantize_int4_preshuffle is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") - assert all(x == 1 for x in block_size[:-1]), ( + assert all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1, ( "Only groupwise quant is supported right now" ) group_size = block_size[-1] diff --git a/torchao/dtypes/fbgemm_int4_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py similarity index 74% rename from torchao/dtypes/fbgemm_int4_tensor.py rename to torchao/quantization/quantize_/workflows/int4/int4_tensor.py index 385f70e3bb..f0ea6d87ad 100644 --- a/torchao/dtypes/fbgemm_int4_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_tensor.py @@ -17,8 +17,7 @@ ) __all__ = [ - "to_fbgemm_int4", - "FbgemmInt4Tensor", + "Int4Tensor", ] aten = torch.ops.aten @@ -31,22 +30,22 @@ pack_int4 = None -class FbgemmInt4Tensor(TorchAOBaseTensor): - tensor_data_attrs = ["packed_weight", "scale", "zero_point"] - tensor_attributes = ["group_size", "shape"] +class Int4Tensor(TorchAOBaseTensor): + tensor_data_attrs = ["_data", "scale", "zero_point"] + tensor_attributes = ["block_size", "shape"] - def __new__(cls, packed_weight, scale, zero_point, group_size, shape): + def __new__(cls, _data, scale, zero_point, block_size, shape): kwargs = {} - kwargs["device"] = packed_weight.device + kwargs["device"] = _data.device kwargs["dtype"] = scale.dtype kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, packed_weight, scale, zero_point, group_size, shape): - self.packed_weight = packed_weight + def __init__(self, _data, scale, zero_point, block_size, shape): + self._data = _data self.scale = scale self.zero_point = zero_point - self.group_size = group_size + self.block_size = block_size def __tensor_flatten__(self): return self.tensor_data_attrs, [ @@ -70,21 +69,21 @@ def _apply_fn_to_data(self, fn): def __repr__(self): return ( - f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, " + f"{self.__class__.__name__}(weight={self._data}, block_size={self.block_size}, " f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): - return f"shape={self.shape}, group_size={self.group_size}, device={self.device}" + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") return self.__class__( - self.packed_weight.to(device), + self._data.to(device), self.scale.to(device), self.zero_point.to(device), - self.group_size, + self.block_size, self.shape, ) @@ -100,6 +99,10 @@ def from_float( if int4_row_quantize_zp is None: raise ImportError("Requires fbgemm-gpu-genai >= 1.2.0") + assert all(x == 1 for x in block_size[:-1]) and block_size[-1] != 1, ( + "Only groupwise quant is supported right now" + ) + group_size = block_size[-1] original_shape = w.shape @@ -118,16 +121,16 @@ def from_float( zero_point = zero_point.to(w.dtype) del w - return FbgemmInt4Tensor( - packed_weight=wq, + return Int4Tensor( + _data=wq, scale=scale, zero_point=zero_point, - group_size=group_size, + block_size=block_size, shape=original_shape, ) -implements = FbgemmInt4Tensor.implements +implements = Int4Tensor.implements @implements([torch.nn.functional.linear, aten.linear.default]) @@ -142,9 +145,9 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.bf16i4bf16_rowwise( input_tensor, - weight_tensor.packed_weight.contiguous(), - weight_tensor.scale, - weight_tensor.zero_point, + weight_tensor._data.contiguous(), + weight_tensor.scale.contiguous(), + weight_tensor.zero_point.contiguous(), ) res = res.reshape(*orig_act_size[:-1], orig_out_features) if bias is not None: @@ -163,7 +166,7 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.bf16i4bf16_rowwise_batched( input_tensor, - weight_tensor.packed_weight.contiguous(), + weight_tensor._data.contiguous(), weight_tensor.scale, weight_tensor.zero_point, ) @@ -185,15 +188,15 @@ def _(func, types, args, kwargs): ) -def _same_metadata(self: "FbgemmInt4Tensor", src: "FbgemmInt4Tensor") -> bool: +def _same_metadata(self: "Int4Tensor", src: "Int4Tensor") -> bool: return ( - isinstance(self, FbgemmInt4Tensor) - and isinstance(src, FbgemmInt4Tensor) + isinstance(self, Int4Tensor) + and isinstance(src, Int4Tensor) and self.shape == src.shape - and self.packed_weight.shape == src.packed_weight.shape + and self._data.shape == src._data.shape and self.scale.shape == src.scale.shape and self.zero_point.shape == src.zero_point.shape - and self.group_size == src.group_size + and self.block_size == src.block_size ) @@ -214,21 +217,21 @@ def _(func, types, args, kwargs): @implements(aten.slice.Tensor) def _(func, types, args, kwargs): """Only supports slicing for dim == 1 and dim == 2 - packed_weight has dimension: (N, K/2) + _data has dimension: (N, K/2) scale and zero_point has dimension: (K/groups, N) dim, start, end, step are args that's referring to the original tensor shape - which is (N, K), and we need to map that to the transformed weight shape of packed_weight, + which is (N, K), and we need to map that to the transformed weight shape of _data, scale and zero_point - when dim == 0: we do a slice on packed_weight dim 0, and on dim 1 of scale and zero_point, + when dim == 0: we do a slice on _data dim 0, and on dim 1 of scale and zero_point, also adjust the start and end indexes based on the ratio between original shape and the shape - of packed_weight and scale/zero_point + of _data and scale/zero_point - when dim == 1: we do a slice on packed_weight dim 1 and dim 0 of scale and zero_point and do the + when dim == 1: we do a slice on _data dim 1 and dim 0 of scale and zero_point and do the same adjustment based on ratio - Note that we need to call slice on the packed_weight, scale and zero_point directly because slice + Note that we need to call slice on the _data, scale and zero_point directly because slice is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_int4` for """ @@ -238,10 +241,10 @@ def _(func, types, args, kwargs): if end >= self.shape[dim]: end = self.shape[dim] - assert self.packed_weight.ndim == 2, ( - f"Expected packed weight to have dim 2, got {self.packed_weight.dim}" + assert self._data.ndim == 2, ( + f"Expected packed weight to have dim 2, got {self._data.dim}" ) - N, K_by_2 = self.packed_weight.shape + N, K_by_2 = self._data.shape sz_dim0, sz_dim1 = self.scale.shape data_len = self.shape[dim] @@ -260,10 +263,10 @@ def _(func, types, args, kwargs): args, kwargs, self.__class__( - self.packed_weight, + self._data, self.scale, self.zero_point, - group_size=self.group_size, + block_size=self.block_size, shape=self.shape, ), ) @@ -276,20 +279,19 @@ def _(func, types, args, kwargs): start_sz = int(start / sz_ratio) end_sz = int(end / sz_ratio) - packed_weight = aten.slice.Tensor(self.packed_weight, dim, start_pw, end_pw, step) + _data = aten.slice.Tensor(self._data, dim, start_pw, end_pw, step) scale = aten.slice.Tensor(self.scale, sz_dim, start_sz, end_sz, step) zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step) - packed_shape0, packed_shape1 = packed_weight.shape + packed_shape0, packed_shape1 = _data.shape new_shape = (packed_shape0, packed_shape1 * 2) new = self.__class__( - packed_weight, scale, zero_point, group_size=self.group_size, shape=new_shape + _data, scale, zero_point, block_size=self.block_size, shape=new_shape ) return return_and_correct_aliasing(func, args, kwargs, new) -to_fbgemm_int4 = FbgemmInt4Tensor.from_float - +Int4Tensor.__module__ = "torchao.quantization" if TORCH_VERSION_AT_LEAST_2_5: - # Allow a model with FbgemmInt4Tensor weights to be loaded with `weights_only=True` - torch.serialization.add_safe_globals([FbgemmInt4Tensor]) + # Allow a model with Int4Tensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([Int4Tensor]) diff --git a/torchao/quantization/quantize_/workflows/packing_format.py b/torchao/quantization/quantize_/workflows/packing_format.py new file mode 100644 index 0000000000..77817a82d4 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/packing_format.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +class PackingFormat(str, Enum): + PLAIN = "plain" + PRESHUFFLED = "preshuffled" + _LEGACY = "_legacy"