Skip to content

Commit 23e2121

Browse files
author
Vincent Moens
authored
[BugFix] Fix Isaac (#2072)
1 parent 160a946 commit 23e2121

File tree

4 files changed

+72
-43
lines changed

4 files changed

+72
-43
lines changed

test/test_libs.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
import importlib
6-
import os
7-
from contextlib import nullcontext
8-
from pathlib import Path
9-
10-
from torchrl._utils import logger as torchrl_logger
11-
12-
from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay
13-
14-
from torchrl.envs.transforms import ActionMask, TransformedEnv
15-
from torchrl.modules import MaskedCategorical
5+
import importlib.util
166

177
_has_isaac = importlib.util.find_spec("isaacgym") is not None
188

@@ -21,11 +11,13 @@
2111
import isaacgym # noqa
2212
import isaacgymenvs # noqa
2313
from torchrl.envs.libs.isaacgym import IsaacGymEnv
24-
2514
import argparse
2615
import importlib
16+
import os
2717

2818
import time
19+
from contextlib import nullcontext
20+
from pathlib import Path
2921
from sys import platform
3022
from typing import Optional, Union
3123

@@ -57,7 +49,8 @@
5749
TensorDictSequential,
5850
)
5951
from torch import nn
60-
from torchrl._utils import implement_for
52+
53+
from torchrl._utils import implement_for, logger as torchrl_logger
6154
from torchrl.collectors.collectors import SyncDataCollector
6255
from torchrl.data import (
6356
BinaryDiscreteTensorSpec,
@@ -74,6 +67,8 @@
7467
)
7568
from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay
7669
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
70+
71+
from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay
7772
from torchrl.data.datasets.minari_data import MinariExperienceReplay
7873
from torchrl.data.datasets.openml import OpenMLExperienceReplay
7974
from torchrl.data.datasets.openx import OpenXExperienceReplay
@@ -114,13 +109,21 @@
114109
from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv
115110
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
116111
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
112+
113+
from torchrl.envs.transforms import ActionMask, TransformedEnv
117114
from torchrl.envs.utils import (
118115
check_env_specs,
119116
ExplorationType,
120117
MarlGroupMapType,
121118
RandomPolicy,
122119
)
123-
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
120+
from torchrl.modules import (
121+
ActorCriticOperator,
122+
MaskedCategorical,
123+
MLP,
124+
SafeModule,
125+
ValueOperator,
126+
)
124127

125128
_has_d4rl = importlib.util.find_spec("d4rl") is not None
126129

@@ -3084,22 +3087,28 @@ def test_data(self, dataset):
30843087
)
30853088
@pytest.mark.parametrize("num_envs", [10, 20])
30863089
@pytest.mark.parametrize("device", get_default_devices())
3090+
@pytest.mark.parametrize("from_pixels", [True, False])
30873091
class TestIsaacGym:
30883092
@classmethod
3089-
def _run_on_proc(cls, q, task, num_envs, device):
3093+
def _run_on_proc(cls, q, task, num_envs, device, from_pixels):
30903094
try:
3091-
env = IsaacGymEnv(task=task, num_envs=num_envs, device=device)
3095+
env = IsaacGymEnv(
3096+
task=task, num_envs=num_envs, device=device, from_pixels=from_pixels
3097+
)
30923098
check_env_specs(env)
30933099
q.put(("succeeded!", None))
30943100
except Exception as err:
30953101
q.put(("failed!", err))
30963102
raise err
30973103

3098-
def test_env(self, task, num_envs, device):
3104+
def test_env(self, task, num_envs, device, from_pixels):
30993105
from torch import multiprocessing as mp
31003106

31013107
q = mp.Queue(1)
3102-
proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device))
3108+
self._run_on_proc(q, task, num_envs, device, from_pixels)
3109+
proc = mp.Process(
3110+
target=self._run_on_proc, args=(q, task, num_envs, device, from_pixels)
3111+
)
31033112
try:
31043113
proc.start()
31053114
msg, error = q.get()

torchrl/envs/libs/gym.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,9 @@ def _reward_space(self, env): # noqa: F811
943943
return rs
944944

945945
def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
946+
# If batch_size is provided, we se it to tell what batch size must be used
947+
# instead of self.batch_size
948+
cur_batch_size = self.batch_size if batch_size is None else torch.Size([])
946949
action_spec = _gym_to_torchrl_spec_transform(
947950
env.action_space,
948951
device=self.device,
@@ -956,14 +959,14 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
956959
if not isinstance(observation_spec, CompositeSpec):
957960
if self.from_pixels:
958961
observation_spec = CompositeSpec(
959-
pixels=observation_spec, shape=self.batch_size
962+
pixels=observation_spec, shape=cur_batch_size
960963
)
961964
else:
962965
observation_spec = CompositeSpec(
963-
observation=observation_spec, shape=self.batch_size
966+
observation=observation_spec, shape=cur_batch_size
964967
)
965-
elif observation_spec.shape[: len(self.batch_size)] != self.batch_size:
966-
observation_spec.shape = self.batch_size
968+
elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size:
969+
observation_spec.shape = cur_batch_size
967970

968971
reward_space = self._reward_space(env)
969972
if reward_space is not None:
@@ -983,10 +986,11 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
983986
observation_spec = observation_spec.expand(
984987
*batch_size, *observation_spec.shape
985988
)
989+
986990
self.done_spec = self._make_done_spec()
987991
self.action_spec = action_spec
988-
if reward_spec.shape[: len(self.batch_size)] != self.batch_size:
989-
self.reward_spec = reward_spec.expand(*self.batch_size, *reward_spec.shape)
992+
if reward_spec.shape[: len(cur_batch_size)] != cur_batch_size:
993+
self.reward_spec = reward_spec.expand(*cur_batch_size, *reward_spec.shape)
990994
else:
991995
self.reward_spec = reward_spec
992996
self.observation_spec = observation_spec

torchrl/envs/libs/isaacgym.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515

1616
from tensordict import TensorDictBase
17+
from torchrl.data import CompositeSpec
1718
from torchrl.envs.libs.gym import GymWrapper
1819
from torchrl.envs.utils import _classproperty, make_composite_from_td
1920

@@ -49,19 +50,23 @@ def __init__(
4950
warnings.warn(
5051
"IsaacGym environment support is an experimental feature that may change in the future."
5152
)
52-
num_envs = env.num_envs
5353
super().__init__(
54-
env, torch.device(env.device), batch_size=torch.Size([num_envs]), **kwargs
54+
env, torch.device(env.device), batch_size=torch.Size([]), **kwargs
5555
)
5656
if not hasattr(self, "task"):
5757
# by convention in IsaacGymEnvs
5858
self.task = env.__name__
5959

6060
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
6161
super()._make_specs(env, batch_size=self.batch_size)
62-
self.full_done_spec = {
63-
key: spec.squeeze(-1) for key, spec in self.full_done_spec.items(True, True)
64-
}
62+
self.full_done_spec = CompositeSpec(
63+
{
64+
key: spec.squeeze(-1)
65+
for key, spec in self.full_done_spec.items(True, True)
66+
},
67+
shape=self.batch_size,
68+
)
69+
6570
self.observation_spec["obs"] = self.observation_spec["observation"]
6671
del self.observation_spec["observation"]
6772

@@ -78,7 +83,18 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
7883
obs_spec.unlock_()
7984
obs_spec.update(specs)
8085
obs_spec.lock_()
81-
self.__dict__["full_observation_spec"] = obs_spec
86+
87+
def _output_transform(self, output):
88+
obs, reward, done, info = output
89+
if self.from_pixels:
90+
obs["pixels"] = self._env.render(mode="rgb_array")
91+
return obs, reward, done ^ done, done, done, info
92+
93+
def _reset_output_transform(self, reset_data):
94+
reset_data.pop("reward", None)
95+
if self.from_pixels:
96+
reset_data["pixels"] = self._env.render(mode="rgb_array")
97+
return reset_data, {}
8298

8399
@classmethod
84100
def _make_envs(cls, *, task, num_envs, device, seed=None, headless=True, **kwargs):
@@ -125,15 +141,8 @@ def read_done(
125141
done = done.bool()
126142
return terminated, truncated, done, done.any()
127143

128-
def read_reward(self, total_reward, step_reward):
129-
"""Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two.
130-
131-
Args:
132-
total_reward (torch.Tensor or TensorDict): total reward so far in the step
133-
step_reward (reward in the format provided by the inner env): reward of this particular step
134-
135-
"""
136-
return total_reward + step_reward
144+
def read_reward(self, total_reward):
145+
return total_reward
137146

138147
def read_obs(
139148
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
@@ -183,6 +192,13 @@ def __init__(self, task=None, *, env=None, num_envs, device, **kwargs):
183192
raise RuntimeError("Cannot provide both `task` and `env` arguments.")
184193
elif env is not None:
185194
task = env
186-
envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs)
195+
from_pixels = kwargs.pop("from_pixels", False)
196+
envs = self._make_envs(
197+
task=task,
198+
num_envs=num_envs,
199+
device=device,
200+
virtual_screen_capture=True,
201+
**kwargs,
202+
)
187203
self.task = task
188-
super().__init__(envs, **kwargs)
204+
super().__init__(envs, from_pixels=from_pixels, **kwargs)

torchrl/envs/libs/jumanji.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def _jumanji_to_torchrl_spec_transform(
6767
dtype = numpy_to_torch_dtype_dict[spec.dtype]
6868
return BoundedTensorSpec(
6969
shape=shape,
70-
low=np.asarray(spec.low),
71-
high=np.asarray(spec.high),
70+
low=np.asarray(spec.minimum),
71+
high=np.asarray(spec.maximum),
7272
dtype=dtype,
7373
device=device,
7474
)

0 commit comments

Comments
 (0)