Skip to content

Commit 6c7d233

Browse files
author
Vincent Moens
committed
[Test] More comprehensive tests for auto_spec
ghstack-source-id: 7535249 Pull Request resolved: #2640
1 parent ef5a37d commit 6c7d233

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

test/mocking_classes.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,14 +1931,18 @@ def __init__(self):
19311931
tensor=Unbounded(3),
19321932
non_tensor=NonTensor(shape=()),
19331933
)
1934+
self._saved_obs_spec = self.observation_spec.clone()
19341935
self.state_spec = Composite(
19351936
non_tensor=NonTensor(shape=()),
19361937
)
1938+
self._saved_state_spec = self.state_spec.clone()
19371939
self.reward_spec = Unbounded(1)
1940+
self._saved_full_reward_spec = self.full_reward_spec.clone()
19381941
self.action_spec = Unbounded(1)
1942+
self._saved_full_action_spec = self.full_action_spec.clone()
19391943

19401944
def _reset(self, tensordict):
1941-
data = self.observation_spec.zero()
1945+
data = self._saved_obs_spec.zero()
19421946
data.set_non_tensor("non_tensor", 0)
19431947
data.update(self.full_done_spec.zero())
19441948
return data
@@ -1947,10 +1951,10 @@ def _step(
19471951
self,
19481952
tensordict: TensorDictBase,
19491953
) -> TensorDictBase:
1950-
data = self.observation_spec.zero()
1954+
data = self._saved_obs_spec.zero()
19511955
data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
19521956
data.update(self.full_done_spec.zero())
1953-
data.update(self.full_reward_spec.zero())
1957+
data.update(self._saved_full_reward_spec.zero())
19541958
return data
19551959

19561960
def _set_seed(self, seed: Optional[int]):

test/test_env.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,8 +3553,13 @@ def test_single_env_spec():
35533553
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))
35543554

35553555

3556-
def test_auto_spec():
3557-
env = CountingEnv()
3556+
@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata])
3557+
def test_auto_spec(env_type):
3558+
if env_type is EnvWithMetadata:
3559+
obs_vals = ["tensor", "non_tensor"]
3560+
else:
3561+
obs_vals = "observation"
3562+
env = env_type()
35583563
td = env.reset()
35593564

35603565
policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
@@ -3577,7 +3582,7 @@ def test_auto_spec():
35773582
shape=env.full_state_spec.shape, device=env.full_state_spec.device
35783583
)
35793584
env._action_keys = ["action"]
3580-
env.auto_specs_(policy, tensordict=td.copy())
3585+
env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals)
35813586
env.check_env_specs(tensordict=td.copy())
35823587

35833588

0 commit comments

Comments
 (0)