Skip to content

Commit 37c01cc

Browse files
author
Vincent Moens
authored
[Feature] End-of-life transform (#1605)
1 parent 244f93a commit 37c01cc

File tree

10 files changed

+325
-79
lines changed

10 files changed

+325
-79
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ to be able to create this other composition:
476476
DiscreteActionProjection
477477
DoubleToFloat
478478
DTypeCastTransform
479+
EndOfLifeTransform
479480
ExcludeTransform
480481
FiniteTensorDictCheck
481482
FlattenObservation

examples/a2c/a2c_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7676
)
7777

7878
# use end-of-life as done key
79-
loss_module.set_keys(done="eol", terminated="eol")
79+
loss_module.set_keys(done="end-of-life", terminated="end-of-life")
8080

8181
# Create optimizer
8282
optim = torch.optim.Adam(

examples/a2c/utils_atari.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import torch.nn
88
import torch.optim
99
from tensordict.nn import TensorDictModule
10-
from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec
10+
from torchrl.data import CompositeSpec
1111
from torchrl.data.tensor_specs import DiscreteBox
1212
from torchrl.envs import (
1313
CatFrames,
1414
DoubleToFloat,
15+
EndOfLifeTransform,
1516
EnvCreator,
1617
ExplorationType,
1718
GrayScale,
@@ -23,7 +24,6 @@
2324
RewardSum,
2425
StepCounter,
2526
ToTensorImage,
26-
Transform,
2727
TransformedEnv,
2828
VecNorm,
2929
)
@@ -42,38 +42,6 @@
4242
# --------------------------------------------------------------------
4343

4444

45-
class EndOfLifeTransform(Transform):
46-
"""Registers the end-of-life signal from a Gym env with a `lives` method.
47-
48-
Done by DeepMind for the DQN and co. It helps value estimation.
49-
"""
50-
51-
def _step(self, tensordict, next_tensordict):
52-
lives = self.parent.base_env._env.unwrapped.ale.lives()
53-
end_of_life = torch.tensor(
54-
[tensordict["lives"] < lives], device=self.parent.device
55-
)
56-
end_of_life = end_of_life | next_tensordict.get("done")
57-
next_tensordict.set("eol", end_of_life)
58-
next_tensordict.set("lives", lives)
59-
return next_tensordict
60-
61-
def reset(self, tensordict):
62-
lives = self.parent.base_env._env.unwrapped.ale.lives()
63-
end_of_life = False
64-
tensordict.set("eol", [end_of_life])
65-
tensordict.set("lives", lives)
66-
return tensordict
67-
68-
def transform_observation_spec(self, observation_spec):
69-
full_done_spec = self.parent.output_spec["full_done_spec"]
70-
observation_spec["eol"] = full_done_spec["done"].clone()
71-
observation_spec["lives"] = UnboundedDiscreteTensorSpec(
72-
self.parent.batch_size, device=self.parent.device
73-
)
74-
return observation_spec
75-
76-
7745
def make_base_env(
7846
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
7947
):

examples/ppo/ppo_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821
7979
)
8080

8181
# use end-of-life as done key
82-
loss_module.set_keys(done="eol", terminated="eol")
82+
loss_module.set_keys(done="end-of-life", terminated="end-of-life")
8383

8484
# Create optimizer
8585
optim = torch.optim.Adam(

examples/ppo/utils_atari.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import torch.optim
88
from tensordict.nn import TensorDictModule
99
from torchrl.data import CompositeSpec
10-
from torchrl.data.tensor_specs import DiscreteBox, UnboundedDiscreteTensorSpec
10+
from torchrl.data.tensor_specs import DiscreteBox
1111
from torchrl.envs import (
1212
CatFrames,
1313
DoubleToFloat,
14+
EndOfLifeTransform,
1415
EnvCreator,
1516
ExplorationType,
1617
GrayScale,
@@ -22,7 +23,6 @@
2223
RewardSum,
2324
StepCounter,
2425
ToTensorImage,
25-
Transform,
2626
TransformedEnv,
2727
VecNorm,
2828
)
@@ -41,38 +41,6 @@
4141
# --------------------------------------------------------------------
4242

4343

44-
class EndOfLifeTransform(Transform):
45-
"""Registers the end-of-life signal from a Gym env with a `lives` method.
46-
47-
Done by DeepMind for the DQN and co. It helps value estimation.
48-
"""
49-
50-
def _step(self, tensordict, next_tensordict):
51-
lives = self.parent.base_env._env.unwrapped.ale.lives()
52-
end_of_life = torch.tensor(
53-
[tensordict["lives"] < lives], device=self.parent.device
54-
)
55-
end_of_life = end_of_life | next_tensordict.get("done")
56-
next_tensordict.set("eol", end_of_life)
57-
next_tensordict.set("lives", lives)
58-
return next_tensordict
59-
60-
def reset(self, tensordict):
61-
lives = self.parent.base_env._env.unwrapped.ale.lives()
62-
end_of_life = False
63-
tensordict.set("eol", [end_of_life])
64-
tensordict.set("lives", lives)
65-
return tensordict
66-
67-
def transform_observation_spec(self, observation_spec):
68-
full_done_spec = self.parent.output_spec["full_done_spec"]
69-
observation_spec["eol"] = full_done_spec["done"].clone()
70-
observation_spec["lives"] = UnboundedDiscreteTensorSpec(
71-
self.parent.batch_size, device=self.parent.device
72-
)
73-
return observation_spec
74-
75-
7644
def make_base_env(
7745
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
7846
):

test/test_transforms.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
DiscreteActionProjection,
6666
DMControlEnv,
6767
DoubleToFloat,
68+
EndOfLifeTransform,
6869
EnvBase,
6970
EnvCreator,
7071
ExcludeTransform,
@@ -101,11 +102,11 @@
101102
VIPTransform,
102103
)
103104
from torchrl.envs.libs.dm_control import _has_dm_control
104-
from torchrl.envs.libs.gym import _has_gym, GymEnv
105+
from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend
105106
from torchrl.envs.transforms import VecNorm
106107
from torchrl.envs.transforms.r3m import _R3MNet
107108
from torchrl.envs.transforms.rlhf import KLRewardTransform
108-
from torchrl.envs.transforms.transforms import _has_tv
109+
from torchrl.envs.transforms.transforms import _has_tv, FORWARD_NOT_IMPLEMENTED
109110
from torchrl.envs.transforms.vc1 import _has_vc
110111
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
111112
from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp
@@ -8710,19 +8711,15 @@ def test_transform_env(self):
87108711

87118712
def test_transform_model(self):
87128713
t = ActionMask()
8713-
with pytest.raises(
8714-
RuntimeError, match="ActionMask must be executed within an environment"
8715-
):
8714+
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
87168715
t(TensorDict({}, []))
87178716

87188717
def test_transform_rb(self):
87198718
t = ActionMask()
87208719
rb = ReplayBuffer(storage=LazyTensorStorage(100))
87218720
rb.append_transform(t)
87228721
rb.extend(TensorDict({"a": [1]}, [1]).expand(10))
8723-
with pytest.raises(
8724-
RuntimeError, match="ActionMask must be executed within an environment"
8725-
):
8722+
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
87268723
rb.sample(3)
87278724

87288725
def test_transform_inverse(self):
@@ -8964,6 +8961,113 @@ def test_transform_no_env(self, batch):
89648961
assert td["pixels"].shape == torch.Size((*batch, C, D, H, W))
89658962

89668963

8964+
@pytest.mark.skipif(
8965+
not _has_gym, reason="EndOfLifeTransform can only be tested when Gym is present."
8966+
)
8967+
class TestEndOfLife(TransformBase):
8968+
def test_trans_parallel_env_check(self):
8969+
def make():
8970+
with set_gym_backend("gymnasium"):
8971+
return GymEnv("ALE/Breakout-v5")
8972+
8973+
with pytest.warns(UserWarning, match="The base_env is not a gym env"):
8974+
with pytest.raises(AttributeError):
8975+
env = TransformedEnv(
8976+
ParallelEnv(2, make), transform=EndOfLifeTransform()
8977+
)
8978+
check_env_specs(env)
8979+
8980+
def test_trans_serial_env_check(self):
8981+
def make():
8982+
with set_gym_backend("gymnasium"):
8983+
return GymEnv("ALE/Breakout-v5")
8984+
8985+
with pytest.warns(UserWarning, match="The base_env is not a gym env"):
8986+
env = TransformedEnv(SerialEnv(2, make), transform=EndOfLifeTransform())
8987+
check_env_specs(env)
8988+
8989+
@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
8990+
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
8991+
def test_single_trans_env_check(self, eol_key, lives_key):
8992+
with set_gym_backend("gymnasium"):
8993+
env = TransformedEnv(
8994+
GymEnv("ALE/Breakout-v5"),
8995+
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
8996+
)
8997+
check_env_specs(env)
8998+
8999+
@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
9000+
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
9001+
def test_serial_trans_env_check(self, eol_key, lives_key):
9002+
def make():
9003+
with set_gym_backend("gymnasium"):
9004+
return TransformedEnv(
9005+
GymEnv("ALE/Breakout-v5"),
9006+
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
9007+
)
9008+
9009+
env = SerialEnv(2, make)
9010+
check_env_specs(env)
9011+
9012+
@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
9013+
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
9014+
def test_parallel_trans_env_check(self, eol_key, lives_key):
9015+
def make():
9016+
with set_gym_backend("gymnasium"):
9017+
return TransformedEnv(
9018+
GymEnv("ALE/Breakout-v5"),
9019+
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
9020+
)
9021+
9022+
env = ParallelEnv(2, make)
9023+
check_env_specs(env)
9024+
9025+
def test_transform_no_env(self):
9026+
t = EndOfLifeTransform()
9027+
with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))):
9028+
t._step(TensorDict({}, []), TensorDict({}, []))
9029+
9030+
def test_transform_compose(self):
9031+
t = EndOfLifeTransform()
9032+
with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))):
9033+
Compose(t)._step(TensorDict({}, []), TensorDict({}, []))
9034+
9035+
@pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")])
9036+
@pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")])
9037+
def test_transform_env(self, eol_key, lives_key):
9038+
from tensordict.nn import TensorDictModule
9039+
from torchrl.objectives import DQNLoss
9040+
from torchrl.objectives.value import GAE
9041+
9042+
with set_gym_backend("gymnasium"):
9043+
env = TransformedEnv(
9044+
GymEnv("ALE/Breakout-v5"),
9045+
transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key),
9046+
)
9047+
check_env_specs(env)
9048+
loss = DQNLoss(nn.Identity(), action_space="categorical")
9049+
env.transform.register_keys(loss)
9050+
assert ("next", eol_key) in loss.in_keys
9051+
gae = GAE(
9052+
gamma=0.9,
9053+
lmbda=0.9,
9054+
value_network=TensorDictModule(nn.Identity(), ["x"], ["y"]),
9055+
)
9056+
env.transform.register_keys(gae)
9057+
assert ("next", eol_key) in gae.in_keys
9058+
9059+
def test_transform_model(self):
9060+
t = EndOfLifeTransform()
9061+
with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))):
9062+
nn.Sequential(t)(TensorDict({}, []))
9063+
9064+
def test_transform_rb(self):
9065+
pass
9066+
9067+
def test_transform_inverse(self):
9068+
pass
9069+
9070+
89679071
if __name__ == "__main__":
89689072
args, unknown = argparse.ArgumentParser().parse_known_args()
89699073
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
DiscreteActionProjection,
4646
DoubleToFloat,
4747
DTypeCastTransform,
48+
EndOfLifeTransform,
4849
ExcludeTransform,
4950
FiniteTensorDictCheck,
5051
FlattenObservation,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .gym_transforms import EndOfLifeTransform
67
from .r3m import R3MTransform
78
from .rlhf import KLRewardTransform
89
from .transforms import (

0 commit comments

Comments
 (0)