|
65 | 65 | DiscreteActionProjection,
|
66 | 66 | DMControlEnv,
|
67 | 67 | DoubleToFloat,
|
| 68 | + EndOfLifeTransform, |
68 | 69 | EnvBase,
|
69 | 70 | EnvCreator,
|
70 | 71 | ExcludeTransform,
|
|
101 | 102 | VIPTransform,
|
102 | 103 | )
|
103 | 104 | from torchrl.envs.libs.dm_control import _has_dm_control
|
104 |
| -from torchrl.envs.libs.gym import _has_gym, GymEnv |
| 105 | +from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend |
105 | 106 | from torchrl.envs.transforms import VecNorm
|
106 | 107 | from torchrl.envs.transforms.r3m import _R3MNet
|
107 | 108 | from torchrl.envs.transforms.rlhf import KLRewardTransform
|
108 |
| -from torchrl.envs.transforms.transforms import _has_tv |
| 109 | +from torchrl.envs.transforms.transforms import _has_tv, FORWARD_NOT_IMPLEMENTED |
109 | 110 | from torchrl.envs.transforms.vc1 import _has_vc
|
110 | 111 | from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
|
111 | 112 | from torchrl.envs.utils import _replace_last, check_env_specs, step_mdp
|
@@ -8710,19 +8711,15 @@ def test_transform_env(self):
|
8710 | 8711 |
|
8711 | 8712 | def test_transform_model(self):
|
8712 | 8713 | t = ActionMask()
|
8713 |
| - with pytest.raises( |
8714 |
| - RuntimeError, match="ActionMask must be executed within an environment" |
8715 |
| - ): |
| 8714 | + with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))): |
8716 | 8715 | t(TensorDict({}, []))
|
8717 | 8716 |
|
8718 | 8717 | def test_transform_rb(self):
|
8719 | 8718 | t = ActionMask()
|
8720 | 8719 | rb = ReplayBuffer(storage=LazyTensorStorage(100))
|
8721 | 8720 | rb.append_transform(t)
|
8722 | 8721 | rb.extend(TensorDict({"a": [1]}, [1]).expand(10))
|
8723 |
| - with pytest.raises( |
8724 |
| - RuntimeError, match="ActionMask must be executed within an environment" |
8725 |
| - ): |
| 8722 | + with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))): |
8726 | 8723 | rb.sample(3)
|
8727 | 8724 |
|
8728 | 8725 | def test_transform_inverse(self):
|
@@ -8964,6 +8961,113 @@ def test_transform_no_env(self, batch):
|
8964 | 8961 | assert td["pixels"].shape == torch.Size((*batch, C, D, H, W))
|
8965 | 8962 |
|
8966 | 8963 |
|
| 8964 | +@pytest.mark.skipif( |
| 8965 | + not _has_gym, reason="EndOfLifeTransform can only be tested when Gym is present." |
| 8966 | +) |
| 8967 | +class TestEndOfLife(TransformBase): |
| 8968 | + def test_trans_parallel_env_check(self): |
| 8969 | + def make(): |
| 8970 | + with set_gym_backend("gymnasium"): |
| 8971 | + return GymEnv("ALE/Breakout-v5") |
| 8972 | + |
| 8973 | + with pytest.warns(UserWarning, match="The base_env is not a gym env"): |
| 8974 | + with pytest.raises(AttributeError): |
| 8975 | + env = TransformedEnv( |
| 8976 | + ParallelEnv(2, make), transform=EndOfLifeTransform() |
| 8977 | + ) |
| 8978 | + check_env_specs(env) |
| 8979 | + |
| 8980 | + def test_trans_serial_env_check(self): |
| 8981 | + def make(): |
| 8982 | + with set_gym_backend("gymnasium"): |
| 8983 | + return GymEnv("ALE/Breakout-v5") |
| 8984 | + |
| 8985 | + with pytest.warns(UserWarning, match="The base_env is not a gym env"): |
| 8986 | + env = TransformedEnv(SerialEnv(2, make), transform=EndOfLifeTransform()) |
| 8987 | + check_env_specs(env) |
| 8988 | + |
| 8989 | + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) |
| 8990 | + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) |
| 8991 | + def test_single_trans_env_check(self, eol_key, lives_key): |
| 8992 | + with set_gym_backend("gymnasium"): |
| 8993 | + env = TransformedEnv( |
| 8994 | + GymEnv("ALE/Breakout-v5"), |
| 8995 | + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), |
| 8996 | + ) |
| 8997 | + check_env_specs(env) |
| 8998 | + |
| 8999 | + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) |
| 9000 | + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) |
| 9001 | + def test_serial_trans_env_check(self, eol_key, lives_key): |
| 9002 | + def make(): |
| 9003 | + with set_gym_backend("gymnasium"): |
| 9004 | + return TransformedEnv( |
| 9005 | + GymEnv("ALE/Breakout-v5"), |
| 9006 | + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), |
| 9007 | + ) |
| 9008 | + |
| 9009 | + env = SerialEnv(2, make) |
| 9010 | + check_env_specs(env) |
| 9011 | + |
| 9012 | + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) |
| 9013 | + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) |
| 9014 | + def test_parallel_trans_env_check(self, eol_key, lives_key): |
| 9015 | + def make(): |
| 9016 | + with set_gym_backend("gymnasium"): |
| 9017 | + return TransformedEnv( |
| 9018 | + GymEnv("ALE/Breakout-v5"), |
| 9019 | + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), |
| 9020 | + ) |
| 9021 | + |
| 9022 | + env = ParallelEnv(2, make) |
| 9023 | + check_env_specs(env) |
| 9024 | + |
| 9025 | + def test_transform_no_env(self): |
| 9026 | + t = EndOfLifeTransform() |
| 9027 | + with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))): |
| 9028 | + t._step(TensorDict({}, []), TensorDict({}, [])) |
| 9029 | + |
| 9030 | + def test_transform_compose(self): |
| 9031 | + t = EndOfLifeTransform() |
| 9032 | + with pytest.raises(RuntimeError, match=t.NO_PARENT_ERR.format(type(t))): |
| 9033 | + Compose(t)._step(TensorDict({}, []), TensorDict({}, [])) |
| 9034 | + |
| 9035 | + @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) |
| 9036 | + @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) |
| 9037 | + def test_transform_env(self, eol_key, lives_key): |
| 9038 | + from tensordict.nn import TensorDictModule |
| 9039 | + from torchrl.objectives import DQNLoss |
| 9040 | + from torchrl.objectives.value import GAE |
| 9041 | + |
| 9042 | + with set_gym_backend("gymnasium"): |
| 9043 | + env = TransformedEnv( |
| 9044 | + GymEnv("ALE/Breakout-v5"), |
| 9045 | + transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), |
| 9046 | + ) |
| 9047 | + check_env_specs(env) |
| 9048 | + loss = DQNLoss(nn.Identity(), action_space="categorical") |
| 9049 | + env.transform.register_keys(loss) |
| 9050 | + assert ("next", eol_key) in loss.in_keys |
| 9051 | + gae = GAE( |
| 9052 | + gamma=0.9, |
| 9053 | + lmbda=0.9, |
| 9054 | + value_network=TensorDictModule(nn.Identity(), ["x"], ["y"]), |
| 9055 | + ) |
| 9056 | + env.transform.register_keys(gae) |
| 9057 | + assert ("next", eol_key) in gae.in_keys |
| 9058 | + |
| 9059 | + def test_transform_model(self): |
| 9060 | + t = EndOfLifeTransform() |
| 9061 | + with pytest.raises(RuntimeError, match=FORWARD_NOT_IMPLEMENTED.format(type(t))): |
| 9062 | + nn.Sequential(t)(TensorDict({}, [])) |
| 9063 | + |
| 9064 | + def test_transform_rb(self): |
| 9065 | + pass |
| 9066 | + |
| 9067 | + def test_transform_inverse(self): |
| 9068 | + pass |
| 9069 | + |
| 9070 | + |
8967 | 9071 | if __name__ == "__main__":
|
8968 | 9072 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
8969 | 9073 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments