|
42 | 42 | CatFrames,
|
43 | 43 | CatTensors,
|
44 | 44 | ChessEnv,
|
| 45 | + ConditionalSkip, |
45 | 46 | DoubleToFloat,
|
46 | 47 | EnvBase,
|
47 | 48 | EnvCreator,
|
|
70 | 71 | check_marl_grouping,
|
71 | 72 | make_composite_from_td,
|
72 | 73 | MarlGroupMapType,
|
| 74 | + RandomPolicy, |
73 | 75 | step_mdp,
|
74 | 76 | )
|
75 | 77 | from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
|
|
131 | 133 | EnvWithMetadata,
|
132 | 134 | HeterogeneousCountingEnv,
|
133 | 135 | HeterogeneousCountingEnvPolicy,
|
| 136 | + HistoryTransform, |
134 | 137 | MockBatchedLockedEnv,
|
135 | 138 | MockBatchedUnLockedEnv,
|
136 | 139 | MockSerialEnv,
|
|
170 | 173 | EnvWithMetadata,
|
171 | 174 | HeterogeneousCountingEnv,
|
172 | 175 | HeterogeneousCountingEnvPolicy,
|
| 176 | + HistoryTransform, |
173 | 177 | MockBatchedLockedEnv,
|
174 | 178 | MockBatchedUnLockedEnv,
|
175 | 179 | MockSerialEnv,
|
@@ -3629,8 +3633,11 @@ def test_serial(self, bwad, use_buffers):
|
3629 | 3633 | def test_parallel(self, bwad, use_buffers):
|
3630 | 3634 | N = 50
|
3631 | 3635 | env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
|
3632 |
| - r = env.rollout(N, break_when_any_done=bwad) |
3633 |
| - assert r.get("non_tensor").tolist() == [list(range(N))] * 2 |
| 3636 | + try: |
| 3637 | + r = env.rollout(N, break_when_any_done=bwad) |
| 3638 | + assert r.get("non_tensor").tolist() == [list(range(N))] * 2 |
| 3639 | + finally: |
| 3640 | + env.close(raise_if_closed=False) |
3634 | 3641 |
|
3635 | 3642 | class AddString(Transform):
|
3636 | 3643 | def __init__(self):
|
@@ -3662,19 +3669,22 @@ def test_partial_reset(self, batched):
|
3662 | 3669 | env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx)
|
3663 | 3670 | else:
|
3664 | 3671 | env = SerialEnv(2, [env0, env1])
|
3665 |
| - s = env.reset() |
3666 |
| - i = 0 |
3667 |
| - for i in range(10): # noqa: B007 |
3668 |
| - s, s_ = env.step_and_maybe_reset( |
3669 |
| - s.set("action", torch.ones(2, 1, dtype=torch.int)) |
3670 |
| - ) |
3671 |
| - if s.get(("next", "done")).any(): |
3672 |
| - break |
3673 |
| - s = s_ |
3674 |
| - assert i == 5 |
3675 |
| - assert (s["next", "done"] == torch.tensor([[True], [False]])).all() |
3676 |
| - assert s_["string"] == ["0", "6"] |
3677 |
| - assert s["next", "string"] == ["6", "6"] |
| 3672 | + try: |
| 3673 | + s = env.reset() |
| 3674 | + i = 0 |
| 3675 | + for i in range(10): # noqa: B007 |
| 3676 | + s, s_ = env.step_and_maybe_reset( |
| 3677 | + s.set("action", torch.ones(2, 1, dtype=torch.int)) |
| 3678 | + ) |
| 3679 | + if s.get(("next", "done")).any(): |
| 3680 | + break |
| 3681 | + s = s_ |
| 3682 | + assert i == 5 |
| 3683 | + assert (s["next", "done"] == torch.tensor([[True], [False]])).all() |
| 3684 | + assert s_["string"] == ["0", "6"] |
| 3685 | + assert s["next", "string"] == ["6", "6"] |
| 3686 | + finally: |
| 3687 | + env.close(raise_if_closed=False) |
3678 | 3688 |
|
3679 | 3689 | @pytest.mark.skipif(not _has_transformers, reason="transformers required")
|
3680 | 3690 | def test_str2str_env_tokenizer(self):
|
@@ -4182,6 +4192,124 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
|
4182 | 4192 | assert (td[3].get("next") != 0).any()
|
4183 | 4193 |
|
4184 | 4194 |
|
| 4195 | +class TestEnvWithHistory: |
| 4196 | + @pytest.fixture(autouse=True, scope="class") |
| 4197 | + def set_capture(self): |
| 4198 | + with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env( |
| 4199 | + False |
| 4200 | + ): |
| 4201 | + yield |
| 4202 | + return |
| 4203 | + |
| 4204 | + def _make_env(self, device, max_steps=10): |
| 4205 | + return CountingEnv(device=device, max_steps=max_steps).append_transform( |
| 4206 | + HistoryTransform() |
| 4207 | + ) |
| 4208 | + |
| 4209 | + def _make_skipping_env(self, device, max_steps=10): |
| 4210 | + env = self._make_env(device=device, max_steps=max_steps) |
| 4211 | + # skip every 3 steps |
| 4212 | + env = env.append_transform( |
| 4213 | + ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2)) |
| 4214 | + ) |
| 4215 | + env = TransformedEnv(env, StepCounter()) |
| 4216 | + return env |
| 4217 | + |
| 4218 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4219 | + def test_env_history_base(self, device): |
| 4220 | + env = self._make_env(device) |
| 4221 | + env.check_env_specs() |
| 4222 | + |
| 4223 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4224 | + def test_skipping_history_env(self, device): |
| 4225 | + env = self._make_skipping_env(device) |
| 4226 | + env.check_env_specs() |
| 4227 | + r = env.rollout(100) |
| 4228 | + |
| 4229 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4230 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4231 | + @pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"]) |
| 4232 | + @pytest.mark.parametrize("consolidate", [False, True]) |
| 4233 | + def test_env_history_base_batched( |
| 4234 | + self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate |
| 4235 | + ): |
| 4236 | + if batch_cls == "parallel": |
| 4237 | + batch_cls = maybe_fork_ParallelEnv |
| 4238 | + env = batch_cls( |
| 4239 | + 2, |
| 4240 | + lambda: self._make_env(device_env), |
| 4241 | + device=device, |
| 4242 | + consolidate=consolidate, |
| 4243 | + ) |
| 4244 | + try: |
| 4245 | + assert not env._use_buffers |
| 4246 | + env.check_env_specs(break_when_any_done="both") |
| 4247 | + finally: |
| 4248 | + env.close(raise_if_closed=False) |
| 4249 | + |
| 4250 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4251 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4252 | + @pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"]) |
| 4253 | + @pytest.mark.parametrize("consolidate", [False, True]) |
| 4254 | + def test_skipping_history_env_batched( |
| 4255 | + self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate |
| 4256 | + ): |
| 4257 | + if batch_cls == "parallel": |
| 4258 | + batch_cls = maybe_fork_ParallelEnv |
| 4259 | + env = batch_cls( |
| 4260 | + 2, |
| 4261 | + lambda: self._make_skipping_env(device_env), |
| 4262 | + device=device, |
| 4263 | + consolidate=consolidate, |
| 4264 | + ) |
| 4265 | + try: |
| 4266 | + env.check_env_specs() |
| 4267 | + finally: |
| 4268 | + env.close(raise_if_closed=False) |
| 4269 | + |
| 4270 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4271 | + @pytest.mark.parametrize("collector_cls", [SyncDataCollector]) |
| 4272 | + def test_env_history_base_collector(self, device_env, collector_cls): |
| 4273 | + env = self._make_env(device_env) |
| 4274 | + collector = collector_cls( |
| 4275 | + env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5 |
| 4276 | + ) |
| 4277 | + for d in collector: |
| 4278 | + for i in range(d.shape[0] - 1): |
| 4279 | + assert ( |
| 4280 | + d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0] |
| 4281 | + ) |
| 4282 | + |
| 4283 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4284 | + @pytest.mark.parametrize("collector_cls", [SyncDataCollector]) |
| 4285 | + def test_skipping_history_env_collector(self, device_env, collector_cls): |
| 4286 | + env = self._make_skipping_env(device_env, max_steps=10) |
| 4287 | + collector = collector_cls( |
| 4288 | + env, |
| 4289 | + lambda td: td.update(env.full_action_spec.one()), |
| 4290 | + total_frames=35, |
| 4291 | + frames_per_batch=5, |
| 4292 | + ) |
| 4293 | + length = None |
| 4294 | + count = 1 |
| 4295 | + for d in collector: |
| 4296 | + for k in range(1, 5): |
| 4297 | + if len(d[k]["history"].content) == 2: |
| 4298 | + count = 1 |
| 4299 | + continue |
| 4300 | + if count % 3 == 2: |
| 4301 | + assert ( |
| 4302 | + d[k]["next", "history"].content |
| 4303 | + == d[k - 1]["next", "history"].content |
| 4304 | + ), (d["next", "history"].content, k, count) |
| 4305 | + else: |
| 4306 | + assert d[k]["next", "history"].content[-1] == str( |
| 4307 | + int(d[k - 1]["next", "history"].content[-1]) + 1 |
| 4308 | + ), (d["next", "history"].content, k, count) |
| 4309 | + count += 1 |
| 4310 | + count += 1 |
| 4311 | + |
| 4312 | + |
4185 | 4313 | if __name__ == "__main__":
|
4186 | 4314 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
4187 | 4315 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments