@@ -584,23 +584,24 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
584
584
f"got { tensordict .batch_size } and { self .batch_size } "
585
585
)
586
586
587
- def rand_step (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDictBase :
588
- """Performs a random step in the environment given the action_spec attribute.
587
+ def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
588
+ """Performs a random action given the action_spec attribute.
589
589
590
590
Args:
591
- tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
591
+ tensordict (TensorDictBase, optional): tensordict where the resulting action should be written.
592
592
593
593
Returns:
594
- a tensordict object with the new observation after a random step in the environment. The action will
595
- be stored with the " action" key .
594
+ a tensordict object with the "action" entry updated with a random
595
+ sample from the action-spec .
596
596
597
597
"""
598
598
shape = torch .Size ([])
599
599
if tensordict is None :
600
600
tensordict = TensorDict (
601
601
{}, device = self .device , batch_size = self .batch_size , _run_checks = False
602
602
)
603
- elif not self .batch_locked and not self .batch_size :
603
+
604
+ if not self .batch_locked and not self .batch_size :
604
605
shape = tensordict .shape
605
606
elif not self .batch_locked and tensordict .shape != self .batch_size :
606
607
raise RuntimeError (
@@ -611,6 +612,20 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa
611
612
)
612
613
action = self .action_spec .rand (shape )
613
614
tensordict .set ("action" , action )
615
+ return tensordict
616
+
617
+ def rand_step (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDictBase :
618
+ """Performs a random step in the environment given the action_spec attribute.
619
+
620
+ Args:
621
+ tensordict (TensorDictBase, optional): tensordict where the resulting info should be written.
622
+
623
+ Returns:
624
+ a tensordict object with the new observation after a random step in the environment. The action will
625
+ be stored with the "action" key.
626
+
627
+ """
628
+ tensordict = self .rand_action (tensordict )
614
629
return self .step (tensordict )
615
630
616
631
@property
@@ -680,7 +695,7 @@ def rollout(
680
695
if policy is None :
681
696
682
697
def policy (td ):
683
- self .rand_step (td )
698
+ self .rand_action (td )
684
699
return td
685
700
686
701
tensordicts = []
@@ -796,16 +811,18 @@ def to(self, device: DEVICE_TYPING) -> EnvBase:
796
811
def fake_tensordict (self ) -> TensorDictBase :
797
812
"""Returns a fake tensordict with key-value pairs that match in shape, device and dtype what can be expected during an environment rollout."""
798
813
input_spec = self .input_spec
799
- fake_input = input_spec .zero ()
800
814
observation_spec = self .observation_spec
801
815
fake_obs = observation_spec .zero ()
816
+ fake_input = input_spec .zero ()
817
+ # the input and output key may match, but the output prevails
818
+ # Hence we generate the input, and override using the output
819
+ fake_in_out = fake_input .clone ().update (fake_obs )
802
820
reward_spec = self .reward_spec
803
821
fake_reward = reward_spec .zero ()
804
822
fake_td = TensorDict (
805
823
{
806
- ** fake_obs ,
824
+ ** fake_in_out ,
807
825
"next" : fake_obs .clone (),
808
- ** fake_input ,
809
826
"reward" : fake_reward ,
810
827
"done" : torch .zeros (
811
828
(* self .batch_size , 1 ), dtype = torch .bool , device = self .device
0 commit comments