diff --git a/torchao/quantization/pt2e/reference_representation_rewrite.py b/torchao/quantization/pt2e/reference_representation_rewrite.py index 8d1875bfd9..5dad11613b 100644 --- a/torchao/quantization/pt2e/reference_representation_rewrite.py +++ b/torchao/quantization/pt2e/reference_representation_rewrite.py @@ -8,12 +8,13 @@ import contextlib from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional import torch from torch._higher_order_ops.out_dtype import out_dtype from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.fx import GraphModule +from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch from torch.fx.subgraph_rewriter import replace_pattern_with_filters from torchao.quantization.pt2e.export_utils import WrapperModule @@ -23,12 +24,17 @@ _replace_literals_with_new_placeholders, remove_tensor_overload_for_qdq_ops, ) +from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.utils import _get_per_token_block_size +from torchao.utils import _register_custom_op try: from torch._export.utils import _disable_aten_to_metadata_assertions except: _disable_aten_to_metadata_assertions = contextlib.nullcontext +quant_lib = torch.library.Library("torchao", "FRAGMENT") +register_custom_op = _register_custom_op(quant_lib) __all__ = [ "reference_representation_rewrite", @@ -203,6 +209,252 @@ def _reference_dynamic_quantized_linear( return out_fp32 +def _qdq_dynamic_quantized_linear_4bit_groupwise( + x_fp32, + x_eps, + weight_i4, + weight_scale, + weight_zero_point, + bias_fp32, + group_size, +): + # Dynamic quantization of activation + x_mapping_type = MappingType.ASYMMETRIC + per_token_block_size = _get_per_token_block_size(x_fp32) + x_quant_min = -128 + x_quant_max = 127 + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( + x_fp32, + x_mapping_type.name, + per_token_block_size, + torch.int8, + x_quant_min, + x_quant_max, + x_eps, + torch.float32, + torch.int32, + ) + x_i8 = torch.ops.torchao.quantize_affine( + x_fp32, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + ) + x_fp32 = torch.ops.torchao.dequantize_affine( + x_i8, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + torch.float32, + ) + + assert group_size > 0, "Group size must be positive" + assert ( + weight_i4.shape[1] % group_size == 0 + ), "Weight must be divisible by group_size" + assert weight_i4.dim() == 2, "Weight must be 2D tensor" + block_size = (1, group_size) + weight_fp32 = torch.ops.torchao.dequantize_affine( + weight_i4, + block_size, + weight_scale, + weight_zero_point, + torch.int8, + -8, + 7, + ) + + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + + +@register_custom_op +def _reference_dqlinear_int4( + x_fp32: torch.Tensor, + x_eps: float, + weight_i4: torch.Tensor, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, # Not used because assuming weight is symmetric + bias_fp32: Optional[torch.Tensor], + group_size: List[int], +) -> torch.Tensor: + """ + Reference implementation for dynamically quantized linear 4-bit groupwise operation. + This implementation emulates actual numerics of on-device integer compute. + + Args: + x_fp32: Input activation tensor in fp32 + x_eps: Epsilon for quantization parameter computation + weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7]) + weight_scale: Groupwise scales for weight dequantization + weight_zero_point: Groupwise zero points for weight (unused for symmetric) + bias_fp32: Optional bias tensor in fp32 + group_size: Size of each group for groupwise quantization + + Returns: + Output tensor in fp32 + """ + # Dynamic quantization of activation + group_size = group_size[1] + x_mapping_type = MappingType.ASYMMETRIC + per_token_block_size = _get_per_token_block_size(x_fp32) + x_quant_min = -128 + x_quant_max = 127 + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( + x_fp32, + x_mapping_type.name, + per_token_block_size, + torch.int8, + x_quant_min, + x_quant_max, + x_eps, + torch.float32, + torch.int32, + ) + x_i8 = torch.ops.torchao.quantize_affine( + x_fp32, + per_token_block_size, + x_scale, + x_zero_point, + torch.int8, + x_quant_min, + x_quant_max, + ) + + # For groupwise quantization, we need to handle the computation differently + # weight_i4 shape: [out_features, in_features] + # weight_scale shape: [out_features, in_features // group_size] + # weight_zero_point shape: [out_features, in_features // group_size] + out_features, in_features = weight_i4.shape + num_groups = in_features // group_size + + # scales in xnnpack are stored as bf16 and converted to fp32 for computation + weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32) + + # Reshape for group-wise processing + # x: [batch_size, in_features] -> [batch_size, num_groups, group_size] + x_orig_shape = x_i8.shape + k_dim = x_i8.shape[-1] + x_i8 = x_i8.view(-1, k_dim) + batch_size = x_i8.shape[0] + x_i8_grouped = x_i8.view(batch_size, num_groups, group_size) + + # weight: [out_features, in_features] -> [out_features, num_groups, group_size] + weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size) + + # Convert to int16 for computation + x_i32_grouped = x_i8_grouped.to(torch.int32) + weight_i32_grouped = weight_i4_grouped.to(torch.int32) + + # Perform groupwise integer linear operation + acc_fp32 = torch.zeros( + batch_size, out_features, dtype=torch.float32, device=x_fp32.device + ) + out_shape = list(x_orig_shape) + out_shape[-1] = out_features + + if weight_scale.ndim == 1: + weight_scale = weight_scale.unsqueeze(0) + + for group_idx in range(num_groups): + # Extract current group + x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size] + weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size] + weight_group_col_sum = weight_group.sum(dim=-1) # [out_features] + + # Get scale for this group + weight_scale_group = weight_scale[:, group_idx] # [out_features] + + # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features] + group_acc = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_group, + weight_group, + None, + ) + + # Output has to be scaled by x_scale * weight_scale_group + # However we will first scale by weight_scale_group, that is accounting + # only for scale of weight, and then scale by x_scale at the end because + # x_scale applies to all groups + acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view( + 1, -1 + ) + + # we must also subtract x_zero_point * weight_group_sum + # since (X - x_zero_point) * W = X * W - x_zero_point * W + weights_col_sum_adjusted = ( + weight_group_col_sum.to(torch.float32).view(1, -1) + * x_zero_point.view(-1, 1) + * weight_scale_group.view(1, -1) + ) + acc_fp32 = acc_fp32 - weights_col_sum_adjusted + x_scale_multiplier = x_scale.view(-1, 1) + out_fp32 = acc_fp32 * x_scale_multiplier + if bias_fp32 is not None: + out_fp32 = out_fp32 + bias_fp32 + + return out_fp32.view(out_shape) + + +def _reference_dynamic_quantized_linear_4bit_groupwise( + x_fp32, + x_eps, + weight_i4, + weight_scale, + weight_zero_point, # Not used because assuming weight is symmetric + bias_fp32, + group_size, +): + """ + Reference implementation for dynamically quantized linear 4-bit groupwise operation. + This function now delegates to the custom op implementation. + """ + return torch.ops.torchao.reference_dqlinear_int4( + x_fp32, + x_eps, + weight_i4, + weight_scale, + weight_zero_point, + bias_fp32, + (1, group_size), + ) + + +def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise( + match, + original_graph, + pattern_graph, +) -> bool: + weight_is_int4 = False + act_quant_is_int8 = False + for node in match.nodes_map.values(): + if ( + isinstance(node, torch.fx.Node) + and node.op == "call_function" + and node.target == torch.ops.torchao.dequantize_affine.default + ): + args = node.args + if len(args) >= 7: + weight_is_int4 = args[5] == -8 and args[6] == 7 + if ( + isinstance(node, torch.fx.Node) + and node.op == "call_function" + and node.target == torch.ops.torchao.quantize_affine.default + ): + args = node.args + if len(args) >= 5: + act_quant_is_int8 = args[4] == torch.int8 + return weight_is_int4 and act_quant_is_int8 + + def _qdq_quantized_conv2d( x_i8, x_scale, @@ -627,6 +879,9 @@ class _RewriteInfo: # post transformation on the exported pattern and replacement GraphModule pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + filter_fn: Optional[ + list[Callable[["InternalMatch", torch.fx.Graph, torch.fx.Graph], bool]] + ] = None ignore_literals: bool = False @@ -739,6 +994,31 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: 127, ) + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1 = ( + torch.randn((1, 32), dtype=torch.float), # x_fp32 + torch.finfo(torch.float32).eps, # x_eps + torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8) + torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups] + torch.zeros( + 8, 4, dtype=torch.int + ), # weight_zero_point [out_features, num_groups] + torch.randn(8, dtype=torch.float), # bias_fp32 + 8, # group_size + ) + + # just saw that we can match again > 2 dim input. Hacky. + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2 = ( + torch.randn((1, 1, 32), dtype=torch.float), # x_fp32 + torch.finfo(torch.float32).eps, # x_eps + torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8) + torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups] + torch.zeros( + 8, 4, dtype=torch.int + ), # weight_zero_point [out_features, num_groups] + torch.randn(8, dtype=torch.float), # bias_fp32 + 8, # group_size + ) + _REWRITE_INFO_LIST = [ _RewriteInfo( _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, @@ -753,6 +1033,48 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, ), ), + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1, + WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise), + WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise], + ignore_literals=True, + ), + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2, + WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise), + WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + torch.finfo(torch.float32).eps: 1, + (1, 8): 6, + }, + ), + filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise], + ignore_literals=True, + ), _RewriteInfo( _QUANTIZED_LINEAR_EXAMPLE_INPUTS, WrapperModule(_qdq_quantized_linear), @@ -835,7 +1157,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: model, pattern, replacement, - match_filters=None, + match_filters=rewrite_info.filter_fn, ignore_literals=rewrite_info.ignore_literals, ) # type: ignore[arg-type] diff --git a/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py new file mode 100644 index 0000000000..91b13144b5 --- /dev/null +++ b/torchao/quantization/pt2e/tests/test_reference_representation_rewrite.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +import torch.nn as nn + +from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ +from torchao.quantization.pt2e.reference_representation_rewrite import ( + _qdq_dynamic_quantized_linear_4bit_groupwise, + _reference_dynamic_quantized_linear_4bit_groupwise, + reference_representation_rewrite, +) +from torchao.utils import unwrap_tensor_subclass + + +class TestReferenceRepresentationRewrite(unittest.TestCase): + """Test cases for dynamically quantized linear 4-bit groupwise implementations.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + # This is a bit hacked since it makes all tests pass + # purpose of these tests is to catch no wild regressions and 1e-1 + # is ok for now + torch.manual_seed(78) + + def _get_default_quantization_params(self): + """Get default quantization parameters.""" + return { + "x_eps": torch.finfo(torch.float32).eps, + } + + def _create_test_tensors( + self, batch_size, in_features, out_features, group_size, bias=True + ): + """Create test tensors for the given dimensions.""" + + # Create input activation + x_fp32 = torch.randn(batch_size, in_features, dtype=torch.float32) + + # Create 4-bit quantized weight (stored as int8 with values in [-8, 7]) + weight_i4 = torch.randint(-8, 7, (out_features, in_features), dtype=torch.int8) + + # Create groupwise scales and zero points + num_groups = in_features // group_size + weight_scale = ( + torch.randn(out_features, num_groups, dtype=torch.float32).abs() + 0.01 + ) + weight_zero_point = torch.zeros( + out_features, num_groups, dtype=torch.int8 + ) # Symmetric quantization + + # Create bias if requested + bias_fp32 = torch.randn(out_features, dtype=torch.float32) if bias else None + + return { + "x_fp32": x_fp32, + "weight_i4": weight_i4, + "weight_scale": weight_scale, + "weight_zero_point": weight_zero_point, + "bias_fp32": bias_fp32, + } + + def _run_qdq_implementation(self, tensors, quant_params, group_size): + """Run the QDQ implementation with given tensors and parameters.""" + return _qdq_dynamic_quantized_linear_4bit_groupwise( + x_fp32=tensors["x_fp32"], + x_eps=quant_params["x_eps"], + weight_i4=tensors["weight_i4"], + weight_scale=tensors["weight_scale"], + weight_zero_point=tensors["weight_zero_point"], + bias_fp32=tensors["bias_fp32"], + group_size=group_size, + ) + + def _run_reference_implementation(self, tensors, quant_params, group_size): + """Run the reference implementation with given tensors and parameters.""" + return _reference_dynamic_quantized_linear_4bit_groupwise( + x_fp32=tensors["x_fp32"], + x_eps=quant_params["x_eps"], + weight_i4=tensors["weight_i4"], + weight_scale=tensors["weight_scale"], + weight_zero_point=tensors["weight_zero_point"], + bias_fp32=tensors["bias_fp32"], + group_size=group_size, + ) + + def _assert_basic_properties(self, result, expected_shape): + """Assert basic properties of the result tensor.""" + self.assertEqual(result.shape, expected_shape) + self.assertEqual(result.dtype, torch.float32) + + def _assert_implementations_close( + self, qdq_result, ref_result, atol=1e-1, rtol=1e-1, msg_suffix="" + ): + """Assert that QDQ and reference implementations produce similar results.""" + torch.testing.assert_close( + qdq_result, + ref_result, + atol=atol, + rtol=rtol, + msg=f"QDQ and reference results differ significantly{msg_suffix}", + ) + + def test_qdq_dynamic_quantized_linear_4bit_groupwise_basic(self): + """Test that QDQ implementation runs without errors and produces reasonable output.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 2, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + result = self._run_qdq_implementation(tensors, quant_params, group_size) + self._assert_basic_properties(result, (batch_size, out_features)) + + def test_reference_dynamic_quantized_linear_4bit_groupwise_basic(self): + """Test that reference implementation runs without errors and produces reasonable output.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 2, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + result = self._run_reference_implementation(tensors, quant_params, group_size) + self._assert_basic_properties(result, (batch_size, out_features)) + + def test_both_implementations_no_bias(self): + """Test both implementations without bias.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 1, 16, 4, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size, bias=False + ) + + qdq_result = self._run_qdq_implementation(tensors, quant_params, group_size) + ref_result = self._run_reference_implementation( + tensors, quant_params, group_size + ) + + self._assert_basic_properties(qdq_result, (batch_size, out_features)) + self._assert_basic_properties(ref_result, (batch_size, out_features)) + self._assert_implementations_close( + qdq_result, ref_result, msg_suffix=" for no-bias case" + ) + + def test_edge_cases_group_size_validation(self): + """Test edge cases and error conditions.""" + # Test-specific parameters + batch_size, in_features, out_features = 1, 32, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, 8 + ) # Valid group size for tensor creation + + # Test with group_size that doesn't divide in_features evenly + with self.assertRaises(AssertionError): + self._run_qdq_implementation( + tensors, quant_params, 7 + ) # 32 is not divisible by 7 + + # Test with zero group_size + with self.assertRaises(AssertionError): + self._run_qdq_implementation(tensors, quant_params, 0) + + def test_weight_dimension_validation(self): + """Test weight dimension validation.""" + # Test-specific parameters + batch_size, in_features, out_features, group_size = 1, 32, 8, 8 + + quant_params = self._get_default_quantization_params() + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + # Create 1D weight tensor (should fail) + tensors["weight_i4"] = torch.randint(-8, 7, (in_features,), dtype=torch.int8) + + with self.assertRaises((AssertionError, IndexError)): + self._run_qdq_implementation(tensors, quant_params, group_size) + + def test_different_group_sizes(self): + """Test with different valid group sizes.""" + # Test-specific parameters + batch_size, in_features, out_features = 2, 64, 16 + group_sizes = [8, 16, 32] + + quant_params = self._get_default_quantization_params() + + for group_size in group_sizes: + with self.subTest(group_size=group_size): + tensors = self._create_test_tensors( + batch_size, in_features, out_features, group_size + ) + + qdq_result = self._run_qdq_implementation( + tensors, quant_params, group_size + ) + ref_result = self._run_reference_implementation( + tensors, quant_params, group_size + ) + + self._assert_basic_properties(qdq_result, (batch_size, out_features)) + self._assert_basic_properties(ref_result, (batch_size, out_features)) + self._assert_implementations_close( + qdq_result, ref_result, msg_suffix=f" for group_size={group_size}" + ) + + def test_qdq_vs_reference_implementation_comparison(self): + """Test that QDQ and reference implementations produce similar results with various configurations.""" + # Test-specific parameters + test_cases = [ + (1, 32, 8, 8), + (2, 64, 16, 16), + (4, 128, 32, 32), + ] + + quant_params = self._get_default_quantization_params() + + for batch_size, in_features, out_features, group_size in test_cases: + with self.subTest( + batch_size=batch_size, + in_features=in_features, + out_features=out_features, + group_size=group_size, + ): + # Test with bias + tensors_with_bias = self._create_test_tensors( + batch_size, + in_features, + out_features, + group_size, + bias=True, + ) + + qdq_result = self._run_qdq_implementation( + tensors_with_bias, quant_params, group_size + ) + ref_result = self._run_reference_implementation( + tensors_with_bias, quant_params, group_size + ) + + self.assertEqual(qdq_result.shape, ref_result.shape) + self.assertEqual(qdq_result.shape, (batch_size, out_features)) + + self._assert_implementations_close( + qdq_result, + ref_result, + msg_suffix=f" for shape ({batch_size}, {in_features}, {out_features}) with group_size={group_size}", + ) + + # Test without bias + tensors_no_bias = self._create_test_tensors( + batch_size, + in_features, + out_features, + group_size, + bias=False, + ) + + qdq_result_no_bias = self._run_qdq_implementation( + tensors_no_bias, quant_params, group_size + ) + ref_result_no_bias = self._run_reference_implementation( + tensors_no_bias, quant_params, group_size + ) + + self._assert_implementations_close( + qdq_result_no_bias, + ref_result_no_bias, + msg_suffix=f" for no-bias case with shape ({batch_size}, {in_features}, {out_features}) and group_size={group_size}", + ) + + +class SimpleLinearModel(nn.Module): + """Simple model with linear layers for testing model rewrite functionality.""" + + def __init__(self, input_size=128, hidden_size=64, output_size=32): + super().__init__() + self.linear1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + +class TestModelRewrite(unittest.TestCase): + """Test cases for model rewrite functionality with 8da4w quantization.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + torch.manual_seed(42) + + def test_export_and_rewrite_workflow(self): + """Test the complete export and rewrite workflow.""" + # Create model + model = SimpleLinearModel(input_size=64, hidden_size=32, output_size=16) + example_input = torch.randn(1, 64) + + # Apply 8da4w quantization + quantize_(model, int8_dynamic_activation_int4_weight(group_size=32)) + + # Unwrap tensor subclasses for export compatibility + model = unwrap_tensor_subclass(model) + + # Export model + exported_model = torch.export.export(model, (example_input,)) + + # Check that export was successful + self.assertIsNotNone(exported_model) + self.assertTrue(hasattr(exported_model, "graph_module")) + + # Test the exported model + with torch.no_grad(): + original_output = exported_model.module()(example_input) + + # Create a copy for rewriting + rewritten_model = copy.deepcopy(exported_model) + + # Apply reference representation rewrite + reference_representation_rewrite(rewritten_model.graph_module) + + # Test the rewritten model + with torch.no_grad(): + rewritten_output = rewritten_model.module()(example_input) + + # Check that outputs are close + self.assertEqual(original_output.shape, rewritten_output.shape) + self.assertEqual(original_output.dtype, rewritten_output.dtype) + + # The outputs should be close (allowing for some numerical differences) + torch.testing.assert_close( + original_output, rewritten_output, atol=5e-2, rtol=5e-2 + ) + + def test_different_group_sizes_rewrite(self): + """Test rewrite functionality with different group sizes.""" + group_sizes = [16, 32, 64] + + for group_size in group_sizes: + with self.subTest(group_size=group_size): + # Create model + model = SimpleLinearModel(input_size=64, hidden_size=32, output_size=16) + example_input = torch.randn(1, 2, 64) + + # Apply quantization with specific group size + quantize_( + model, int8_dynamic_activation_int4_weight(group_size=group_size) + ) + + # Unwrap tensor subclasses for export compatibility + model = unwrap_tensor_subclass(model) + + # Export and test rewrite + exported_model = torch.export.export(model, (example_input,)) + + # Test the exported model + with torch.no_grad(): + original_output = exported_model.module()(example_input) + + # Create a copy for rewriting + rewritten_model = copy.deepcopy(exported_model) + + # Apply reference representation rewrite + reference_representation_rewrite(rewritten_model.graph_module) + + # Test the rewritten model + with torch.no_grad(): + rewritten_output = rewritten_model.module()(example_input) + + # The outputs should be close (allowing for some numerical differences) + torch.testing.assert_close( + original_output, + rewritten_output, + atol=5e-2, + rtol=5e-2, + msg=f"Rewrite failed for group_size={group_size}", + ) + + def test_model_without_bias_rewrite(self): + """Test rewrite functionality with linear layers that have no bias.""" + # Create model without bias + model = SimpleLinearModel(input_size=32, hidden_size=16, output_size=8) + model.linear1.bias = None + model.linear2.bias = None + + example_input = torch.randn(1, 32) + + # Apply quantization + quantize_(model, int8_dynamic_activation_int4_weight(group_size=16)) + + # Unwrap tensor subclasses for export compatibility + model = unwrap_tensor_subclass(model) + + # Export and test rewrite + exported_model = torch.export.export(model, (example_input,)) + + # Test the exported model + with torch.no_grad(): + original_output = exported_model.module()(example_input) + + # Create a copy for rewriting + rewritten_model = copy.deepcopy(exported_model) + + # Apply reference representation rewrite + reference_representation_rewrite(rewritten_model.graph_module) + + # Test the rewritten model + with torch.no_grad(): + rewritten_output = rewritten_model.module()(example_input) + + # The outputs should be close (allowing for some numerical differences) + torch.testing.assert_close( + original_output, + rewritten_output, + atol=5e-2, + rtol=5e-2, + msg="Rewrite failed for model without bias", + ) + + +if __name__ == "__main__": + unittest.main()