From e1009b79f43e6c6e479c40fdeb3f227ae9fc4cc5 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 7 Jul 2025 15:27:23 -0700 Subject: [PATCH] Reference representation of dqlinear int4 for xnnpack Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: --- .../pt2e/reference_representation_rewrite.py | 227 +++++++++++++ .../test_reference_representation_rewrite.py | 300 ++++++++++++++++++ 2 files changed, 527 insertions(+) create mode 100644 torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py diff --git a/torchao/quantization/pt2e/reference_representation_rewrite.py b/torchao/quantization/pt2e/reference_representation_rewrite.py index 6526c6044f..7233a8563d 100644 --- a/torchao/quantization/pt2e/reference_representation_rewrite.py +++ b/torchao/quantization/pt2e/reference_representation_rewrite.py @@ -24,6 +24,9 @@ remove_tensor_overload_for_qdq_ops, ) +from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS, MappingType +from torchao.quantization.utils import _get_per_token_block_size + try: from torch._export.utils import _disable_aten_to_metadata_assertions except: @@ -32,6 +35,8 @@ __all__ = [ "reference_representation_rewrite", + "_qdq_dynamic_quantized_linear_4bit_groupwise", + "_reference_dynamic_quantized_linear_4bit_groupwise", ] @@ -203,6 +208,185 @@ def _reference_dynamic_quantized_linear( return out_fp32 +def _qdq_dynamic_quantized_linear_4bit_groupwise( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + x_scales_type, + x_zero_points_type, + weight_i4, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + bias_fp32, + group_size, +): + # Dynamic quantization of activation + x_mapping_type = MappingType.ASYMMETRIC + per_token_block_size = _get_per_token_block_size(x_fp32) + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( + x_fp32, + x_mapping_type.name, + per_token_block_size, + torch.int8, + x_quant_min, + x_quant_max, + x_eps, + x_scales_type, + x_zero_points_type, + ) + x_i8 = torch.ops.torchao.quantize_affine( + x_fp32, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + ) + x_fp32 = torch.ops.torchao.dequantize_affine( + x_i8, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + torch.float32, + ) + + assert group_size > 0, "Group size must be positive" + assert ( + weight_i4.shape[1] % group_size == 0 + ), "Weight must be divisible by group_size" + assert weight_i4.dim() == 2, "Weight must be 2D tensor" + block_size = (1, group_size) + weight_fp32 = torch.ops.torchao.dequantize_affine( + weight_i4, + block_size, + weight_scale, + weight_zero_point, + torch.int8, + weight_quant_min, + weight_quant_max, + torch.float32, + ) + + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + + +def _reference_dynamic_quantized_linear_4bit_groupwise( + x_fp32, + x_quant_min, + x_quant_max, + x_eps, + x_scales_type, + x_zero_points_type, + weight_i4, + weight_scale, + weight_zero_point, # Not used because assuming weight is symmetric + weight_quant_min, + weight_quant_max, + bias_fp32, + group_size, +): + # Dynamic quantization of activation + x_mapping_type = MappingType.ASYMMETRIC + per_token_block_size = _get_per_token_block_size(x_fp32) + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( + x_fp32, + x_mapping_type.name, + per_token_block_size, + torch.int8, + x_quant_min, + x_quant_max, + x_eps, + x_scales_type, + x_zero_points_type, + ) + x_i8 = torch.ops.torchao.quantize_affine( + x_fp32, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + ) + + # For groupwise quantization, we need to handle the computation differently + # weight_i4 shape: [out_features, in_features] + # weight_scale shape: [out_features, in_features // group_size] + # weight_zero_point shape: [out_features, in_features // group_size] + out_features, in_features = weight_i4.shape + num_groups = in_features // group_size + + # scales in xnnpack are stored as bf16 and converted to fp32 for computation + weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32) + + assert x_i8.dim() == 2, "x_i8 must be 2D tensor" + # Reshape for group-wise processing + # x: [batch_size, in_features] -> [batch_size, num_groups, group_size] + batch_size = x_i8.shape[0] + x_i8_grouped = x_i8.view(batch_size, num_groups, group_size) + + # weight: [out_features, in_features] -> [out_features, num_groups, group_size] + weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size) + + # Convert to int16 for computation + x_i32_grouped = x_i8_grouped.to(torch.int32) + weight_i32_grouped = weight_i4_grouped.to(torch.int32) + + # Perform groupwise integer linear operation + acc_fp32 = torch.zeros( + batch_size, out_features, dtype=torch.float32, device=x_fp32.device + ) + + for group_idx in range(num_groups): + # Extract current group + x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size] + weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size] + weight_group_col_sum = weight_group.sum(dim=-1) # [out_features] + + # Get scale for this group + weight_scale_group = weight_scale[:, group_idx] # [out_features] + + # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features] + group_acc = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_group, + weight_group, + None, + ) + + # Output has to be scaled by x_scale * weight_scale_group + # However we will first scale by weight_scale_group, that is accounting + # only for scale of weight, and then scale by x_scale at the end because + # x_scale applies to all groups + acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view( + 1, -1 + ) + + # we must also subtract x_zero_point * weight_group_sum + # since (X - x_zero_point) * W = X * W - x_zero_point * W + weights_col_sum_adjusted = ( + weight_group_col_sum.to(torch.float32).view(1, -1) + * x_zero_point.view(-1, 1) + * weight_scale_group.view(1, -1) + ) + acc_fp32 = acc_fp32 - weights_col_sum_adjusted + x_scale_multiplier = x_scale.view(-1, 1) + out_fp32 = acc_fp32 * x_scale_multiplier + if bias_fp32 is not None: + out_fp32 = out_fp32 + bias_fp32 + + return out_fp32 + + def _qdq_quantized_conv2d( x_i8, x_scale, @@ -738,6 +922,22 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: 127, ) + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS = ( + torch.randn((2, 32), dtype=torch.float), # x_fp32 + -128, # x_quant_min + 127, # x_quant_max + torch.finfo(torch.float32).eps, # x_eps + torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8) + torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups] + torch.zeros( + 8, 4, dtype=torch.int + ), # weight_zero_point [out_features, num_groups] + torch.tensor([-8], dtype=torch.int), # weight_quant_min (4-bit range) + torch.tensor([7], dtype=torch.int), # weight_quant_max (4-bit range) + torch.randn(8, dtype=torch.float), # bias_fp32 + 8, # group_size + ) + _REWRITE_INFO_LIST = [ _RewriteInfo( _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, @@ -752,6 +952,33 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, ), ), + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS, + WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise), + WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + -128: 1, + 127: 2, + torch.finfo(torch.float32).eps: 3, + -8: 7, + 7: 8, + 8: 10, + }, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + -128: 1, + 127: 2, + torch.finfo(torch.float32).eps: 3, + -8: 7, + 7: 8, + 8: 10, + }, + ), + ), _RewriteInfo( _QUANTIZED_LINEAR_EXAMPLE_INPUTS, WrapperModule(_qdq_quantized_linear), diff --git a/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py new file mode 100644 index 0000000000..494932e3e4 --- /dev/null +++ b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from torchao.quantization.pt2e.reference_representation_rewrite import ( + _qdq_dynamic_quantized_linear_4bit_groupwise, + _reference_dynamic_quantized_linear_4bit_groupwise, +) + + +class TestReferenceRepresentationRewrite(unittest.TestCase): + """Test cases for dynamically quantized linear 4-bit groupwise implementations.""" + + def _get_default_quantization_params(self): + """Get default quantization parameters.""" + return { + "x_quant_min": -128, + "x_quant_max": 127, + "x_eps": torch.finfo(torch.float32).eps, + "x_scales_type": torch.float32, + "x_zero_points_type": torch.int8, + "weight_quant_min": -8, # 4-bit range + "weight_quant_max": 7, # 4-bit range + } + + def _create_test_tensors( + self, batch_size, in_features, out_features, group_size, seed=42, bias=True + ): + """Create test tensors for the given dimensions.""" + torch.manual_seed(seed) + + # Create input activation + x_fp32 = torch.randn(batch_size, in_features, dtype=torch.float32) + + # Create 4-bit quantized weight (stored as int8 with values in [-8, 7]) + weight_i4 = torch.randint(-8, 7, (out_features, in_features), dtype=torch.int8) + + # Create groupwise scales and zero points + num_groups = in_features // group_size + weight_scale = ( + torch.randn(out_features, num_groups, dtype=torch.float32).abs() + 0.01 + ) + weight_zero_point = torch.zeros( + out_features, num_groups, dtype=torch.int8 + ) # Symmetric quantization + + # Create bias if requested + bias_fp32 = torch.randn(out_features, dtype=torch.float32) if bias else None + + return { + "x_fp32": x_fp32, + "weight_i4": weight_i4, + "weight_scale": weight_scale, + "weight_zero_point": weight_zero_point, + "bias_fp32": bias_fp32, + } + + def _run_qdq_implementation(self, tensors, quant_params, group_size): + """Run the QDQ implementation with given tensors and parameters.""" + return _qdq_dynamic_quantized_linear_4bit_groupwise( + x_fp32=tensors["x_fp32"], + x_quant_min=quant_params["x_quant_min"], + x_quant_max=quant_params["x_quant_max"], + x_eps=quant_params["x_eps"], + x_scales_type=quant_params["x_scales_type"], + x_zero_points_type=quant_params["x_zero_points_type"], + weight_i4=tensors["weight_i4"], + weight_scale=tensors["weight_scale"], + weight_zero_point=tensors["weight_zero_point"], + weight_quant_min=quant_params["weight_quant_min"], + weight_quant_max=quant_params["weight_quant_max"], + bias_fp32=tensors["bias_fp32"], + group_size=group_size, + ) + + def _run_reference_implementation(self, tensors, quant_params, group_size): + """Run the reference implementation with given tensors and parameters.""" + return _reference_dynamic_quantized_linear_4bit_groupwise( + x_fp32=tensors["x_fp32"], + x_quant_min=quant_params["x_quant_min"], + x_quant_max=quant_params["x_quant_max"], + x_eps=quant_params["x_eps"], + x_scales_type=quant_params["x_scales_type"], + x_zero_points_type=quant_params["x_zero_points_type"], + weight_i4=tensors["weight_i4"], + weight_scale=tensors["weight_scale"], + weight_zero_point=tensors["weight_zero_point"], + weight_quant_min=quant_params["weight_quant_min"], + weight_quant_max=quant_params["weight_quant_max"], + bias_fp32=tensors["bias_fp32"], + group_size=group_size, + ) + + def _assert_basic_properties(self, result, expected_shape): + """Assert basic properties of the result tensor.""" + self.assertEqual(result.shape, expected_shape) + self.assertEqual(result.dtype, torch.float32) + + def _assert_implementations_close( + self, qdq_result, ref_result, atol=5e-2, rtol=5e-2, msg_suffix="" + ): + """Assert that QDQ and reference implementations produce similar results.""" + torch.testing.assert_close( + qdq_result, + ref_result, + atol=atol, + rtol=rtol, + msg=f"QDQ and reference results differ significantly{msg_suffix}", + ) + + def test_qdq_dynamic_quantized_linear_4bit_groupwise_basic(self): + """Test that QDQ implementation runs without errors and produces reasonable output.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 2, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + result = self._run_qdq_implementation(tensors, quant_params, group_size) + self._assert_basic_properties(result, (batch_size, out_features)) + + def test_reference_dynamic_quantized_linear_4bit_groupwise_basic(self): + """Test that reference implementation runs without errors and produces reasonable output.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 2, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + result = self._run_reference_implementation(tensors, quant_params, group_size) + self._assert_basic_properties(result, (batch_size, out_features)) + + def test_both_implementations_no_bias(self): + """Test both implementations without bias.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 1, 16, 4, 8 + seed = 123 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size, seed=seed, bias=False + ) + + qdq_result = self._run_qdq_implementation(tensors, quant_params, group_size) + ref_result = self._run_reference_implementation( + tensors, quant_params, group_size + ) + + self._assert_basic_properties(qdq_result, (batch_size, out_features)) + self._assert_basic_properties(ref_result, (batch_size, out_features)) + self._assert_implementations_close( + qdq_result, ref_result, msg_suffix=" for no-bias case" + ) + + def test_edge_cases_group_size_validation(self): + """Test edge cases and error conditions.""" + # Test-specific parameters + batch_size, in_features, out_features = 1, 32, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, 8 + ) # Valid group size for tensor creation + + # Test with group_size that doesn't divide in_features evenly + with self.assertRaises(AssertionError): + self._run_qdq_implementation( + tensors, quant_params, 7 + ) # 32 is not divisible by 7 + + # Test with zero group_size + with self.assertRaises(AssertionError): + self._run_qdq_implementation(tensors, quant_params, 0) + + def test_weight_dimension_validation(self): + """Test weight dimension validation.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 1, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + # Create 1D weight tensor (should fail) + tensors["weight_i4"] = torch.randint(-8, 7, (in_features,), dtype=torch.int8) + + with self.assertRaises((AssertionError, IndexError)): + self._run_qdq_implementation(tensors, quant_params, group_size) + + def test_different_group_sizes(self): + """Test with different valid group sizes.""" + # Test-specific parameters + batch_size, in_features, out_features = 2, 64, 16 + seed = 456 + group_sizes = [8, 16, 32, 64] + + quant_params = self._get_default_quantization_params() + + for group_size in group_sizes: + with self.subTest(group_size=group_size): + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size, seed=seed + ) + + qdq_result = self._run_qdq_implementation( + tensors, quant_params, group_size + ) + ref_result = self._run_reference_implementation( + tensors, quant_params, group_size + ) + + self._assert_basic_properties(qdq_result, (batch_size, out_features)) + self._assert_basic_properties(ref_result, (batch_size, out_features)) + self._assert_implementations_close( + qdq_result, ref_result, msg_suffix=f" for group_size={group_size}" + ) + + def test_qdq_vs_reference_implementation_comparison(self): + """Test that QDQ and reference implementations produce similar results with various configurations.""" + # Test-specific parameters + test_cases = [ + (1, 32, 8, 8), + (2, 64, 16, 16), + (4, 128, 32, 32), + ] + + quant_params = self._get_default_quantization_params() + + for batch_size, in_features, out_features, group_size in test_cases: + with self.subTest( + batch_size=batch_size, + in_features=in_features, + out_features=out_features, + group_size=group_size, + ): + seed = 42 + batch_size + in_features # Deterministic but varied seed + + # Test with bias + tensors_with_bias = self._create_test_tensors( + batch_size, + in_features, + out_features, + group_size, + seed=seed, + bias=True, + ) + + qdq_result = self._run_qdq_implementation( + tensors_with_bias, quant_params, group_size + ) + ref_result = self._run_reference_implementation( + tensors_with_bias, quant_params, group_size + ) + + self.assertEqual(qdq_result.shape, ref_result.shape) + self.assertEqual(qdq_result.shape, (batch_size, out_features)) + + self._assert_implementations_close( + qdq_result, + ref_result, + msg_suffix=f" for shape ({batch_size}, {in_features}, {out_features}) with group_size={group_size}", + ) + + # Test without bias + tensors_no_bias = self._create_test_tensors( + batch_size, + in_features, + out_features, + group_size, + seed=seed, + bias=False, + ) + + qdq_result_no_bias = self._run_qdq_implementation( + tensors_no_bias, quant_params, group_size + ) + ref_result_no_bias = self._run_reference_implementation( + tensors_no_bias, quant_params, group_size + ) + + self._assert_implementations_close( + qdq_result_no_bias, + ref_result_no_bias, + msg_suffix=f" for no-bias case with shape ({batch_size}, {in_features}, {out_features}) and group_size={group_size}", + ) + + +if __name__ == "__main__": + unittest.main()