From a0f18c874566d8f00edd60f870db47e2ba1a4093 Mon Sep 17 00:00:00 2001 From: morelos Date: Wed, 16 Jul 2025 17:43:06 -0700 Subject: [PATCH] [ET-VK] linear_qta8a_qga4w graph pass # Changes * Introduce `linear_qta8a_qga4w` custom operator in `custom_ops_lib.py` to handle dynamic activation + grouped weight quantized linear operations * Add pattern matching and fusion logic in `FuseQuantizedOpsTransform` to detect and replace dequant + dequant + linear sequences with the new fused operator * Implement comprehensive test coverage in `test_vulkan_passes.py` for the QTA8A_QGA4W fusion pattern validation * Add 4-bit weight packing utilities and grouped quantization support for efficient memory usage # Motivation The existing quantization workflow in Vulkan backend processes dynamic activation + grouped weight quantized linear operations as separate quantize/dequantize/linear steps, which creates performance overhead through: * Multiple kernel dispatches instead of a single fused operation * Intermediate tensor allocations for dequantized weights and activations * Suboptimal memory bandwidth utilization The new `linear_qta8a_qga4w` operator fuses the entire sequence into a single operation that: * Directly processes 8-bit quantized activations with per-token scales/zero-points * Handles 4-bit grouped quantized weights with configurable group sizes * Eliminates intermediate dequantization steps by performing dequantization inline * Reduces memory footprint through packed 4-bit weight storage This aligns with the broader goal of optimizing quantized model inference in the Vulkan backend by leveraging graph-level transformations to improve computational efficiency while maintaining numerical accuracy. Differential Revision: [D78291269](https://our.internmc.facebook.com/intern/diff/D78291269/) [ghstack-poisoned] --- backends/vulkan/_passes/fuse_quantized_ops.py | 205 ++++++++++++++++++ backends/vulkan/custom_ops_lib.py | 90 ++++++++ backends/vulkan/op_registry.py | 7 +- backends/vulkan/test/TARGETS | 1 + backends/vulkan/test/test_vulkan_passes.py | 54 +++++ backends/vulkan/utils.py | 15 ++ 6 files changed, 371 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index 805a5c1f744..94fe454ff21 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -210,6 +210,200 @@ def fuse_into_linear_qcnw_node( graph_module.graph.erase_node(dq_weight_node) +######################### +## linear_qta8a_qga4w ## +######################### + + +def matches_linear_qta8a_qga4w_pattern( + program: ExportedProgram, node: torch.fx.Node +) -> Optional[Tuple[int, int]]: + """ + Checks if the nodes surrounding a linear node matches the pattern for dynamic + activation + grouped weight quantized linear (QTA8A_QGA4W). + + This pattern involves: + 1. Dynamic quantization of input activations (8-bit) + 2. Grouped quantization of weights (4-bit with group size) + + The expected pattern from Int8DynActInt4WeightQuantizer is: + scale, zero_point = choose_qparams_affine(input) + quantized_input = quantize_affine(input, scale, zero_point) + dequantized_input = dequantize_affine(quantized_input, ...) + dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros) + output = linear(dequantized_input, dequantized_weight) + + If the pattern matches, return (group_size, weight_bits), otherwise None. + """ + if not utils.is_linear_node(node): + return None + + input_node = node.args[0] + weight_node = node.args[1] + + # Type checking - ensure we have torch.fx.Node objects + if not isinstance(weight_node, torch.fx.Node): + return None + if not isinstance(input_node, torch.fx.Node): + return None + + # Check if input is dequantized with dequantize_affine (from dynamic quantization) + if not ( + input_node.op == "call_function" + and input_node.target is not None + and hasattr(input_node.target, "__name__") + and "dequantize_affine" in getattr(input_node.target, "__name__", "") + ): + return None + + # Check if weight is dequantized with dequantize_affine + if not ( + weight_node.op == "call_function" + and weight_node.target is not None + and hasattr(weight_node.target, "__name__") + and "dequantize_affine" in getattr(weight_node.target, "__name__", "") + ): + return None + + # Get the original quantized weight and quantization parameters + if len(weight_node.args) < 4: + return None + + orig_weight = weight_node.args[0] + weight_scales = weight_node.args[2] + weight_zeros = weight_node.args[3] + + # Type checking + if not isinstance(orig_weight, torch.fx.Node): + return None + if not is_param_node(program, orig_weight): + return None + if not isinstance(weight_scales, torch.fx.Node): + return None + if not is_param_node(program, weight_scales): + return None + if not isinstance(weight_zeros, torch.fx.Node): + return None + if not is_param_node(program, weight_zeros): + return None + + # Get tensors to analyze the quantization scheme + orig_weight_tensor = get_param_tensor(program, orig_weight) + weight_scales_tensor = get_param_tensor(program, weight_scales) + weight_zeros_tensor = get_param_tensor(program, weight_zeros) + + if not isinstance(orig_weight_tensor, torch.Tensor): + return None + if not isinstance(weight_scales_tensor, torch.Tensor): + return None + if not isinstance(weight_zeros_tensor, torch.Tensor): + return None + + # Check if weight is quantized to 4 bits (values should be in [-8, 7] range) + quant_min = orig_weight_tensor.min().item() + quant_max = orig_weight_tensor.max().item() + + if not (quant_min >= -8 and quant_max <= 7): + return None + + # Determine group size from the scales tensor shape + # For grouped quantization, scales shape should be [out_features, in_features // group_size] + out_features, in_features = orig_weight_tensor.shape + + if len(weight_scales_tensor.shape) != 2: + return None + + scales_out_features, num_groups = weight_scales_tensor.shape + + if scales_out_features != out_features: + return None + + group_size = in_features // num_groups + if in_features % group_size != 0: + return None + + # Verify this is 4-bit grouped quantization + weight_bits = 4 + + return group_size, weight_bits + + +def fuse_into_linear_qta8a_qga4w_node( + program: ExportedProgram, + graph_module: torch.fx.GraphModule, + linear_node: torch.fx.Node, + group_size: int, + weight_bits: int, +) -> None: + """ + Fuse the dynamic activation + grouped weight quantized linear pattern into + a single linear_qta8a_qga4w operator. + + The pattern: + dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...) + dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...) + output = linear(dequantized_input, dequantized_weight) + + Becomes: + output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point, + weight, group_size, weight_scales, weight_zeros) + """ + dq_input_node = linear_node.args[0] + dq_weight_node = linear_node.args[1] + + assert isinstance(dq_input_node, torch.fx.Node) + assert isinstance(dq_weight_node, torch.fx.Node) + + # Get the quantized input and quantization parameters from the input dequantize_affine node + # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype) + quantized_input = dq_input_node.args[0] + input_scale = dq_input_node.args[2] # scale is the 3rd argument + input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None + + # Get the weight and its quantization parameters from dequantize_affine + # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype) + orig_weight = dq_weight_node.args[0] + weight_scales = dq_weight_node.args[2] + weight_zeros = dq_weight_node.args[3] + + # Pack the 4-bit weight tensor for efficient storage + assert isinstance(orig_weight, torch.fx.Node) + orig_weight_tensor = get_param_tensor(program, orig_weight) + assert isinstance(orig_weight_tensor, torch.Tensor) + packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor) + utils.update_program_state_dict( + program, + orig_weight.name, + packed_weight_tensor, + ) + # Update the metadata to reflect the new packed shape + orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8) + + # Create the linear_qta8a_qga4w node + with graph_module.graph.inserting_before(linear_node): + linear_qta8a_qga4w_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.linear_qta8a_qga4w.default, + ( + quantized_input, # quantized input (int8) + input_scale, # mat1_scale + input_zero_point, # mat1_zero_point + orig_weight, # mat2_data (packed 4-bit weights) + group_size, # group_size (int) + weight_scales, # weight_scales + weight_zeros, # weight_zeros + ), + ) + + # Replace the linear node with the new fused node + linear_node.replace_all_uses_with(linear_qta8a_qga4w_node) + + # Erase nodes in the correct order (users first, then dependencies) + graph_module.graph.erase_node(linear_node) + graph_module.graph.erase_node(dq_weight_node) + graph_module.graph.erase_node(dq_input_node) + + class FuseQuantizedOpsTransform(ExportPass): def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() @@ -217,12 +411,23 @@ def __init__(self, exported_program: ExportedProgram) -> None: def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in graph_module.graph.nodes: + # Check for linear_qcnw pattern (weight-only quantization) qcnw_details = matches_linear_qcnw_pattern(self.program, node) if qcnw_details is not None: qcnw_method, qcnw_nbits = qcnw_details fuse_into_linear_qcnw_node( self.program, graph_module, node, qcnw_method, qcnw_nbits ) + continue + + # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization) + qta8a_qga4w_details = matches_linear_qta8a_qga4w_pattern(self.program, node) + if qta8a_qga4w_details is not None: + group_size, weight_bits = qta8a_qga4w_details + fuse_into_linear_qta8a_qga4w_node( + self.program, graph_module, node, group_size, weight_bits + ) + continue graph_module.recompile() dead_code_elimination_pass(graph_module) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index af6fcbfbb14..ebbd2bdc6a7 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -231,6 +231,96 @@ def linear_qcs4w( lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd") linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) +######################## +## linear_qta8a_qga4w ## +######################## + + +def linear_qta8a_qga4w( + x_quantized: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weights_4bit: torch.Tensor, + group_size: int, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, +): + """ + Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W). + + Args: + x_quantized: Already quantized input tensor (int8, per-token quantized) + input_scale: Scale for per-token quantization of input (shape: [batch_size]) + input_zero_point: Zero point for per-token quantization of input (shape: [batch_size]) + weights_4bit: Packed 4-bit quantized weights + group_size: Group size for weight quantization (int) + weight_scales: Per-group scales for weights + weight_zeros: Per-group zero points for weights + """ + original_x_shape = x_quantized.shape + batch_size = original_x_shape[0] + feature_dim = original_x_shape[-1] + + # Reshape for processing + x_quantized_2d = x_quantized.reshape(-1, feature_dim) + + # Unpack 4-bit weights + unpacked_weights_shape = weights_4bit.shape + out_features = unpacked_weights_shape[0] + in_features = unpacked_weights_shape[1] + + weights_unpacked = torch.empty( + (out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device + ) + + weights_unpacked[:, ::2] = weights_4bit >> 4 + weights_unpacked[:, 1::2] = weights_4bit & 0x0F + + # Convert to signed 4-bit range [-8, 7] + weights_unpacked = torch.where( + weights_unpacked > 7, weights_unpacked - 16, weights_unpacked + ) + + # Dequantize weights using grouped quantization + actual_in_features = in_features * 2 + num_groups = actual_in_features // group_size + + # Reshape weights for grouped dequantization + weights_grouped = weights_unpacked.view(out_features, num_groups, group_size) + + # Expand scales and zeros to match grouped weights + scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size) + zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size) + + # Dequantize: (quantized - zero_point) * scale + dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded + dq_weights = dq_weights_grouped.view(out_features, actual_in_features) + + # Dequantize input (per-token) + # For per-token quantization, each token (row) has its own scale and zero_point + x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token( + x_quantized_2d, + input_scale, + input_zero_point, + -128, + 127, + torch.int8, + torch.float32, + ) + + # Perform linear operation + out = torch.nn.functional.linear(x_dequantized, dq_weights) + out_shape = original_x_shape[:-1] + (out_features,) + return out.reshape(out_shape) + + +name = "linear_qta8a_qga4w" +lib.define( + f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor" +) +lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd") +linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 619b5f24b82..badb5588519 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -487,7 +487,12 @@ def register_int8_mm_op(features: OpFeatures): return features -@update_features(exir_ops.edge.et_vk.linear_weight_int4.default) +@update_features( + [ + exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.et_vk.linear_qta8a_qga4w.default, + ] +) def register_int4_mm_op(features: OpFeatures): features.buffer_impl = True features.texture_impl = TextureImplFeatures( diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 8f07040d586..6693d15f344 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -34,6 +34,7 @@ python_unittest( "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/vulkan:vulkan_preprocess", + "//pytorch/ao:torchao", # @manual ] ) diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 0de1fd5d69a..4f54bc638ba 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -16,6 +16,7 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) +from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightQuantizer from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import Quantizer @@ -153,3 +154,56 @@ def test_fuse_linear_qcs4w(self): self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + + def test_fuse_linear_qta8a_qga4w(self): + """Test fusion of dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).""" + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + # Use source transform quantizer for dynamic activation + grouped weight quantization + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=128, # Group size for 4-bit weights + padding_allowed=False, + precision=torch.float32, + scales_precision=torch.float32, + device=torch.device("cpu"), + ) + + # Apply source transform quantization + quantized_model = quantizer.quantize(model) + + # Export the quantized model + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, + _check_ir_validity=False, + ) + + program = torch.export.export_for_training( + quantized_model, sample_inputs, strict=True + ).module() + + program = torch.export.export(program, sample_inputs) + + edge_manager = to_edge( + program, + compile_config=edge_compile_config, + ) + + ep = edge_manager._edge_programs["forward"] + edge_manager.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(ep), + ] + ) + + gm = ep.graph_module + + # Check that the linear_qta8a_qga4w operator was created + self.assertEqual(op_node_count(gm, "linear_qta8a_qga4w.default"), 1) + # Check that the original quantization/dequantization nodes were removed + self.assertEqual(op_node_count(gm, "quantize_per_token.default"), 0) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + self.assertEqual(op_node_count(gm, "linear.default"), 0) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index d71c0a35776..9086b2d0792 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -38,6 +38,14 @@ "dequantize_affine.default", } +_Q_OPS = { + "quantize_per_tensor.tensor", + "quantize_per_tensor.default", + "quantize_per_channel.default", + "quantize_per_token.default", + "quantize_affine.default", +} + ## ## Node type determination ## @@ -50,6 +58,13 @@ def is_dequant_node(node: torch.fx.Node) -> bool: return node_name in _DQ_OPS +def is_quant_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name in _Q_OPS + + def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False