From 481a8271544255529dc1af2137b6a269f0976388 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Wed, 7 May 2025 22:31:42 +0100 Subject: [PATCH] Arm backend: Replace Tensor_Scalar with Tensor_Tensor in sqrt op Signed-off-by: Elena Zhelezina Change-Id: Ia3e596f855ec97b0ad59161bccc906b13e96c770 --- .../annotate_channels_last_dim_order_pass.py | 2 + backends/arm/_passes/decompose_sqrt_pass.py | 66 +++++++++++++------ backends/arm/_passes/insert_table_ops.py | 18 ++++- .../tosa_supported_operators.py | 4 +- .../arm/quantizer/quantization_annotator.py | 1 + backends/arm/test/ops/test_sqrt.py | 4 +- backends/arm/tosa_partitioner.py | 7 ++ 7 files changed, 77 insertions(+), 25 deletions(-) diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 029dc421920..f887fdbab56 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -61,6 +61,8 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): """ if node.op == "placeholder": # node is an input, weight or bias node + if not node.users: + return False consumer_node = list(node.users)[0] if self.is_weight_node_for_depthwise_conv2d(consumer_node): return True diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index d4a678affea..75cfb6caeee 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -4,36 +4,62 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe + +from typing import Any, Dict, Tuple + import torch + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass -edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,) -aten_sqrt_ops = ( - torch.ops.aten.sqrt.default, - torch.ops.aten.sqrt_.default, -) +class DecomposeSqrtPass(ExportPass): + def __init__(self) -> None: + super().__init__() + + # We cache constant tensor for the exponent + self._half_cache: Dict[Tuple[Any, Any], Any] = {} + self.SQRT_TO_POW = { + exir_ops.edge.aten.sqrt.default: exir_ops.edge.aten.pow.Tensor_Tensor, + torch.ops.aten.sqrt.default: torch.ops.aten.pow.Tensor_Tensor, + torch.ops.aten.sqrt_.default: torch.ops.aten.pow.Tensor_Tensor, + } -def get_sqrt_decomposition(op) -> tuple: - # TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor" - if op in edge_sqrt_ops: - return exir_ops.edge.aten.pow.Tensor_Scalar - if op in aten_sqrt_ops: - return torch.ops.aten.pow.Tensor_Scalar - raise RuntimeError(f"Can't get sqrt decomposition for op {op}") + def _get_half_tensor( + self, + dtype: Any, + device: Any, + meta: Any, + ) -> Any: + # Choose a floating dtype for 0.5 + if dtype in (torch.float16, torch.float32, torch.float64): + half_dtype = dtype + else: + half_dtype = torch.float32 + key = (half_dtype, device) + if key not in self._half_cache: + half = super().call_operator( + exir_ops.edge.aten.full.default, + ([], 0.5), + {"dtype": half_dtype, "device": device}, + meta, + ) + self._half_cache[key] = half -class DecomposeSqrtPass(ExportPass): + return self._half_cache[key] - def call_operator(self, op, args, kwargs, meta): - """ - Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support. - """ + def call_operator(self, op: Any, args: tuple, kwargs: dict, meta: Any) -> Any: - if op not in (edge_sqrt_ops + aten_sqrt_ops): + if op not in self.SQRT_TO_POW: return super().call_operator(op, args, kwargs, meta) - pow_op = get_sqrt_decomposition(op) + if len(args) != 1: + raise ValueError(f"Expected 1 arg to sqrt, got {len(args)}") + + x = args[0] + pow_op = self.SQRT_TO_POW[op] + + half = self._get_half_tensor(x.data.dtype, x.data.device, meta) - return super().call_operator(pow_op, (args[0], 0.5), {}, meta) + return super().call_operator(pow_op, (x, half), {}, meta) diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 402ed0253c0..428306d7c4b 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -57,6 +57,7 @@ class TableOps: # Targets that must be treated explicitly special_table_ops: Set[EdgeOpOverload] = { + exir_ops.edge.aten.pow.Tensor_Tensor, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.gelu.default, } @@ -75,6 +76,13 @@ def __getitem__(self, node: Node): return self.unary_table_ops[target] elif target in self.special_table_ops: match target: + case exir_ops.edge.aten.pow.Tensor_Tensor: + # Exponent is a constant. Embed it into a lambda. + exp_node = node.args[1] + exp = float( + self.exported_program.state_dict[exp_node.name].item() # type: ignore[union-attr] + ) + return lambda x: torch.pow(x, exp).flatten() case exir_ops.edge.aten.pow.Tensor_Scalar: # Exponent is a constant. Embed it into a lambda. exp = cast(int, node.args[1]) @@ -283,8 +291,16 @@ def call(self, graph_module: GraphModule) -> PassResult: modified = True if modified: + graph_module.graph.eliminate_dead_code() + + # Remove any placeholder with zero users + for ph in list(graph_module.graph.nodes): + if ph.op == "placeholder" and len(ph.users) == 0: + graph_module.graph.erase_node(ph) + self.exported_program.state_dict.pop(ph.name, None) + # retrace the graph to update the fake tensor types graph_module = super().call(graph_module).graph_module - graph_module.recompile() + return PassResult(graph_module, modified) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 77d2f1011fa..9cea1f65f2f 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -226,6 +226,7 @@ def is_node_supported( exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.pow.Tensor_Scalar, exir_ops.edge.aten.pow.Tensor_Tensor, + torch.ops.aten.pow.Tensor_Tensor, exir_ops.edge.aten.where.self, operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, @@ -306,8 +307,6 @@ class CheckProperQuantization(OperatorSupportBase): exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.convolution.default, - exir_ops.edge.aten.full.default, - exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.max_pool2d_with_indices.default, @@ -410,6 +409,7 @@ def is_node_supported( input_quantized = input_quantized or all( (input_node.target in dq_ops) + or (node.name == "aten_pow_tensor_tensor") or (not get_first_fake_tensor(input_node).dtype.is_floating_point) for input_node in node.all_input_nodes ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5c2f7822097..7b6d25b66ba 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -212,6 +212,7 @@ def _match_pattern( torch.ops.aten.hardswish_.default, torch.ops.aten.full_like.default, torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.pow.Tensor_Tensor, torch.ops.aten.gelu.default, ] diff --git a/backends/arm/test/ops/test_sqrt.py b/backends/arm/test/ops/test_sqrt.py index 0c79f534656..617d5ecbf25 100644 --- a/backends/arm/test/ops/test_sqrt.py +++ b/backends/arm/test/ops/test_sqrt.py @@ -21,8 +21,8 @@ class Sqrt(torch.nn.Module): aten_op_MI = "torch.ops.aten.sqrt.default" exir_op_MI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor" - aten_op_BI = "torch.ops.aten.pow.Tensor_Scalar" - exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar" + aten_op_BI = "torch.ops.aten.pow.Tensor_Tensor" + exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor" def __init__(self): super().__init__() diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index ee7d1733f37..369ebadc1ff 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -118,6 +118,13 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: if is_partitioned(node): for input in node.all_input_nodes: + if input.target in ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.full_like.default, + ): + continue + if is_dequant_node(input): + continue if is_partitioned(input): continue if get_first_fake_tensor(input).dtype.is_floating_point: