Skip to content

Commit 304e707

Browse files
author
Vincent Moens
committed
[Doc] torchrl_demo.py revamp
ghstack-source-id: 2f00878 Pull Request resolved: #2561
1 parent 2f3b4cd commit 304e707

File tree

2 files changed

+207
-168
lines changed

2 files changed

+207
-168
lines changed

tutorials/sphinx-tutorials/export.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from pathlib import Path
5252

5353
import numpy as np
54-
import tensordict.utils
5554

5655
import torch
5756

@@ -360,6 +359,31 @@
360359
print(compiled_module(pixels=pixels))
361360

362361
#####################################
362+
# An extra feature of AOTInductor is its capacity of dealing with dynamic shapes. This can be useful if you don't know
363+
# the shape of your input data ahead of time. For instance, we may want to run our policy for one, two or more
364+
# observations at a time. For this, let us re-export our policy, marking a new unsqueezed batch dimension as dynamic:
365+
366+
batch_dim = torch.export.Dim("batch", min=1, max=32)
367+
pixels_unsqueeze = pixels.unsqueeze(0)
368+
exported_dynamic_policy = torch.export.export(
369+
policy_transform,
370+
args=(),
371+
kwargs={"pixels": pixels_unsqueeze},
372+
strict=False,
373+
dynamic_shapes={"pixels": {0: batch_dim}},
374+
)
375+
# Then recompile and export
376+
pkg_path = aoti_compile_and_package(
377+
exported_dynamic_policy,
378+
args=(),
379+
kwargs={"pixels": pixels_unsqueeze},
380+
package_path=path,
381+
)
382+
383+
#####################################
384+
# More information about this can be found in the
385+
# `AOTInductor tutorial <https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`_.
386+
#
363387
# Exporting TorchRL models with ONNX
364388
# ----------------------------------
365389
#

0 commit comments

Comments
 (0)