16
16
from typing import Any , Dict , Optional , Tuple , Union
17
17
18
18
import torch
19
- from torch ._export import capture_pre_autograd_graph
20
19
from torch .fx .graph_module import GraphModule
21
20
22
21
from neural_compressor .common .utils import logger
23
- from neural_compressor .torch .utils import TORCH_VERSION_2_2_2 , get_torch_version , is_ipex_imported
22
+ from neural_compressor .torch .utils import TORCH_VERSION_2_2_2 , TORCH_VERSION_2_7_0 , get_torch_version , is_ipex_imported
24
23
25
24
__all__ = ["export" , "export_model_for_pt2e_quant" ]
26
25
@@ -53,6 +52,10 @@ def export_model_for_pt2e_quant(
53
52
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
54
53
# updated to use the official `torch.export` API when that is ready.
55
54
cur_version = get_torch_version ()
55
+ if cur_version >= TORCH_VERSION_2_7_0 :
56
+ export_func = torch .export .export_for_training
57
+ else :
58
+ export_func = torch ._export .capture_pre_autograd_graph
56
59
if cur_version <= TORCH_VERSION_2_2_2 : # pragma: no cover
57
60
logger .warning (
58
61
(
@@ -62,11 +65,13 @@ def export_model_for_pt2e_quant(
62
65
),
63
66
cur_version ,
64
67
)
65
- exported_model = capture_pre_autograd_graph (model , args = example_inputs )
68
+ exported_model = export_func (model , args = example_inputs )
66
69
else :
67
- exported_model = capture_pre_autograd_graph ( # pylint: disable=E1123
70
+ exported_model = export_func ( # pylint: disable=E1123
68
71
model , args = example_inputs , dynamic_shapes = dynamic_shapes
69
72
)
73
+ if cur_version >= TORCH_VERSION_2_7_0 :
74
+ exported_model = exported_model .module ()
70
75
exported_model ._exported = True
71
76
logger .info ("Exported the model to Aten IR successfully." )
72
77
except Exception as e :
0 commit comments