Skip to content

Commit 1fdebef

Browse files
Support for unified recipe registry and user interfaces (#12248) (#12319)
Summary: Implements the RFC: #12248 1. `BackendRecipeProvider` -> Abstract interface that all backends must implement while providing recipes. 1. `recipe_registry` -> Singleton registry that maintains `BackendRecipeProviders` 1. `RecipeType` -> Abstract enum, backends extend this to provide support for specific recipes. 1. `ExportRecipe` will have two class methods a. `get_recipe` -> Queries registry to get recipe w/ specific backend One can simply get and use recipes with: ## Using recipe as-is ``` recipe = ExportRecipe.get_recipe( XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, group_size=32) export(eager_model, example_inputs, dynamic_shapes, recipe) ``` Differential Revision: D78034047
1 parent 4e29bc9 commit 1fdebef

File tree

10 files changed

+646
-33
lines changed

10 files changed

+646
-33
lines changed

export/TARGETS

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ python_library(
2424
deps = [
2525
":recipe",
2626
"//executorch/runtime:runtime",
27+
":recipe_registry"
2728
]
2829
)
2930

@@ -35,5 +36,30 @@ python_library(
3536
deps = [
3637
":export",
3738
":recipe",
39+
":recipe_registry",
40+
":recipe_provider"
3841
],
3942
)
43+
44+
45+
python_library(
46+
name = "recipe_registry",
47+
srcs = [
48+
"recipe_registry.py",
49+
],
50+
deps = [
51+
":recipe",
52+
":recipe_provider"
53+
],
54+
)
55+
56+
57+
python_library(
58+
name = "recipe_provider",
59+
srcs = [
60+
"recipe_provider.py",
61+
],
62+
deps = [
63+
":recipe",
64+
]
65+
)

export/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-strict
8+
79
"""
810
ExecuTorch export module.
911
@@ -12,13 +14,18 @@
1214
export management.
1315
"""
1416

15-
# pyre-strict
16-
1717
from .export import export, ExportSession
18-
from .recipe import ExportRecipe
18+
from .recipe import ExportRecipe, QuantizationRecipe, RecipeType
19+
from .recipe_provider import BackendRecipeProvider
20+
from .recipe_registry import recipe_registry
21+
1922

2023
__all__ = [
2124
"ExportRecipe",
25+
"QuantizationRecipe",
2226
"ExportSession",
2327
"export",
28+
"BackendRecipeProvider",
29+
"recipe_registry",
30+
"RecipeType",
2431
]

export/export.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from abc import ABC, abstractmethod
28
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
39

@@ -10,11 +16,14 @@
1016
ExecutorchProgramManager,
1117
to_edge_transform_and_lower,
1218
)
19+
from executorch.exir.program._program import _transform
1320
from executorch.exir.schema import Program
1421
from executorch.extension.export_util.utils import save_pte_program
1522
from executorch.runtime import Runtime, Verification
1623
from tabulate import tabulate
1724
from torch import nn
25+
26+
from torch._export.pass_base import PassType
1827
from torch.export import ExportedProgram
1928
from torchao.quantization import quantize_
2029
from torchao.quantization.pt2e import allow_exported_model_train_eval
@@ -95,9 +104,7 @@ class ExportStage(Stage):
95104

96105
def __init__(
97106
self,
98-
pre_edge_transform_passes: Optional[
99-
Callable[[ExportedProgram], ExportedProgram]
100-
] = None,
107+
pre_edge_transform_passes: Optional[List[PassType]] = None,
101108
) -> None:
102109
self._exported_program: Dict[str, ExportedProgram] = {}
103110
self._pre_edge_transform_passes = pre_edge_transform_passes
@@ -153,10 +160,10 @@ def run(
153160
)
154161

155162
# Apply pre-edge transform passes if available
156-
if self._pre_edge_transform_passes is not None:
157-
for pre_edge_transform_pass in self._pre_edge_transform_passes:
158-
self._exported_program[method_name] = pre_edge_transform_pass(
159-
self._exported_program[method_name]
163+
if pre_edge_transform_passes := self._pre_edge_transform_passes or []:
164+
for pass_ in pre_edge_transform_passes:
165+
self._exported_program[method_name] = _transform(
166+
self._exported_program[method_name], pass_
160167
)
161168

162169
def get_artifacts(self) -> Dict[str, ExportedProgram]:

export/recipe.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from abc import ABCMeta, abstractmethod
7+
from dataclasses import dataclass
8+
from enum import Enum, EnumMeta
9+
from typing import List, Optional, Sequence
10+
11+
from executorch.exir._warnings import experimental
12+
13+
from executorch.exir.backend.partitioner import Partitioner
14+
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
15+
from executorch.exir.pass_manager import PassType
16+
from torchao.core.config import AOBaseConfig
17+
from torchao.quantization.pt2e.quantizer import Quantizer
18+
619

720
"""
821
Export recipe definitions for ExecuTorch.
@@ -11,18 +24,29 @@
1124
for ExecuTorch models, including export configurations and quantization recipes.
1225
"""
1326

14-
from dataclasses import dataclass
15-
from enum import Enum
16-
from typing import Callable, List, Optional, Sequence
1727

18-
from executorch.exir._warnings import experimental
28+
class RecipeTypeMeta(EnumMeta, ABCMeta):
29+
"""Metaclass that combines EnumMeta and ABCMeta"""
1930

20-
from executorch.exir.backend.partitioner import Partitioner
21-
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
22-
from executorch.exir.pass_manager import PassType
23-
from torch.export import ExportedProgram
24-
from torchao.core.config import AOBaseConfig
25-
from torchao.quantization.pt2e.quantizer import Quantizer
31+
pass
32+
33+
34+
class RecipeType(Enum, metaclass=RecipeTypeMeta):
35+
"""
36+
Base recipe type class that backends can extend to define their own recipe types.
37+
Backends should create their own enum classes that inherit from RecipeType:
38+
"""
39+
40+
@classmethod
41+
@abstractmethod
42+
def get_backend_name(cls) -> str:
43+
"""
44+
Return the backend name for this recipe type.
45+
46+
Returns:
47+
str: The backend name (e.g., "xnnpack", "qnn", etc.)
48+
"""
49+
pass
2650

2751

2852
class Mode(str, Enum):
@@ -52,7 +76,7 @@ class QuantizationRecipe:
5276
quantizers: Optional[List[Quantizer]] = None
5377
ao_base_config: Optional[List[AOBaseConfig]] = None
5478

55-
def get_quantizers(self) -> Optional[Quantizer]:
79+
def get_quantizers(self) -> Optional[List[Quantizer]]:
5680
"""
5781
Get the quantizer associated with this recipe.
5882
@@ -89,17 +113,40 @@ class ExportRecipe:
89113

90114
name: Optional[str] = None
91115
quantization_recipe: Optional[QuantizationRecipe] = None
92-
edge_compile_config: Optional[EdgeCompileConfig] = (
93-
None # pyre-ignore[11]: Type not defined
94-
)
95-
pre_edge_transform_passes: Optional[
96-
Callable[[ExportedProgram], ExportedProgram]
97-
| List[Callable[[ExportedProgram], ExportedProgram]]
98-
] = None
116+
# pyre-ignore[11]: Type not defined
117+
edge_compile_config: Optional[EdgeCompileConfig] = None
118+
pre_edge_transform_passes: Optional[Sequence[PassType]] = None
99119
edge_transform_passes: Optional[Sequence[PassType]] = None
100120
transform_check_ir_validity: bool = True
101121
partitioners: Optional[List[Partitioner]] = None
102-
executorch_backend_config: Optional[ExecutorchBackendConfig] = (
103-
None # pyre-ignore[11]: Type not defined
104-
)
122+
# pyre-ignore[11]: Type not defined
123+
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
105124
mode: Mode = Mode.RELEASE
125+
126+
@classmethod
127+
def get_recipe(cls, recipe: "RecipeType", **kwargs) -> "ExportRecipe":
128+
"""
129+
Get an export recipe from backend. Backend is automatically determined based on the
130+
passed recipe type.
131+
132+
Args:
133+
recipe: The type of recipe to create
134+
**kwargs: Recipe-specific parameters
135+
136+
Returns:
137+
ExportRecipe configured for the specified recipe type
138+
"""
139+
from .recipe_registry import recipe_registry
140+
141+
if not isinstance(recipe, RecipeType):
142+
raise ValueError(f"Invalid recipe type: {recipe}")
143+
144+
backend = recipe.get_backend_name()
145+
export_recipe = recipe_registry.create_recipe(recipe, backend, **kwargs)
146+
if export_recipe is None:
147+
supported = recipe_registry.get_supported_recipes(backend)
148+
raise ValueError(
149+
f"Recipe '{recipe.value}' not supported by '{backend}'. "
150+
f"Supported: {[r.value for r in supported]}"
151+
)
152+
return export_recipe

export/recipe_provider.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
Recipe registry for managing backend recipe providers.
11+
12+
This module provides the registry system for backend recipe providers and
13+
the abstract interface that all backends must implement.
14+
"""
15+
16+
from abc import ABC, abstractmethod
17+
from typing import Any, Optional, Sequence
18+
19+
from .recipe import ExportRecipe, RecipeType
20+
21+
22+
class BackendRecipeProvider(ABC):
23+
"""
24+
Abstract recipe provider that all backends must implement
25+
"""
26+
27+
@property
28+
@abstractmethod
29+
def backend_name(self) -> str:
30+
"""
31+
Name of the backend (ex: 'xnnpack', 'qnn' etc)
32+
"""
33+
pass
34+
35+
@abstractmethod
36+
def get_supported_recipes(self) -> Sequence[RecipeType]:
37+
"""
38+
Get list of supported recipes.
39+
"""
40+
pass
41+
42+
@abstractmethod
43+
def create_recipe(
44+
self, recipe_type: RecipeType, **kwargs: Any
45+
) -> Optional[ExportRecipe]:
46+
"""
47+
Create a recipe for the given type.
48+
Returns None if the recipe is not supported by this backend.
49+
50+
Args:
51+
recipe_type: The type of recipe to create
52+
**kwargs: Recipe-specific parameters (ex: group_size)
53+
54+
Returns:
55+
ExportRecipe if supported, None otherwise
56+
"""
57+
pass
58+
59+
def supports_recipe(self, recipe_type: RecipeType) -> bool:
60+
return recipe_type in self.get_supported_recipes()

export/recipe_registry.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Recipe registry for managing backend recipe providers.
9+
10+
This module provides the registry system for backend recipe providers and
11+
the abstract interface that all backends must implement.
12+
"""
13+
14+
from typing import Any, Dict, Optional, Sequence
15+
16+
from .recipe import ExportRecipe, RecipeType
17+
from .recipe_provider import BackendRecipeProvider
18+
19+
20+
class RecipeRegistry:
21+
"""Global registry for all backend recipe providers"""
22+
23+
_instance = None
24+
_initialized = False
25+
26+
def __new__(cls):
27+
if cls._instance is None:
28+
cls._instance = super().__new__(cls)
29+
return cls._instance
30+
31+
def __init__(self) -> None:
32+
# Only initialize once to avoid resetting state on subsequent calls
33+
if not RecipeRegistry._initialized:
34+
self._providers: Dict[str, BackendRecipeProvider] = {}
35+
RecipeRegistry._initialized = True
36+
37+
def register_backend_recipe_provider(self, provider: BackendRecipeProvider) -> None:
38+
"""
39+
Register a backend recipe provider
40+
"""
41+
self._providers[provider.backend_name] = provider
42+
43+
def create_recipe(
44+
self, recipe_type: RecipeType, backend: str, **kwargs: Any
45+
) -> Optional[ExportRecipe]:
46+
"""
47+
Create a recipe for a specific backend.
48+
49+
Args:
50+
recipe_type: The type of recipe to create
51+
backend: Backend name
52+
**kwargs: Recipe-specific parameters
53+
54+
Returns:
55+
ExportRecipe if supported, None if not supported
56+
"""
57+
if backend not in self._providers:
58+
raise ValueError(
59+
f"Backend '{backend}' not available. Available: {list(self._providers.keys())}"
60+
)
61+
62+
return self._providers[backend].create_recipe(recipe_type, **kwargs)
63+
64+
def get_supported_recipes(self, backend: str) -> Sequence[RecipeType]:
65+
"""
66+
Get list of recipes supported by a backend.
67+
68+
Args:
69+
backend: Backend name
70+
71+
Returns:
72+
List of supported recipe types
73+
"""
74+
if backend not in self._providers:
75+
raise ValueError(f"Backend '{backend}' not available")
76+
return self._providers[backend].get_supported_recipes()
77+
78+
def list_backends(self) -> Sequence[str]:
79+
"""
80+
Get list of all registered backends
81+
"""
82+
return list(self._providers.keys())
83+
84+
85+
# initialize recipe registry
86+
recipe_registry = RecipeRegistry()

0 commit comments

Comments
 (0)