Skip to content

Commit 9d292a0

Browse files
author
Vincent Moens
committed
[BugFix,Doc] Revert dynamic shape in export tutorial
ghstack-source-id: fc85621 Pull Request resolved: #2563
1 parent 304e707 commit 9d292a0

File tree

1 file changed

+19
-43
lines changed

1 file changed

+19
-43
lines changed

tutorials/sphinx-tutorials/export.py

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -338,51 +338,27 @@
338338
# `AOTI documentation <https://pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_:
339339
#
340340

341-
from tempfile import TemporaryDirectory
342-
343-
from torch._inductor import aoti_compile_and_package, aoti_load_package
344-
345-
with TemporaryDirectory() as tmpdir:
346-
path = str(Path(tmpdir) / "model.pt2")
347-
with torch.no_grad():
348-
pkg_path = aoti_compile_and_package(
349-
exported_policy,
350-
args=(),
351-
kwargs={"pixels": pixels},
352-
# Specify the generated shared library path
353-
package_path=path,
354-
)
355-
print("pkg_path", pkg_path)
356-
357-
compiled_module = aoti_load_package(pkg_path)
358-
359-
print(compiled_module(pixels=pixels))
360-
361-
#####################################
362-
# An extra feature of AOTInductor is its capacity of dealing with dynamic shapes. This can be useful if you don't know
363-
# the shape of your input data ahead of time. For instance, we may want to run our policy for one, two or more
364-
# observations at a time. For this, let us re-export our policy, marking a new unsqueezed batch dimension as dynamic:
365-
366-
batch_dim = torch.export.Dim("batch", min=1, max=32)
367-
pixels_unsqueeze = pixels.unsqueeze(0)
368-
exported_dynamic_policy = torch.export.export(
369-
policy_transform,
370-
args=(),
371-
kwargs={"pixels": pixels_unsqueeze},
372-
strict=False,
373-
dynamic_shapes={"pixels": {0: batch_dim}},
374-
)
375-
# Then recompile and export
376-
pkg_path = aoti_compile_and_package(
377-
exported_dynamic_policy,
378-
args=(),
379-
kwargs={"pixels": pixels_unsqueeze},
380-
package_path=path,
381-
)
341+
# from tempfile import TemporaryDirectory
342+
#
343+
# from torch._inductor import aoti_compile_and_package, aoti_load_package
344+
#
345+
# with TemporaryDirectory() as tmpdir:
346+
# path = str(Path(tmpdir) / "model.pt2")
347+
# with torch.no_grad():
348+
# pkg_path = aoti_compile_and_package(
349+
# exported_policy,
350+
# args=(),
351+
# kwargs={"pixels": pixels},
352+
# # Specify the generated shared library path
353+
# package_path=path,
354+
# )
355+
# print("pkg_path", pkg_path)
356+
#
357+
# compiled_module = aoti_load_package(pkg_path)
358+
#
359+
# print(compiled_module(pixels=pixels))
382360

383361
#####################################
384-
# More information about this can be found in the
385-
# `AOTInductor tutorial <https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`_.
386362
#
387363
# Exporting TorchRL models with ONNX
388364
# ----------------------------------

0 commit comments

Comments
 (0)