@@ -621,64 +621,45 @@ def __init__(
621
621
)
622
622
623
623
with torch .no_grad ():
624
- self ._tensordict_out = env .fake_tensordict ()
624
+ self ._tensordict_out = self .env .fake_tensordict ()
625
+ # If the policy has a valid spec, we use it
625
626
if (
626
- hasattr (self .policy , "spec" )
627
- and self .policy .spec is not None
628
- and all (
629
- v is not None for v in self .policy .spec .values (True , True )
630
- ) # if a spec is None, we don't know anything about it
631
- # and set(self.policy.spec.keys(True, True)) == set(self.policy.out_keys)
632
- and any (
633
- key not in self ._tensordict_out .keys (isinstance (key , tuple ))
634
- for key in self .policy .spec .keys (True , True )
635
- )
636
- ):
637
- # if policy spec is non-empty, all the values are not None and the keys
638
- # match the out_keys we assume the user has given all relevant information
639
- # the policy could have more keys than the env:
640
- policy_spec = self .policy .spec
641
- if policy_spec .ndim < self ._tensordict_out .ndim :
642
- policy_spec = policy_spec .expand (self ._tensordict_out .shape )
643
- for key , spec in policy_spec .items (True , True ):
644
- if key in self ._tensordict_out .keys (isinstance (key , tuple )):
645
- continue
646
- self ._tensordict_out .set (key , spec .zero ())
647
- self ._tensordict_out = (
648
- self ._tensordict_out .unsqueeze (- 1 )
649
- .expand (* env .batch_size , self .frames_per_batch )
650
- .clone ()
651
- )
652
- elif (
653
627
hasattr (self .policy , "spec" )
654
628
and self .policy .spec is not None
655
629
and all (v is not None for v in self .policy .spec .values (True , True ))
656
- and all (
657
- key in self ._tensordict_out .keys (isinstance (key , tuple ))
658
- for key in self .policy .spec .keys (True , True )
659
- )
660
630
):
661
- # reach this if the policy has specs and they match with the fake tensordict
662
- self ._tensordict_out = (
663
- self ._tensordict_out .unsqueeze (- 1 )
664
- .expand (* env .batch_size , self .frames_per_batch )
665
- .clone ()
666
- )
631
+ if any (
632
+ key not in self ._tensordict_out .keys (isinstance (key , tuple ))
633
+ for key in self .policy .spec .keys (True , True )
634
+ ):
635
+ # if policy spec is non-empty, all the values are not None and the keys
636
+ # match the out_keys we assume the user has given all relevant information
637
+ # the policy could have more keys than the env:
638
+ policy_spec = self .policy .spec
639
+ if policy_spec .ndim < self ._tensordict_out .ndim :
640
+ policy_spec = policy_spec .expand (self ._tensordict_out .shape )
641
+ for key , spec in policy_spec .items (True , True ):
642
+ if key in self ._tensordict_out .keys (isinstance (key , tuple )):
643
+ continue
644
+ self ._tensordict_out .set (key , spec .zero ())
645
+
667
646
else :
668
647
# otherwise, we perform a small number of steps with the policy to
669
648
# determine the relevant keys with which to pre-populate _tensordict_out.
670
649
# This is the safest thing to do if the spec has None fields or if there is
671
650
# no spec at all.
672
651
# See #505 for additional context.
652
+ self ._tensordict_out .update (self ._tensordict )
673
653
with torch .no_grad ():
674
- self ._tensordict_out = self ._tensordict_out .to (self .device )
675
- self ._tensordict_out = self .policy (self ._tensordict_out ).unsqueeze (- 1 )
676
- self ._tensordict_out = (
677
- self ._tensordict_out .expand (* env .batch_size , self .frames_per_batch )
678
- .clone ()
679
- .zero_ ()
680
- )
681
- # in addition to outputs of the policy, we add traj_ids and step_count to
654
+ self ._tensordict_out = self .policy (self ._tensordict_out .to (self .device ))
655
+
656
+ self ._tensordict_out = (
657
+ self ._tensordict_out .unsqueeze (- 1 )
658
+ .expand (* env .batch_size , self .frames_per_batch )
659
+ .clone ()
660
+ .zero_ ()
661
+ )
662
+ # in addition to outputs of the policy, we add traj_ids to
682
663
# _tensordict_out which will be collected during rollout
683
664
self ._tensordict_out = self ._tensordict_out .to (self .storing_device )
684
665
self ._tensordict_out .set (
0 commit comments