Skip to content

Commit cf82161

Browse files
authored
Remove torch.export.export_for_inference
Differential Revision: D71069057 Pull Request resolved: #1877
1 parent 64bcf4c commit cf82161

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

examples/sam2_amg_server/compile_export_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,10 @@ def aot_compile(
119119
"triton.cudagraphs": True,
120120
}
121121

122-
from torch.export import export_for_inference
122+
from torch.export import export_for_training
123123

124-
exported = export_for_inference(fn, sample_args, sample_kwargs)
124+
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
125+
exported.run_decompositions()
125126
output_path = torch._inductor.aoti_compile_and_package(
126127
exported,
127128
package_path=str(path),

examples/sam2_vos_example/compile_export_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,10 @@ def aot_compile(
8282
"triton.cudagraphs": True,
8383
}
8484

85-
from torch.export import export_for_inference
85+
from torch.export import export_for_training
8686

87-
exported = export_for_inference(fn, sample_args, sample_kwargs)
87+
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
88+
exported.run_decompositions()
8889
output_path = torch._inductor.aoti_compile_and_package(
8990
exported,
9091
package_path=str(path),

0 commit comments

Comments
 (0)