|
338 | 338 | # `AOTI documentation <https://pytorch.org/docs/main/torch.compiler_aot_inductor.html>`_:
|
339 | 339 | #
|
340 | 340 |
|
341 |
| -from tempfile import TemporaryDirectory |
342 |
| - |
343 |
| -from torch._inductor import aoti_compile_and_package, aoti_load_package |
344 |
| - |
345 |
| -with TemporaryDirectory() as tmpdir: |
346 |
| - path = str(Path(tmpdir) / "model.pt2") |
347 |
| - with torch.no_grad(): |
348 |
| - pkg_path = aoti_compile_and_package( |
349 |
| - exported_policy, |
350 |
| - args=(), |
351 |
| - kwargs={"pixels": pixels}, |
352 |
| - # Specify the generated shared library path |
353 |
| - package_path=path, |
354 |
| - ) |
355 |
| - print("pkg_path", pkg_path) |
356 |
| - |
357 |
| - compiled_module = aoti_load_package(pkg_path) |
358 |
| - |
359 |
| -print(compiled_module(pixels=pixels)) |
360 |
| - |
361 |
| -##################################### |
362 |
| -# An extra feature of AOTInductor is its capacity of dealing with dynamic shapes. This can be useful if you don't know |
363 |
| -# the shape of your input data ahead of time. For instance, we may want to run our policy for one, two or more |
364 |
| -# observations at a time. For this, let us re-export our policy, marking a new unsqueezed batch dimension as dynamic: |
365 |
| - |
366 |
| -batch_dim = torch.export.Dim("batch", min=1, max=32) |
367 |
| -pixels_unsqueeze = pixels.unsqueeze(0) |
368 |
| -exported_dynamic_policy = torch.export.export( |
369 |
| - policy_transform, |
370 |
| - args=(), |
371 |
| - kwargs={"pixels": pixels_unsqueeze}, |
372 |
| - strict=False, |
373 |
| - dynamic_shapes={"pixels": {0: batch_dim}}, |
374 |
| -) |
375 |
| -# Then recompile and export |
376 |
| -pkg_path = aoti_compile_and_package( |
377 |
| - exported_dynamic_policy, |
378 |
| - args=(), |
379 |
| - kwargs={"pixels": pixels_unsqueeze}, |
380 |
| - package_path=path, |
381 |
| -) |
| 341 | +# from tempfile import TemporaryDirectory |
| 342 | +# |
| 343 | +# from torch._inductor import aoti_compile_and_package, aoti_load_package |
| 344 | +# |
| 345 | +# with TemporaryDirectory() as tmpdir: |
| 346 | +# path = str(Path(tmpdir) / "model.pt2") |
| 347 | +# with torch.no_grad(): |
| 348 | +# pkg_path = aoti_compile_and_package( |
| 349 | +# exported_policy, |
| 350 | +# args=(), |
| 351 | +# kwargs={"pixels": pixels}, |
| 352 | +# # Specify the generated shared library path |
| 353 | +# package_path=path, |
| 354 | +# ) |
| 355 | +# print("pkg_path", pkg_path) |
| 356 | +# |
| 357 | +# compiled_module = aoti_load_package(pkg_path) |
| 358 | +# |
| 359 | +# print(compiled_module(pixels=pixels)) |
382 | 360 |
|
383 | 361 | #####################################
|
384 |
| -# More information about this can be found in the |
385 |
| -# `AOTInductor tutorial <https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`_. |
386 | 362 | #
|
387 | 363 | # Exporting TorchRL models with ONNX
|
388 | 364 | # ----------------------------------
|
|
0 commit comments