Skip to content

Arm backend: Replace Tensor_Scalar with Tensor_Tensor in sqrt op #11783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
"""
if node.op == "placeholder":
# node is an input, weight or bias node
if not node.users:
return False
consumer_node = list(node.users)[0]
if self.is_weight_node_for_depthwise_conv2d(consumer_node):
return True
Expand Down
66 changes: 46 additions & 20 deletions backends/arm/_passes/decompose_sqrt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,62 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Any, Dict, Tuple

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,)
aten_sqrt_ops = (
torch.ops.aten.sqrt.default,
torch.ops.aten.sqrt_.default,
)

class DecomposeSqrtPass(ExportPass):
def __init__(self) -> None:
super().__init__()

# We cache constant tensor for the exponent
self._half_cache: Dict[Tuple[Any, Any], Any] = {}
self.SQRT_TO_POW = {
exir_ops.edge.aten.sqrt.default: exir_ops.edge.aten.pow.Tensor_Tensor,
torch.ops.aten.sqrt.default: torch.ops.aten.pow.Tensor_Tensor,
torch.ops.aten.sqrt_.default: torch.ops.aten.pow.Tensor_Tensor,
}

def get_sqrt_decomposition(op) -> tuple:
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
if op in edge_sqrt_ops:
return exir_ops.edge.aten.pow.Tensor_Scalar
if op in aten_sqrt_ops:
return torch.ops.aten.pow.Tensor_Scalar
raise RuntimeError(f"Can't get sqrt decomposition for op {op}")
def _get_half_tensor(
self,
dtype: Any,
device: Any,
meta: Any,
) -> Any:
# Choose a floating dtype for 0.5
if dtype in (torch.float16, torch.float32, torch.float64):
half_dtype = dtype
else:
half_dtype = torch.float32

key = (half_dtype, device)
if key not in self._half_cache:
half = super().call_operator(
exir_ops.edge.aten.full.default,
([], 0.5),
{"dtype": half_dtype, "device": device},
meta,
)
self._half_cache[key] = half

class DecomposeSqrtPass(ExportPass):
return self._half_cache[key]

def call_operator(self, op, args, kwargs, meta):
"""
Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.
"""
def call_operator(self, op: Any, args: tuple, kwargs: dict, meta: Any) -> Any:

if op not in (edge_sqrt_ops + aten_sqrt_ops):
if op not in self.SQRT_TO_POW:
return super().call_operator(op, args, kwargs, meta)

pow_op = get_sqrt_decomposition(op)
if len(args) != 1:
raise ValueError(f"Expected 1 arg to sqrt, got {len(args)}")

x = args[0]
pow_op = self.SQRT_TO_POW[op]

half = self._get_half_tensor(x.data.dtype, x.data.device, meta)

return super().call_operator(pow_op, (args[0], 0.5), {}, meta)
return super().call_operator(pow_op, (x, half), {}, meta)
18 changes: 17 additions & 1 deletion backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TableOps:

# Targets that must be treated explicitly
special_table_ops: Set[EdgeOpOverload] = {
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.gelu.default,
}
Expand All @@ -75,6 +76,13 @@ def __getitem__(self, node: Node):
return self.unary_table_ops[target]
elif target in self.special_table_ops:
match target:
case exir_ops.edge.aten.pow.Tensor_Tensor:
# Exponent is a constant. Embed it into a lambda.
exp_node = node.args[1]
exp = float(
self.exported_program.state_dict[exp_node.name].item() # type: ignore[union-attr]
)
return lambda x: torch.pow(x, exp).flatten()
case exir_ops.edge.aten.pow.Tensor_Scalar:
# Exponent is a constant. Embed it into a lambda.
exp = cast(int, node.args[1])
Expand Down Expand Up @@ -283,8 +291,16 @@ def call(self, graph_module: GraphModule) -> PassResult:
modified = True

if modified:
graph_module.graph.eliminate_dead_code()

# Remove any placeholder with zero users
for ph in list(graph_module.graph.nodes):
if ph.op == "placeholder" and len(ph.users) == 0:
graph_module.graph.erase_node(ph)
self.exported_program.state_dict.pop(ph.name, None)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

graph_module.recompile()

return PassResult(graph_module, modified)
4 changes: 2 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def is_node_supported(
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.pow.Tensor_Tensor,
torch.ops.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.where.self,
operator.getitem,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
Expand Down Expand Up @@ -308,8 +309,6 @@ class CheckProperQuantization(OperatorSupportBase):
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
Expand Down Expand Up @@ -412,6 +411,7 @@ def is_node_supported(

input_quantized = input_quantized or all(
(input_node.target in dq_ops)
or (node.name == "aten_pow_tensor_tensor")
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
for input_node in node.all_input_nodes
)
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _match_pattern(
torch.ops.aten.hardswish_.default,
torch.ops.aten.full_like.default,
torch.ops.aten.pow.Tensor_Scalar,
torch.ops.aten.pow.Tensor_Tensor,
torch.ops.aten.gelu.default,
]

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/ops/test_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class Sqrt(torch.nn.Module):
aten_op_MI = "torch.ops.aten.sqrt.default"
exir_op_MI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor"

aten_op_BI = "torch.ops.aten.pow.Tensor_Scalar"
exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar"
aten_op_BI = "torch.ops.aten.pow.Tensor_Tensor"
exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor"

def __init__(self):
super().__init__()
Expand Down
7 changes: 7 additions & 0 deletions backends/arm/tosa_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:

if is_partitioned(node):
for input in node.all_input_nodes:
if input.target in (
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
):
continue
if is_dequant_node(input):
continue
if is_partitioned(input):
continue
if get_first_fake_tensor(input).dtype.is_floating_point:
Expand Down
Loading