Skip to content

Commit c72583f

Browse files
author
Vincent Moens
committed
[Feature, Test] Adding tests for envs that have no specs
ghstack-source-id: 4c75691 Pull Request resolved: #2621
1 parent 830f2f2 commit c72583f

File tree

4 files changed

+52
-2
lines changed

4 files changed

+52
-2
lines changed

test/mocking_classes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,3 +1996,17 @@ def _step(
19961996

19971997
def _set_seed(self, seed: Optional[int]):
19981998
...
1999+
2000+
2001+
class EnvThatDoesNothing(EnvBase):
2002+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
2003+
return TensorDict(batch_size=self.batch_size, device=self.device)
2004+
2005+
def _step(
2006+
self,
2007+
tensordict: TensorDictBase,
2008+
) -> TensorDictBase:
2009+
return TensorDict(batch_size=self.batch_size, device=self.device)
2010+
2011+
def _set_seed(self, seed):
2012+
...

test/test_env.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
DiscreteActionConvMockEnvNumpy,
4545
DiscreteActionVecMockEnv,
4646
DummyModelBasedEnvBase,
47+
EnvThatDoesNothing,
4748
EnvWithDynamicSpec,
4849
EnvWithMetadata,
4950
HeterogeneousCountingEnv,
@@ -81,6 +82,7 @@
8182
DiscreteActionConvMockEnvNumpy,
8283
DiscreteActionVecMockEnv,
8384
DummyModelBasedEnvBase,
85+
EnvThatDoesNothing,
8486
EnvWithDynamicSpec,
8587
EnvWithMetadata,
8688
HeterogeneousCountingEnv,
@@ -3554,6 +3556,34 @@ def test_auto_spec():
35543556
env.check_env_specs(tensordict=td.copy())
35553557

35563558

3559+
def test_env_that_does_nothing():
3560+
env = EnvThatDoesNothing()
3561+
env.check_env_specs()
3562+
r = env.rollout(3)
3563+
r.exclude(
3564+
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
3565+
)
3566+
assert r.is_empty()
3567+
p_env = SerialEnv(2, EnvThatDoesNothing)
3568+
p_env.check_env_specs()
3569+
r = p_env.rollout(3)
3570+
r.exclude(
3571+
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
3572+
)
3573+
assert r.is_empty()
3574+
p_env = ParallelEnv(2, EnvThatDoesNothing)
3575+
try:
3576+
p_env.check_env_specs()
3577+
r = p_env.rollout(3)
3578+
r.exclude(
3579+
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
3580+
)
3581+
assert r.is_empty()
3582+
finally:
3583+
p_env.close()
3584+
del p_env
3585+
3586+
35573587
if __name__ == "__main__":
35583588
args, unknown = argparse.ArgumentParser().parse_known_args()
35593589
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2434,8 +2434,12 @@ def _register_gym( # noqa: F811
24342434
apply_api_compatibility=apply_api_compatibility,
24352435
)
24362436

2437-
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
2438-
raise NotImplementedError("EnvBase.forward is not implemented")
2437+
def forward(self, *args, **kwargs):
2438+
raise NotImplementedError(
2439+
"EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use "
2440+
"a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. "
2441+
"Batched envs require constructors because environment instances may not always be serializable."
2442+
)
24392443

24402444
@abc.abstractmethod
24412445
def _step(

torchrl/envs/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ def __call__(self, tensordict):
287287
if self.validate(tensordict):
288288
if self.keep_other:
289289
out = self._exclude(self.exclude_from_root, tensordict, out=None)
290+
if out is None:
291+
out = tensordict.empty()
290292
else:
291293
out = next_td.empty()
292294
self._grab_and_place(

0 commit comments

Comments
 (0)