Skip to content

Commit 908ca39

Browse files
author
Vincent Moens
authored
Revert "[BugFix] Fix Isaac" (#2118)
1 parent 422f1ac commit 908ca39

File tree

4 files changed

+43
-72
lines changed

4 files changed

+43
-72
lines changed

test/test_libs.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
import importlib.util
5+
import importlib
6+
import os
7+
from contextlib import nullcontext
8+
from pathlib import Path
9+
10+
from torchrl._utils import logger as torchrl_logger
11+
12+
from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay
13+
14+
from torchrl.envs.transforms import ActionMask, TransformedEnv
15+
from torchrl.modules import MaskedCategorical
616

717
_has_isaac = importlib.util.find_spec("isaacgym") is not None
818

@@ -11,13 +21,11 @@
1121
import isaacgym # noqa
1222
import isaacgymenvs # noqa
1323
from torchrl.envs.libs.isaacgym import IsaacGymEnv
24+
1425
import argparse
1526
import importlib
16-
import os
1727

1828
import time
19-
from contextlib import nullcontext
20-
from pathlib import Path
2129
from sys import platform
2230
from typing import Optional, Union
2331

@@ -49,8 +57,7 @@
4957
TensorDictSequential,
5058
)
5159
from torch import nn
52-
53-
from torchrl._utils import implement_for, logger as torchrl_logger
60+
from torchrl._utils import implement_for
5461
from torchrl.collectors.collectors import SyncDataCollector
5562
from torchrl.data import (
5663
BinaryDiscreteTensorSpec,
@@ -67,8 +74,6 @@
6774
)
6875
from torchrl.data.datasets.atari_dqn import AtariDQNExperienceReplay
6976
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
70-
71-
from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay
7277
from torchrl.data.datasets.minari_data import MinariExperienceReplay
7378
from torchrl.data.datasets.openml import OpenMLExperienceReplay
7479
from torchrl.data.datasets.openx import OpenXExperienceReplay
@@ -109,21 +114,13 @@
109114
from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv
110115
from torchrl.envs.libs.smacv2 import _has_smacv2, SMACv2Env
111116
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
112-
113-
from torchrl.envs.transforms import ActionMask, TransformedEnv
114117
from torchrl.envs.utils import (
115118
check_env_specs,
116119
ExplorationType,
117120
MarlGroupMapType,
118121
RandomPolicy,
119122
)
120-
from torchrl.modules import (
121-
ActorCriticOperator,
122-
MaskedCategorical,
123-
MLP,
124-
SafeModule,
125-
ValueOperator,
126-
)
123+
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator
127124

128125
_has_d4rl = importlib.util.find_spec("d4rl") is not None
129126

@@ -3087,28 +3084,22 @@ def test_data(self, dataset):
30873084
)
30883085
@pytest.mark.parametrize("num_envs", [10, 20])
30893086
@pytest.mark.parametrize("device", get_default_devices())
3090-
@pytest.mark.parametrize("from_pixels", [True, False])
30913087
class TestIsaacGym:
30923088
@classmethod
3093-
def _run_on_proc(cls, q, task, num_envs, device, from_pixels):
3089+
def _run_on_proc(cls, q, task, num_envs, device):
30943090
try:
3095-
env = IsaacGymEnv(
3096-
task=task, num_envs=num_envs, device=device, from_pixels=from_pixels
3097-
)
3091+
env = IsaacGymEnv(task=task, num_envs=num_envs, device=device)
30983092
check_env_specs(env)
30993093
q.put(("succeeded!", None))
31003094
except Exception as err:
31013095
q.put(("failed!", err))
31023096
raise err
31033097

3104-
def test_env(self, task, num_envs, device, from_pixels):
3098+
def test_env(self, task, num_envs, device):
31053099
from torch import multiprocessing as mp
31063100

31073101
q = mp.Queue(1)
3108-
self._run_on_proc(q, task, num_envs, device, from_pixels)
3109-
proc = mp.Process(
3110-
target=self._run_on_proc, args=(q, task, num_envs, device, from_pixels)
3111-
)
3102+
proc = mp.Process(target=self._run_on_proc, args=(q, task, num_envs, device))
31123103
try:
31133104
proc.start()
31143105
msg, error = q.get()

torchrl/envs/libs/gym.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -943,9 +943,6 @@ def _reward_space(self, env): # noqa: F811
943943
return rs
944944

945945
def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
946-
# If batch_size is provided, we se it to tell what batch size must be used
947-
# instead of self.batch_size
948-
cur_batch_size = self.batch_size if batch_size is None else torch.Size([])
949946
action_spec = _gym_to_torchrl_spec_transform(
950947
env.action_space,
951948
device=self.device,
@@ -959,14 +956,14 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
959956
if not isinstance(observation_spec, CompositeSpec):
960957
if self.from_pixels:
961958
observation_spec = CompositeSpec(
962-
pixels=observation_spec, shape=cur_batch_size
959+
pixels=observation_spec, shape=self.batch_size
963960
)
964961
else:
965962
observation_spec = CompositeSpec(
966-
observation=observation_spec, shape=cur_batch_size
963+
observation=observation_spec, shape=self.batch_size
967964
)
968-
elif observation_spec.shape[: len(cur_batch_size)] != cur_batch_size:
969-
observation_spec.shape = cur_batch_size
965+
elif observation_spec.shape[: len(self.batch_size)] != self.batch_size:
966+
observation_spec.shape = self.batch_size
970967

971968
reward_space = self._reward_space(env)
972969
if reward_space is not None:
@@ -986,11 +983,10 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
986983
observation_spec = observation_spec.expand(
987984
*batch_size, *observation_spec.shape
988985
)
989-
990986
self.done_spec = self._make_done_spec()
991987
self.action_spec = action_spec
992-
if reward_spec.shape[: len(cur_batch_size)] != cur_batch_size:
993-
self.reward_spec = reward_spec.expand(*cur_batch_size, *reward_spec.shape)
988+
if reward_spec.shape[: len(self.batch_size)] != self.batch_size:
989+
self.reward_spec = reward_spec.expand(*self.batch_size, *reward_spec.shape)
994990
else:
995991
self.reward_spec = reward_spec
996992
self.observation_spec = observation_spec

torchrl/envs/libs/isaacgym.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import torch
1515

1616
from tensordict import TensorDictBase
17-
from torchrl.data import CompositeSpec
1817
from torchrl.envs.libs.gym import GymWrapper
1918
from torchrl.envs.utils import _classproperty, make_composite_from_td
2019

@@ -50,23 +49,19 @@ def __init__(
5049
warnings.warn(
5150
"IsaacGym environment support is an experimental feature that may change in the future."
5251
)
52+
num_envs = env.num_envs
5353
super().__init__(
54-
env, torch.device(env.device), batch_size=torch.Size([]), **kwargs
54+
env, torch.device(env.device), batch_size=torch.Size([num_envs]), **kwargs
5555
)
5656
if not hasattr(self, "task"):
5757
# by convention in IsaacGymEnvs
5858
self.task = env.__name__
5959

6060
def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
6161
super()._make_specs(env, batch_size=self.batch_size)
62-
self.full_done_spec = CompositeSpec(
63-
{
64-
key: spec.squeeze(-1)
65-
for key, spec in self.full_done_spec.items(True, True)
66-
},
67-
shape=self.batch_size,
68-
)
69-
62+
self.full_done_spec = {
63+
key: spec.squeeze(-1) for key, spec in self.full_done_spec.items(True, True)
64+
}
7065
self.observation_spec["obs"] = self.observation_spec["observation"]
7166
del self.observation_spec["observation"]
7267

@@ -83,18 +78,7 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
8378
obs_spec.unlock_()
8479
obs_spec.update(specs)
8580
obs_spec.lock_()
86-
87-
def _output_transform(self, output):
88-
obs, reward, done, info = output
89-
if self.from_pixels:
90-
obs["pixels"] = self._env.render(mode="rgb_array")
91-
return obs, reward, done ^ done, done, done, info
92-
93-
def _reset_output_transform(self, reset_data):
94-
reset_data.pop("reward", None)
95-
if self.from_pixels:
96-
reset_data["pixels"] = self._env.render(mode="rgb_array")
97-
return reset_data, {}
81+
self.__dict__["full_observation_spec"] = obs_spec
9882

9983
@classmethod
10084
def _make_envs(cls, *, task, num_envs, device, seed=None, headless=True, **kwargs):
@@ -141,8 +125,15 @@ def read_done(
141125
done = done.bool()
142126
return terminated, truncated, done, done.any()
143127

144-
def read_reward(self, total_reward):
145-
return total_reward
128+
def read_reward(self, total_reward, step_reward):
129+
"""Reads a reward and the total reward so far (in the frame skip loop) and returns a sum of the two.
130+
131+
Args:
132+
total_reward (torch.Tensor or TensorDict): total reward so far in the step
133+
step_reward (reward in the format provided by the inner env): reward of this particular step
134+
135+
"""
136+
return total_reward + step_reward
146137

147138
def read_obs(
148139
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
@@ -192,13 +183,6 @@ def __init__(self, task=None, *, env=None, num_envs, device, **kwargs):
192183
raise RuntimeError("Cannot provide both `task` and `env` arguments.")
193184
elif env is not None:
194185
task = env
195-
from_pixels = kwargs.pop("from_pixels", False)
196-
envs = self._make_envs(
197-
task=task,
198-
num_envs=num_envs,
199-
device=device,
200-
virtual_screen_capture=True,
201-
**kwargs,
202-
)
186+
envs = self._make_envs(task=task, num_envs=num_envs, device=device, **kwargs)
203187
self.task = task
204-
super().__init__(envs, from_pixels=from_pixels, **kwargs)
188+
super().__init__(envs, **kwargs)

torchrl/envs/libs/jumanji.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def _jumanji_to_torchrl_spec_transform(
6767
dtype = numpy_to_torch_dtype_dict[spec.dtype]
6868
return BoundedTensorSpec(
6969
shape=shape,
70-
low=np.asarray(spec.minimum),
71-
high=np.asarray(spec.maximum),
70+
low=np.asarray(spec.low),
71+
high=np.asarray(spec.high),
7272
dtype=dtype,
7373
device=device,
7474
)

0 commit comments

Comments
 (0)