Skip to content

Commit 6f8d08d

Browse files
authored
Update export function for torch 2.6 (#2184)
Signed-off-by: Yi Liu <yiliu4@habana.ai>
1 parent 35ae68c commit 6f8d08d

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

neural_compressor/torch/export/pt2e_export.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
from typing import Any, Dict, Optional, Tuple, Union
1717

1818
import torch
19-
from torch._export import capture_pre_autograd_graph
2019
from torch.fx.graph_module import GraphModule
2120

2221
from neural_compressor.common.utils import logger
23-
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, get_torch_version, is_ipex_imported
22+
from neural_compressor.torch.utils import TORCH_VERSION_2_2_2, TORCH_VERSION_2_7_0, get_torch_version, is_ipex_imported
2423

2524
__all__ = ["export", "export_model_for_pt2e_quant"]
2625

@@ -53,6 +52,10 @@ def export_model_for_pt2e_quant(
5352
# Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be
5453
# updated to use the official `torch.export` API when that is ready.
5554
cur_version = get_torch_version()
55+
if cur_version >= TORCH_VERSION_2_7_0:
56+
export_func = torch.export.export_for_training
57+
else:
58+
export_func = torch._export.capture_pre_autograd_graph
5659
if cur_version <= TORCH_VERSION_2_2_2: # pragma: no cover
5760
logger.warning(
5861
(
@@ -62,11 +65,13 @@ def export_model_for_pt2e_quant(
6265
),
6366
cur_version,
6467
)
65-
exported_model = capture_pre_autograd_graph(model, args=example_inputs)
68+
exported_model = export_func(model, args=example_inputs)
6669
else:
67-
exported_model = capture_pre_autograd_graph( # pylint: disable=E1123
70+
exported_model = export_func( # pylint: disable=E1123
6871
model, args=example_inputs, dynamic_shapes=dynamic_shapes
6972
)
73+
if cur_version >= TORCH_VERSION_2_7_0:
74+
exported_model = exported_model.module()
7075
exported_model._exported = True
7176
logger.info("Exported the model to Aten IR successfully.")
7277
except Exception as e:

neural_compressor/torch/utils/environ.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def get_ipex_version():
129129

130130

131131
TORCH_VERSION_2_2_2 = Version("2.2.2")
132+
TORCH_VERSION_2_7_0 = Version("2.7.0")
132133

133134

134135
def get_torch_version():

0 commit comments

Comments
 (0)