Skip to content

Commit a02679b

Browse files
[BugFix] Improve collector buffer initialisation when policy spec is unavailable (#1547)
Signed-off-by: Matteo Bettini <matbet@meta.com> Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 802f0e4 commit a02679b

File tree

2 files changed

+47
-46
lines changed

2 files changed

+47
-46
lines changed

test/test_collector.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,26 @@ def test_reset_heterogeneous_envs():
13971397
).all()
13981398

13991399

1400+
def test_policy_with_mask():
1401+
env = CountingBatchedEnv(start_val=torch.tensor(10), max_steps=torch.tensor(1e5))
1402+
1403+
def policy(td):
1404+
obs = td.get("observation")
1405+
# This policy cannot work with obs all 0s
1406+
if not obs.any():
1407+
raise AssertionError
1408+
action = obs.clone()
1409+
td.set("action", action)
1410+
return td
1411+
1412+
collector = SyncDataCollector(
1413+
env, policy=policy, frames_per_batch=10, total_frames=20
1414+
)
1415+
for _ in collector:
1416+
break
1417+
collector.shutdown()
1418+
1419+
14001420
class TestNestedEnvsCollector:
14011421
def test_multi_collector_nested_env_consistency(self, seed=1):
14021422
env = NestedCountingEnv()

torchrl/collectors/collectors.py

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -621,64 +621,45 @@ def __init__(
621621
)
622622

623623
with torch.no_grad():
624-
self._tensordict_out = env.fake_tensordict()
624+
self._tensordict_out = self.env.fake_tensordict()
625+
# If the policy has a valid spec, we use it
625626
if (
626-
hasattr(self.policy, "spec")
627-
and self.policy.spec is not None
628-
and all(
629-
v is not None for v in self.policy.spec.values(True, True)
630-
) # if a spec is None, we don't know anything about it
631-
# and set(self.policy.spec.keys(True, True)) == set(self.policy.out_keys)
632-
and any(
633-
key not in self._tensordict_out.keys(isinstance(key, tuple))
634-
for key in self.policy.spec.keys(True, True)
635-
)
636-
):
637-
# if policy spec is non-empty, all the values are not None and the keys
638-
# match the out_keys we assume the user has given all relevant information
639-
# the policy could have more keys than the env:
640-
policy_spec = self.policy.spec
641-
if policy_spec.ndim < self._tensordict_out.ndim:
642-
policy_spec = policy_spec.expand(self._tensordict_out.shape)
643-
for key, spec in policy_spec.items(True, True):
644-
if key in self._tensordict_out.keys(isinstance(key, tuple)):
645-
continue
646-
self._tensordict_out.set(key, spec.zero())
647-
self._tensordict_out = (
648-
self._tensordict_out.unsqueeze(-1)
649-
.expand(*env.batch_size, self.frames_per_batch)
650-
.clone()
651-
)
652-
elif (
653627
hasattr(self.policy, "spec")
654628
and self.policy.spec is not None
655629
and all(v is not None for v in self.policy.spec.values(True, True))
656-
and all(
657-
key in self._tensordict_out.keys(isinstance(key, tuple))
658-
for key in self.policy.spec.keys(True, True)
659-
)
660630
):
661-
# reach this if the policy has specs and they match with the fake tensordict
662-
self._tensordict_out = (
663-
self._tensordict_out.unsqueeze(-1)
664-
.expand(*env.batch_size, self.frames_per_batch)
665-
.clone()
666-
)
631+
if any(
632+
key not in self._tensordict_out.keys(isinstance(key, tuple))
633+
for key in self.policy.spec.keys(True, True)
634+
):
635+
# if policy spec is non-empty, all the values are not None and the keys
636+
# match the out_keys we assume the user has given all relevant information
637+
# the policy could have more keys than the env:
638+
policy_spec = self.policy.spec
639+
if policy_spec.ndim < self._tensordict_out.ndim:
640+
policy_spec = policy_spec.expand(self._tensordict_out.shape)
641+
for key, spec in policy_spec.items(True, True):
642+
if key in self._tensordict_out.keys(isinstance(key, tuple)):
643+
continue
644+
self._tensordict_out.set(key, spec.zero())
645+
667646
else:
668647
# otherwise, we perform a small number of steps with the policy to
669648
# determine the relevant keys with which to pre-populate _tensordict_out.
670649
# This is the safest thing to do if the spec has None fields or if there is
671650
# no spec at all.
672651
# See #505 for additional context.
652+
self._tensordict_out.update(self._tensordict)
673653
with torch.no_grad():
674-
self._tensordict_out = self._tensordict_out.to(self.device)
675-
self._tensordict_out = self.policy(self._tensordict_out).unsqueeze(-1)
676-
self._tensordict_out = (
677-
self._tensordict_out.expand(*env.batch_size, self.frames_per_batch)
678-
.clone()
679-
.zero_()
680-
)
681-
# in addition to outputs of the policy, we add traj_ids and step_count to
654+
self._tensordict_out = self.policy(self._tensordict_out.to(self.device))
655+
656+
self._tensordict_out = (
657+
self._tensordict_out.unsqueeze(-1)
658+
.expand(*env.batch_size, self.frames_per_batch)
659+
.clone()
660+
.zero_()
661+
)
662+
# in addition to outputs of the policy, we add traj_ids to
682663
# _tensordict_out which will be collected during rollout
683664
self._tensordict_out = self._tensordict_out.to(self.storing_device)
684665
self._tensordict_out.set(

0 commit comments

Comments
 (0)