Skip to content
Merged
12 changes: 7 additions & 5 deletions docs/source/using-executorch-export.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,16 @@ To generate a `model.pte`, `model.ptd` pair with the weights inside `model.ptd`,

```python
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass,
delegate_external_constants_pass_unlifted,
)
partial_function = partial(
delegate_external_constants_pass,
ep=exported_program,
# Tag the unlifted ep.module().
tagged_module = exported_program.module()
delegate_external_constants_pass_unlifted(
module=tagged_module,
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
)

# Re-export to get the EP.
exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes)
executorch_program = to_edge_transform_and_lower(
exported_program,
transform_passes = [partial_function],
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901

if llm_config.backend.xnnpack.enabled:
if llm_config.export.foundation_weights_file is not None:
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
llm_config.export.foundation_weights_file
if "lora" not in x.name
else None
Expand All @@ -1089,8 +1089,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
delegate_external_constants_pass_unlifted,
)

assert (
builder_exported.pre_autograd_graph_module is not None
), "pre_autograd_graph_module shouldn't be None here"
delegate_external_constants_pass_unlifted(
gm=builder_exported.pre_autograd_graph_module,
module=builder_exported.pre_autograd_graph_module,
gen_tag_fn=gen_tag_fn,
)

Expand Down
43 changes: 6 additions & 37 deletions exir/passes/external_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,53 +88,22 @@ def external_mutable_weights_pass(
return PassResult(gm, mutated)


def delegate_external_constants_pass(
gm: GraphModule,
ep: ExportedProgram,
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
) -> PassResult:
"""
Tag external constants before to_backend.
Note: this pass must be run after run_decompositions(), as tags on
constants are removed then.
Args:
gm: GraphModule to tag.
ep: ExportedProgram, to distinguish if a node is a constant.
gen_tag_fn: node -> str callable indicating the tag for the node.
Returns:
PassResult: The resulting gm, and if it was mutated or not.
"""
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.op == "placeholder" and is_param_node(ep, node):
if gen_tag_fn is not None:
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)


# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
# and not on a lifted graph, e.g. ep.graph_module.
# This is using 'get_attr' to tag constants, which only appears in
# unlifted graphs.
def delegate_external_constants_pass_unlifted(
gm: GraphModule,
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
module: torch.nn.Module,
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
) -> PassResult:
mutated = False
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
for m in module.modules():
if not isinstance(m, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
for node in m.graph.nodes:
if node.op == "get_attr":
if gen_tag_fn is not None:
node.meta.setdefault("custom", {})
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
mutated = True
return PassResult(gm, mutated)
return PassResult(module, mutated)
13 changes: 5 additions & 8 deletions test/models/export_delegated_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import os
import sys

from functools import partial
from typing import Dict, final, Optional, Sequence, Type

import executorch.exir as exir
Expand All @@ -28,7 +27,7 @@
ExecutorBackend,
)
from executorch.exir.passes.external_constants_pass import (
delegate_external_constants_pass,
delegate_external_constants_pass_unlifted,
)
from executorch.exir.program import ExecutorchProgramManager
from torch import nn
Expand Down Expand Up @@ -173,17 +172,15 @@ def forward(self, *args, **kwargs):
XnnpackPartitioner,
)

transform_passes = []
if external_constants:
partial_function = partial(
delegate_external_constants_pass,
ep=exported_program,
tagged_module = exported_program.module()
delegate_external_constants_pass_unlifted(
module=tagged_module,
gen_tag_fn=lambda x: module_class.__name__,
)
transform_passes.append(partial_function)
exported_program = export(tagged_module, args=inputs, strict=True)
executorch_program = to_edge_transform_and_lower(
exported_program,
transform_passes=transform_passes,
compile_config=edge_config,
partitioner=[XnnpackPartitioner()],
).to_executorch(config=et_config)
Expand Down
Loading