|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | + |
| 7 | +from typing import Callable, Iterable, List, ParamSpec, TypeVar |
| 8 | + |
| 9 | +from executorch.backends.arm.tosa.dialect.lib import register_tosa_dialect_op |
| 10 | + |
| 11 | +from executorch.backends.arm.tosa_specification import ( |
| 12 | + get_context_spec, |
| 13 | + TosaSpecification, |
| 14 | +) |
| 15 | + |
| 16 | +P = ParamSpec("P") |
| 17 | +R = TypeVar("R") |
| 18 | + |
| 19 | +# The list of registered ops are not yet used, except for registration |
| 20 | +_tosa_registered_ops: dict[TosaSpecification, list[Callable]] = { |
| 21 | + TosaSpecification.create_from_string("TOSA-1.0+FP"): [], |
| 22 | + TosaSpecification.create_from_string("TOSA-1.0+INT"): [], |
| 23 | +} |
| 24 | + |
| 25 | +# Mapping to ensure we only register a given function once. |
| 26 | +_registered_tosa_ops_by_func: dict[Callable, Callable] = {} |
| 27 | + |
| 28 | + |
| 29 | +def register_tosa_op( |
| 30 | + op_schema: str, tosa_specs: Iterable[TosaSpecification] |
| 31 | +) -> Callable[[Callable[P, R]], Callable[P, R]]: |
| 32 | + """ |
| 33 | + Decorator for registering a TOSA operation. |
| 34 | +
|
| 35 | + Parameters: |
| 36 | + op_schema : A string that defines the operation schema. |
| 37 | + tosa_specs : Iterable of TOSA specification strings, |
| 38 | + e.g. ("TOSA-1.0+INT", "TOSA-1.0+FP"). |
| 39 | +
|
| 40 | + The decorated function is registered with the given op_schema by calling |
| 41 | + register_tosa_dialect_op(op_schema, func) only once per function. The resulting |
| 42 | + callable is then inserted into _tosa_registered_ops for each spec. |
| 43 | + """ |
| 44 | + |
| 45 | + def decorator(func: Callable[P, R]) -> Callable[P, R]: |
| 46 | + # Only call register_tosa_dialect_op if the function hasn't been registered yet. |
| 47 | + if func not in _registered_tosa_ops_by_func: |
| 48 | + op_callable = register_tosa_dialect_op(op_schema, func) |
| 49 | + _registered_tosa_ops_by_func[func] = op_callable |
| 50 | + else: |
| 51 | + op_callable = _registered_tosa_ops_by_func[func] |
| 52 | + |
| 53 | + # For each TOSA spec, ensure the operation is added only once. |
| 54 | + for spec in tosa_specs: |
| 55 | + if spec not in _tosa_registered_ops: |
| 56 | + raise ValueError(f"TOSA spec {spec} not listed for registrations") |
| 57 | + if op_callable not in _tosa_registered_ops[spec]: |
| 58 | + _tosa_registered_ops[spec].append(op_callable) |
| 59 | + |
| 60 | + # return the original function |
| 61 | + return func |
| 62 | + |
| 63 | + return decorator |
| 64 | + |
| 65 | + |
| 66 | +def get_registered_tosa_ops() -> List[Callable]: |
| 67 | + tosa_spec = get_context_spec() |
| 68 | + return _tosa_registered_ops[tosa_spec] |
0 commit comments