Skip to content

Commit 821d8bc

Browse files
authored
[BugFix] Minor fixes PPO / A2C examples (#1591)
1 parent 1697102 commit 821d8bc

File tree

4 files changed

+51
-40
lines changed

4 files changed

+51
-40
lines changed

examples/a2c/a2c_atari.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def main(cfg: "DictConfig"): # noqa: F821
7575
critic_coef=cfg.loss.critic_coef,
7676
)
7777

78+
# use end-of-life as done key
79+
loss_module.set_keys(done="eol", terminated="eol")
80+
7881
# Create optimizer
7982
optim = torch.optim.Adam(
8083
loss_module.parameters(),

examples/a2c/utils_atari.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,30 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import gymnasium as gym
76
import numpy as np
87
import torch.nn
98
import torch.optim
109
from tensordict.nn import TensorDictModule
11-
from torchrl.data import CompositeSpec
10+
from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec
1211
from torchrl.data.tensor_specs import DiscreteBox
1312
from torchrl.envs import (
1413
CatFrames,
15-
default_info_dict_reader,
1614
DoubleToFloat,
1715
EnvCreator,
1816
ExplorationType,
1917
GrayScale,
18+
GymEnv,
2019
NoopResetEnv,
2120
ParallelEnv,
2221
Resize,
2322
RewardClipping,
2423
RewardSum,
2524
StepCounter,
2625
ToTensorImage,
26+
Transform,
2727
TransformedEnv,
2828
VecNorm,
2929
)
30-
from torchrl.envs.libs.gym import GymWrapper
3130
from torchrl.modules import (
3231
ActorValueOperator,
3332
ConvNet,
@@ -43,43 +42,52 @@
4342
# --------------------------------------------------------------------
4443

4544

46-
class EpisodicLifeEnv(gym.Wrapper):
47-
def __init__(self, env):
48-
"""Make end-of-life == end-of-episode, but only reset on true game over.
49-
Done by DeepMind for the DQN and co. It helps value estimation.
50-
"""
51-
gym.Wrapper.__init__(self, env)
52-
self.lives = 0
45+
class EndOfLifeTransform(Transform):
46+
"""Registers the end-of-life signal from a Gym env with a `lives` method.
5347
54-
def step(self, action):
55-
obs, rew, done, truncate, info = self.env.step(action)
56-
lives = self.env.unwrapped.ale.lives()
57-
info["end_of_life"] = False
58-
if (lives < self.lives) or done:
59-
info["end_of_life"] = True
60-
self.lives = lives
61-
return obs, rew, done, truncate, info
48+
Done by DeepMind for the DQN and co. It helps value estimation.
49+
"""
6250

63-
def reset(self, **kwargs):
64-
reset_data = self.env.reset(**kwargs)
65-
self.lives = self.env.unwrapped.ale.lives()
66-
return reset_data
51+
def _step(self, tensordict, next_tensordict):
52+
lives = self.parent.base_env._env.unwrapped.ale.lives()
53+
end_of_life = torch.tensor(
54+
[tensordict["lives"] < lives], device=self.parent.device
55+
)
56+
end_of_life = end_of_life | next_tensordict.get("done")
57+
next_tensordict.set("eol", end_of_life)
58+
next_tensordict.set("lives", lives)
59+
return next_tensordict
60+
61+
def reset(self, tensordict):
62+
lives = self.parent.base_env._env.unwrapped.ale.lives()
63+
end_of_life = False
64+
tensordict.set("eol", [end_of_life])
65+
tensordict.set("lives", lives)
66+
return tensordict
67+
68+
def transform_observation_spec(self, observation_spec):
69+
full_done_spec = self.parent.output_spec["full_done_spec"]
70+
observation_spec["eol"] = full_done_spec["done"].clone()
71+
observation_spec["lives"] = UnboundedDiscreteTensorSpec(
72+
self.parent.batch_size, device=self.parent.device
73+
)
74+
return observation_spec
6775

6876

6977
def make_base_env(
7078
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
7179
):
72-
env = gym.make(env_name)
73-
if not is_test:
74-
env = EpisodicLifeEnv(env)
75-
env = GymWrapper(
76-
env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device
80+
env = GymEnv(
81+
env_name,
82+
frame_skip=frame_skip,
83+
from_pixels=True,
84+
pixels_only=False,
85+
device=device,
7786
)
7887
env = TransformedEnv(env)
7988
env.append_transform(NoopResetEnv(noops=30, random=True))
8089
if not is_test:
81-
reader = default_info_dict_reader(["end_of_life"])
82-
env.set_info_dict_reader(reader)
90+
env.append_transform(EndOfLifeTransform())
8391
return env
8492

8593

examples/ppo/ppo_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7979
)
8080

8181
# use end-of-life as done key
82-
loss_module.set_keys(done="eol")
82+
loss_module.set_keys(done="eol", terminated="eol")
8383

8484
# Create optimizer
8585
optim = torch.optim.Adam(

examples/ppo/utils_atari.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import gymnasium as gym
76
import torch.nn
87
import torch.optim
98
from tensordict.nn import TensorDictModule
109
from torchrl.data import CompositeSpec
1110
from torchrl.data.tensor_specs import DiscreteBox, UnboundedDiscreteTensorSpec
1211
from torchrl.envs import (
1312
CatFrames,
14-
default_info_dict_reader,
1513
DoubleToFloat,
1614
EnvCreator,
1715
ExplorationType,
1816
GrayScale,
17+
GymEnv,
1918
NoopResetEnv,
2019
ParallelEnv,
2120
Resize,
@@ -27,7 +26,6 @@
2726
TransformedEnv,
2827
VecNorm,
2928
)
30-
from torchrl.envs.libs.gym import GymWrapper
3129
from torchrl.modules import (
3230
ActorValueOperator,
3331
ConvNet,
@@ -78,15 +76,17 @@ def transform_observation_spec(self, observation_spec):
7876
def make_base_env(
7977
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
8078
):
81-
env = gym.make(env_name)
82-
env = GymWrapper(
83-
env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device
79+
env = GymEnv(
80+
env_name,
81+
frame_skip=frame_skip,
82+
from_pixels=True,
83+
pixels_only=False,
84+
device=device,
8485
)
85-
env = TransformedEnv(env, EndOfLifeTransform())
86+
env = TransformedEnv(env)
8687
env.append_transform(NoopResetEnv(noops=30, random=True))
8788
if not is_test:
88-
reader = default_info_dict_reader(["end_of_life"])
89-
env.set_info_dict_reader(reader)
89+
env.append_transform(EndOfLifeTransform())
9090
return env
9191

9292

0 commit comments

Comments
 (0)