14
14
import logging
15
15
from enum import Enum
16
16
from typing import Any , Callable , Dict , List , Optional , Tuple
17
- from unittest .mock import patch
18
17
19
18
import torch
20
19
from executorch .backends .transforms .duplicate_dynamic_quant_chain import (
@@ -96,7 +95,6 @@ def __init__(
96
95
verbose : bool = False ,
97
96
metadata : Optional [dict ] = None ,
98
97
dynamic_shapes : Optional [Any ] = None ,
99
- use_legacy_export : bool = False ,
100
98
save_exported_program : bool = False ,
101
99
):
102
100
# Store necessary constructor arguments.
@@ -117,7 +115,6 @@ def __init__(
117
115
self .verbose = verbose
118
116
self .metadata = metadata
119
117
self .dynamic_shapes = dynamic_shapes
120
- self .use_legacy_export = use_legacy_export
121
118
self .save_exported_program = save_exported_program
122
119
123
120
# Note: treat this as the source of truth for the result of
@@ -228,39 +225,20 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
228
225
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
229
226
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
230
227
with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
231
- if self .use_legacy_export :
232
- # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
233
- # See issue: https://github.com/pytorch/executorch/issues/7373
234
-
235
- with patch .object (
236
- torch ._utils_internal ,
237
- "export_training_ir_rollout_check" ,
238
- return_value = False ,
239
- ):
240
- # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
241
- # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
242
- exported_module = torch .export .export (
243
- self .model if not module else module ,
244
- self .example_inputs ,
245
- self .example_kwarg_inputs ,
246
- dynamic_shapes = dynamic_shape ,
247
- strict = True ,
248
- )
228
+ if module :
229
+ logging .info ("Re-exporting with:" )
249
230
else :
250
- if module :
251
- logging .info ("Re-exporting with:" )
252
- else :
253
- logging .info ("Exporting with:" )
254
- logging .info (f"inputs: { self .example_inputs } " )
255
- logging .info (f"kwargs: { self .example_kwarg_inputs } " )
256
- logging .info (f"dynamic shapes: { dynamic_shape } " )
257
- exported_module = export_for_training (
258
- self .model if not module else module ,
259
- self .example_inputs ,
260
- kwargs = self .example_kwarg_inputs ,
261
- dynamic_shapes = dynamic_shape ,
262
- strict = True ,
263
- )
231
+ logging .info ("Exporting with:" )
232
+ logging .info (f"inputs: { self .example_inputs } " )
233
+ logging .info (f"kwargs: { self .example_kwarg_inputs } " )
234
+ logging .info (f"dynamic shapes: { dynamic_shape } " )
235
+ exported_module = export_for_training (
236
+ self .model if not module else module ,
237
+ self .example_inputs ,
238
+ kwargs = self .example_kwarg_inputs ,
239
+ dynamic_shapes = dynamic_shape ,
240
+ strict = True ,
241
+ )
264
242
return exported_module
265
243
266
244
def export (self ) -> "LLMEdgeManager" :
@@ -446,13 +424,6 @@ def export_to_edge(self) -> "LLMEdgeManager":
446
424
self .export ()
447
425
448
426
override_export_behaviour = contextlib .nullcontext ()
449
- if self .use_legacy_export :
450
- override_export_behaviour = patch .object (
451
- torch ._utils_internal ,
452
- "export_training_ir_rollout_check" ,
453
- return_value = False ,
454
- )
455
-
456
427
with override_export_behaviour :
457
428
self .edge_manager = export_to_edge (
458
429
self .pre_autograd_graph_module , # pyre-fixme[6]
0 commit comments