Skip to content

Commit 4156249

Browse files
authored
Fixes device parsing in policy inference tutorial (#2250)
# Description The model in the policy inference tutorial needs to be moved to the correct device based on the command line argument input. Otherwise, it will cause device mismatch errors as policy defaults to the CPU while environment defaults to GPU. ## Type of change <!-- As you go through the list, delete the ones that are not applicable. --> - Bug fix (non-breaking change which fixes an issue) ## Screenshots Please attach before and after screenshots of the change if applicable. <!-- Example: | Before | After | | ------ | ----- | | _gif/png before_ | _gif/png after_ | To upload images to a PR -- simply drag and drop an image while in edit mode and it should upload the image directly. You can then paste that source into the above before/after sections. --> ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [ ] I have added my name to the `CONTRIBUTORS.md` or my name already exists there <!-- As you go through the checklist above, you can mark something as done by putting an x character in it For example, - [x] I have done this task - [ ] I have not done this task -->
1 parent 01d6d5c commit 4156249

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

scripts/tutorials/03_envs/policy_inference_in_usd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def main():
5656
policy_path = os.path.abspath(args_cli.checkpoint)
5757
file_content = omni.client.read_file(policy_path)[2]
5858
file = io.BytesIO(memoryview(file_content).tobytes())
59-
policy = torch.jit.load(file)
59+
policy = torch.jit.load(file, map_location=args_cli.device)
6060

6161
# setup environment
6262
env_cfg = H1RoughEnvCfg_PLAY()

0 commit comments

Comments
 (0)