Skip to content

Commit da5618b

Browse files
authored
Fixes device settings in env tutorials (#2151)
# Description The environment examples were not setting the device properly. ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there Signed-off-by: Mayank Mittal <12863862+Mayankm96@users.noreply.github.com>
1 parent ba31408 commit da5618b

File tree

5 files changed

+17
-5
lines changed

5 files changed

+17
-5
lines changed

scripts/tutorials/03_envs/create_cartpole_base_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def main():
140140
# parse the arguments
141141
env_cfg = CartpoleEnvCfg()
142142
env_cfg.scene.num_envs = args_cli.num_envs
143+
env_cfg.sim.device = args_cli.device
143144
# setup base environment
144145
env = ManagerBasedEnv(cfg=env_cfg)
145146

scripts/tutorials/03_envs/create_cube_base_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def __post_init__(self):
304304
self.sim.dt = 0.01
305305
self.sim.physics_material = self.scene.terrain.physics_material
306306
self.sim.render_interval = 2 # render interval should be a multiple of decimation
307+
self.sim.device = args_cli.device
307308
# viewer settings
308309
self.viewer.eye = (5.0, 5.0, 5.0)
309310
self.viewer.lookat = (0.0, 0.0, 2.0)

scripts/tutorials/03_envs/create_quadruped_base_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __post_init__(self):
194194
# simulation settings
195195
self.sim.dt = 0.005 # simulation timestep -> 200 Hz physics
196196
self.sim.physics_material = self.scene.terrain.physics_material
197+
self.sim.device = args_cli.device
197198
# update sensor update periods
198199
# we tick all the sensors based on the smallest update period (physics update period)
199200
if self.scene.height_scanner is not None:

scripts/tutorials/03_envs/policy_inference_in_usd.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def main():
5757
file_content = omni.client.read_file(policy_path)[2]
5858
file = io.BytesIO(memoryview(file_content).tobytes())
5959
policy = torch.jit.load(file)
60+
61+
# setup environment
6062
env_cfg = H1RoughEnvCfg_PLAY()
6163
env_cfg.scene.num_envs = 1
6264
env_cfg.curriculum = None
@@ -65,13 +67,19 @@ def main():
6567
terrain_type="usd",
6668
usd_path=f"{ISAAC_NUCLEUS_DIR}/Environments/Simple_Warehouse/warehouse.usd",
6769
)
68-
env_cfg.sim.device = "cpu"
69-
env_cfg.sim.use_fabric = False
70+
env_cfg.sim.device = args_cli.device
71+
if args_cli.device == "cpu":
72+
env_cfg.sim.use_fabric = False
73+
74+
# create environment
7075
env = ManagerBasedRLEnv(cfg=env_cfg)
76+
77+
# run inference with the policy
7178
obs, _ = env.reset()
72-
while simulation_app.is_running():
73-
action = policy(obs["policy"]) # run inference
74-
obs, _, _, _, _ = env.step(action)
79+
with torch.inference_mode():
80+
while simulation_app.is_running():
81+
action = policy(obs["policy"])
82+
obs, _, _, _, _ = env.step(action)
7583

7684

7785
if __name__ == "__main__":

scripts/tutorials/03_envs/run_cartpole_rl_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def main():
4545
# create environment configuration
4646
env_cfg = CartpoleEnvCfg()
4747
env_cfg.scene.num_envs = args_cli.num_envs
48+
env_cfg.sim.device = args_cli.device
4849
# setup RL environment
4950
env = ManagerBasedRLEnv(cfg=env_cfg)
5051

0 commit comments

Comments
 (0)