Skip to content

Commit f858e0d

Browse files
authored
Arm backend: Introduce TOSA backend dialect (#12195)
### Summary Adds a TOSA backend dialect so it's possible to convert Edge IR to specific TOSA operators in the Arm backends. This is done to enable control of types and additional arguments available for TOSA operators in comparison with the Edge IR. The operators are registered into the exir.backend.tosa namespace and is only traceable, not executable since it's only used in the lowering step to a TOSA serialization format. ### Test plan Existing CI tests Signed-off-by: Per Åstrand <per.astrand@arm.com>
1 parent fd677ac commit f858e0d

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-unsafe
9+
10+
import executorch.backends.arm.tosa.dialect # noqa: unused
911
from executorch.backends.arm._passes import (
1012
AddBiasPass,
1113
AnnotateChannelsLastDimOrder,

backends/arm/tosa/dialect/lib.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable
7+
8+
from executorch.exir.dialects._ops import _BACKEND_OP_LIB, ops as exir_ops
9+
from torch.library import Library, register_fake
10+
from torchgen.model import FunctionSchema
11+
12+
# create a torch library for the TOSA dialect
13+
# This defines a library to include Backend Dialect Operators in Executorch
14+
tosa_lib = Library("tosa", "DEF")
15+
16+
17+
def register_tosa_dialect_op(op_schema, func) -> Callable:
18+
if tosa_lib.ns not in _BACKEND_OP_LIB:
19+
_BACKEND_OP_LIB.append(tosa_lib.ns)
20+
21+
if "::" in op_schema:
22+
raise ValueError("The schema should not contain a namespace.")
23+
24+
# Parse the op_schema into a FunctionSchema
25+
func_schema = FunctionSchema.parse(op_schema)
26+
overload_name = func_schema.name.overload_name
27+
if overload_name:
28+
raise ValueError(
29+
"The TOSA dialect does not support overload names in the op schema."
30+
)
31+
32+
opname = func_schema.name.name.base
33+
tosa_lib.define(op_schema)
34+
35+
overload_name = "default"
36+
op_qualified_name = f"{tosa_lib.ns}::{opname}"
37+
38+
register_fake(op_qualified_name, func, lib=tosa_lib)
39+
40+
op = getattr(getattr(getattr(exir_ops.backend, tosa_lib.ns), opname), overload_name)
41+
42+
# For now, since the TOSA operators are only used for lowering and serialization in the backend
43+
# the op doesn't need to be callable. This can be changed in the future if needed to support
44+
# execution of TOSA ops directly.
45+
def not_callable():
46+
raise RuntimeError("TOSA dialect op is not callable")
47+
48+
op.__equvalent_callable__ = not_callable
49+
50+
return op
51+
52+
53+
class TosaValueError(ValueError):
54+
def __init__(self, message="A TOSA value error occurred", *args, **kwargs):
55+
super().__init__(message, *args, **kwargs)
56+
self.op = kwargs.get("op", None)
57+
58+
def __str__(self):
59+
base_message = super().__str__()
60+
if self.op is not None:
61+
return f"{base_message} (TOSA op: {self.op})"
62+
return base_message
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Callable, Iterable, List, ParamSpec, TypeVar
8+
9+
from executorch.backends.arm.tosa.dialect.lib import register_tosa_dialect_op
10+
11+
from executorch.backends.arm.tosa_specification import (
12+
get_context_spec,
13+
TosaSpecification,
14+
)
15+
16+
P = ParamSpec("P")
17+
R = TypeVar("R")
18+
19+
# The list of registered ops are not yet used, except for registration
20+
_tosa_registered_ops: dict[TosaSpecification, list[Callable]] = {
21+
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
22+
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
23+
}
24+
25+
# Mapping to ensure we only register a given function once.
26+
_registered_tosa_ops_by_func: dict[Callable, Callable] = {}
27+
28+
29+
def register_tosa_op(
30+
op_schema: str, tosa_specs: Iterable[TosaSpecification]
31+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
32+
"""
33+
Decorator for registering a TOSA operation.
34+
35+
Parameters:
36+
op_schema : A string that defines the operation schema.
37+
tosa_specs : Iterable of TOSA specification strings,
38+
e.g. ("TOSA-1.0+INT", "TOSA-1.0+FP").
39+
40+
The decorated function is registered with the given op_schema by calling
41+
register_tosa_dialect_op(op_schema, func) only once per function. The resulting
42+
callable is then inserted into _tosa_registered_ops for each spec.
43+
"""
44+
45+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
46+
# Only call register_tosa_dialect_op if the function hasn't been registered yet.
47+
if func not in _registered_tosa_ops_by_func:
48+
op_callable = register_tosa_dialect_op(op_schema, func)
49+
_registered_tosa_ops_by_func[func] = op_callable
50+
else:
51+
op_callable = _registered_tosa_ops_by_func[func]
52+
53+
# For each TOSA spec, ensure the operation is added only once.
54+
for spec in tosa_specs:
55+
if spec not in _tosa_registered_ops:
56+
raise ValueError(f"TOSA spec {spec} not listed for registrations")
57+
if op_callable not in _tosa_registered_ops[spec]:
58+
_tosa_registered_ops[spec].append(op_callable)
59+
60+
# return the original function
61+
return func
62+
63+
return decorator
64+
65+
66+
def get_registered_tosa_ops() -> List[Callable]:
67+
tosa_spec = get_context_spec()
68+
return _tosa_registered_ops[tosa_spec]

0 commit comments

Comments
 (0)