From 75de0c24efa606bccb30596a119178121efb0df0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 10 Mar 2025 13:15:09 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- test/_utils_internal.py | 20 ++++++++++++++------ test/test_env.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 05fdada16d2..eb99f8d5f0a 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -267,7 +267,9 @@ def _make_envs( transformed_in, transformed_out, N, - device="cpu", + p_env_device=None, + env_device=None, + # device="cpu", kwargs=None, local_mp_ctx=mp_ctx, ): @@ -275,13 +277,13 @@ def _make_envs( if not transformed_in: def create_env_fn(): - return GymEnv(env_name, frame_skip=frame_skip, device=device) + return GymEnv(env_name, frame_skip=frame_skip, device=env_device) else: if env_name == PONG_VERSIONED(): def create_env_fn(): - base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device) in_keys = list(base_env.observation_spec.keys(True, True))[:1] return TransformedEnv( base_env, @@ -292,7 +294,7 @@ def create_env_fn(): def create_env_fn(): - base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device) in_keys = list(base_env.observation_spec.keys(True, True))[:1] return TransformedEnv( @@ -305,9 +307,15 @@ def create_env_fn(): env0 = create_env_fn() env_parallel = ParallelEnv( - N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx + N, + create_env_fn, + create_env_kwargs=kwargs, + mp_start_method=local_mp_ctx, + device=p_env_device, + ) + env_serial = SerialEnv( + N, create_env_fn, create_env_kwargs=kwargs, device=p_env_device ) - env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs) for key in env0.observation_spec.keys(True, True): obs_key = key diff --git a/test/test_env.py b/test/test_env.py index 79f11abc4eb..df258e6eb69 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1464,12 +1464,29 @@ def make_env(): "transformed_in,transformed_out", [[True, True], [False, False]] ) # 1226: effociency @pytest.mark.parametrize("static_seed", [False, True]) + @pytest.mark.parametrize("penv_device", ["cpu", None]) + @pytest.mark.parametrize("env_device", ["cpu", None]) + @pytest.mark.parametrize("bwad", [True, False]) def test_parallel_env_seed( - self, env_name, frame_skip, transformed_in, transformed_out, static_seed + self, + env_name, + frame_skip, + transformed_in, + transformed_out, + static_seed, + penv_device, + env_device, + bwad, ): env_name = env_name() env_parallel, env_serial, _, _ = _make_envs( - env_name, frame_skip, transformed_in, transformed_out, 5 + env_name, + frame_skip, + transformed_in, + transformed_out, + 5, + p_env_device=penv_device, + env_device=env_device, ) try: out_seed_serial = env_serial.set_seed(0, static_seed=static_seed) @@ -1479,7 +1496,10 @@ def test_parallel_env_seed( torch.manual_seed(0) td_serial = env_serial.rollout( - max_steps=10, auto_reset=False, tensordict=td0_serial + max_steps=10, + auto_reset=False, + tensordict=td0_serial, + break_when_any_done=bwad, ).contiguous() key = "pixels" if "pixels" in td_serial.keys() else "observation" torch.testing.assert_close( @@ -1494,7 +1514,10 @@ def test_parallel_env_seed( torch.manual_seed(0) assert out_seed_parallel == out_seed_serial td_parallel = env_parallel.rollout( - max_steps=10, auto_reset=False, tensordict=td0_parallel + max_steps=10, + auto_reset=False, + tensordict=td0_parallel, + break_when_any_done=bwad, ).contiguous() torch.testing.assert_close( td_parallel[:, :-1].get(("next", key)), td_parallel[:, 1:].get(key) @@ -1670,7 +1693,7 @@ def test_parallel_env_device( frame_skip, transformed_in=transformed_in, transformed_out=transformed_out, - device=device, + env_device=device, N=N, local_mp_ctx="spawn", )