Skip to content

Commit 45d7520

Browse files
authored
Remove the legacy export
Differential Revision: D77761473 Pull Request resolved: #12218
1 parent d952326 commit 45d7520

File tree

2 files changed

+13
-43
lines changed

2 files changed

+13
-43
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,6 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
12161216
calibration_seq_length=llm_config.quantization.calibration_seq_length,
12171217
calibration_data=llm_config.quantization.calibration_data,
12181218
tokenizer_path=llm_config.base.tokenizer_path,
1219-
use_legacy_export=llm_config.backend.qnn.enabled,
12201219
save_exported_program=llm_config.export.export_only,
12211220
verbose=llm_config.debug.verbose,
12221221
metadata=_load_llama_model_metadata(

extension/llm/export/builder.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import logging
1515
from enum import Enum
1616
from typing import Any, Callable, Dict, List, Optional, Tuple
17-
from unittest.mock import patch
1817

1918
import torch
2019
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -96,7 +95,6 @@ def __init__(
9695
verbose: bool = False,
9796
metadata: Optional[dict] = None,
9897
dynamic_shapes: Optional[Any] = None,
99-
use_legacy_export: bool = False,
10098
save_exported_program: bool = False,
10199
):
102100
# Store necessary constructor arguments.
@@ -117,7 +115,6 @@ def __init__(
117115
self.verbose = verbose
118116
self.metadata = metadata
119117
self.dynamic_shapes = dynamic_shapes
120-
self.use_legacy_export = use_legacy_export
121118
self.save_exported_program = save_exported_program
122119

123120
# Note: treat this as the source of truth for the result of
@@ -228,39 +225,20 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
228225
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
229226
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
230227
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
231-
if self.use_legacy_export:
232-
# TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
233-
# See issue: https://github.com/pytorch/executorch/issues/7373
234-
235-
with patch.object(
236-
torch._utils_internal,
237-
"export_training_ir_rollout_check",
238-
return_value=False,
239-
):
240-
# TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
241-
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
242-
exported_module = torch.export.export(
243-
self.model if not module else module,
244-
self.example_inputs,
245-
self.example_kwarg_inputs,
246-
dynamic_shapes=dynamic_shape,
247-
strict=True,
248-
)
228+
if module:
229+
logging.info("Re-exporting with:")
249230
else:
250-
if module:
251-
logging.info("Re-exporting with:")
252-
else:
253-
logging.info("Exporting with:")
254-
logging.info(f"inputs: {self.example_inputs}")
255-
logging.info(f"kwargs: {self.example_kwarg_inputs}")
256-
logging.info(f"dynamic shapes: {dynamic_shape}")
257-
exported_module = export_for_training(
258-
self.model if not module else module,
259-
self.example_inputs,
260-
kwargs=self.example_kwarg_inputs,
261-
dynamic_shapes=dynamic_shape,
262-
strict=True,
263-
)
231+
logging.info("Exporting with:")
232+
logging.info(f"inputs: {self.example_inputs}")
233+
logging.info(f"kwargs: {self.example_kwarg_inputs}")
234+
logging.info(f"dynamic shapes: {dynamic_shape}")
235+
exported_module = export_for_training(
236+
self.model if not module else module,
237+
self.example_inputs,
238+
kwargs=self.example_kwarg_inputs,
239+
dynamic_shapes=dynamic_shape,
240+
strict=True,
241+
)
264242
return exported_module
265243

266244
def export(self) -> "LLMEdgeManager":
@@ -446,13 +424,6 @@ def export_to_edge(self) -> "LLMEdgeManager":
446424
self.export()
447425

448426
override_export_behaviour = contextlib.nullcontext()
449-
if self.use_legacy_export:
450-
override_export_behaviour = patch.object(
451-
torch._utils_internal,
452-
"export_training_ir_rollout_check",
453-
return_value=False,
454-
)
455-
456427
with override_export_behaviour:
457428
self.edge_manager = export_to_edge(
458429
self.pre_autograd_graph_module, # pyre-fixme[6]

0 commit comments

Comments
 (0)