|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
6 |
| -import gymnasium as gym |
7 | 6 | import numpy as np
|
8 | 7 | import torch.nn
|
9 | 8 | import torch.optim
|
10 | 9 | from tensordict.nn import TensorDictModule
|
11 |
| -from torchrl.data import CompositeSpec |
| 10 | +from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec |
12 | 11 | from torchrl.data.tensor_specs import DiscreteBox
|
13 | 12 | from torchrl.envs import (
|
14 | 13 | CatFrames,
|
15 |
| - default_info_dict_reader, |
16 | 14 | DoubleToFloat,
|
17 | 15 | EnvCreator,
|
18 | 16 | ExplorationType,
|
19 | 17 | GrayScale,
|
| 18 | + GymEnv, |
20 | 19 | NoopResetEnv,
|
21 | 20 | ParallelEnv,
|
22 | 21 | Resize,
|
23 | 22 | RewardClipping,
|
24 | 23 | RewardSum,
|
25 | 24 | StepCounter,
|
26 | 25 | ToTensorImage,
|
| 26 | + Transform, |
27 | 27 | TransformedEnv,
|
28 | 28 | VecNorm,
|
29 | 29 | )
|
30 |
| -from torchrl.envs.libs.gym import GymWrapper |
31 | 30 | from torchrl.modules import (
|
32 | 31 | ActorValueOperator,
|
33 | 32 | ConvNet,
|
|
43 | 42 | # --------------------------------------------------------------------
|
44 | 43 |
|
45 | 44 |
|
46 |
| -class EpisodicLifeEnv(gym.Wrapper): |
47 |
| - def __init__(self, env): |
48 |
| - """Make end-of-life == end-of-episode, but only reset on true game over. |
49 |
| - Done by DeepMind for the DQN and co. It helps value estimation. |
50 |
| - """ |
51 |
| - gym.Wrapper.__init__(self, env) |
52 |
| - self.lives = 0 |
| 45 | +class EndOfLifeTransform(Transform): |
| 46 | + """Registers the end-of-life signal from a Gym env with a `lives` method. |
53 | 47 |
|
54 |
| - def step(self, action): |
55 |
| - obs, rew, done, truncate, info = self.env.step(action) |
56 |
| - lives = self.env.unwrapped.ale.lives() |
57 |
| - info["end_of_life"] = False |
58 |
| - if (lives < self.lives) or done: |
59 |
| - info["end_of_life"] = True |
60 |
| - self.lives = lives |
61 |
| - return obs, rew, done, truncate, info |
| 48 | + Done by DeepMind for the DQN and co. It helps value estimation. |
| 49 | + """ |
62 | 50 |
|
63 |
| - def reset(self, **kwargs): |
64 |
| - reset_data = self.env.reset(**kwargs) |
65 |
| - self.lives = self.env.unwrapped.ale.lives() |
66 |
| - return reset_data |
| 51 | + def _step(self, tensordict, next_tensordict): |
| 52 | + lives = self.parent.base_env._env.unwrapped.ale.lives() |
| 53 | + end_of_life = torch.tensor( |
| 54 | + [tensordict["lives"] < lives], device=self.parent.device |
| 55 | + ) |
| 56 | + end_of_life = end_of_life | next_tensordict.get("done") |
| 57 | + next_tensordict.set("eol", end_of_life) |
| 58 | + next_tensordict.set("lives", lives) |
| 59 | + return next_tensordict |
| 60 | + |
| 61 | + def reset(self, tensordict): |
| 62 | + lives = self.parent.base_env._env.unwrapped.ale.lives() |
| 63 | + end_of_life = False |
| 64 | + tensordict.set("eol", [end_of_life]) |
| 65 | + tensordict.set("lives", lives) |
| 66 | + return tensordict |
| 67 | + |
| 68 | + def transform_observation_spec(self, observation_spec): |
| 69 | + full_done_spec = self.parent.output_spec["full_done_spec"] |
| 70 | + observation_spec["eol"] = full_done_spec["done"].clone() |
| 71 | + observation_spec["lives"] = UnboundedDiscreteTensorSpec( |
| 72 | + self.parent.batch_size, device=self.parent.device |
| 73 | + ) |
| 74 | + return observation_spec |
67 | 75 |
|
68 | 76 |
|
69 | 77 | def make_base_env(
|
70 | 78 | env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
|
71 | 79 | ):
|
72 |
| - env = gym.make(env_name) |
73 |
| - if not is_test: |
74 |
| - env = EpisodicLifeEnv(env) |
75 |
| - env = GymWrapper( |
76 |
| - env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device |
| 80 | + env = GymEnv( |
| 81 | + env_name, |
| 82 | + frame_skip=frame_skip, |
| 83 | + from_pixels=True, |
| 84 | + pixels_only=False, |
| 85 | + device=device, |
77 | 86 | )
|
78 | 87 | env = TransformedEnv(env)
|
79 | 88 | env.append_transform(NoopResetEnv(noops=30, random=True))
|
80 | 89 | if not is_test:
|
81 |
| - reader = default_info_dict_reader(["end_of_life"]) |
82 |
| - env.set_info_dict_reader(reader) |
| 90 | + env.append_transform(EndOfLifeTransform()) |
83 | 91 | return env
|
84 | 92 |
|
85 | 93 |
|
|
0 commit comments