@@ -3015,6 +3015,52 @@ def add_truncated_keys(self) -> EnvBase:
3015
3015
self .__dict__ ["_done_keys" ] = None
3016
3016
return self
3017
3017
3018
+ def step_mdp (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
3019
+ """Advances the environment state by one step using the provided `next_tensordict`.
3020
+
3021
+ This method updates the environment's state by transitioning from the current
3022
+ state to the next, as defined by the `next_tensordict`. The resulting tensordict
3023
+ includes updated observations and any other relevant state information, with
3024
+ keys managed according to the environment's specifications.
3025
+
3026
+ Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently
3027
+ handle the transition of state, observation, action, reward, and done keys. The
3028
+ :class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and
3029
+ exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance
3030
+ is created with `exclude_action=False`, meaning that action keys are retained in
3031
+ the root tensordict.
3032
+
3033
+ Args:
3034
+ next_tensordict (TensorDictBase): A tensordict containing the state of the
3035
+ environment at the next time step. This tensordict should include keys
3036
+ for observations, actions, rewards, and done flags, as defined by the
3037
+ environment's specifications.
3038
+
3039
+ Returns:
3040
+ TensorDictBase: A new tensordict representing the environment state after
3041
+ advancing by one step.
3042
+
3043
+ .. note:: The method ensures that the environment's key specifications are validated
3044
+ against the provided `next_tensordict`, issuing warnings if discrepancies
3045
+ are found.
3046
+
3047
+ .. note:: This method is designed to work efficiently with environments that have
3048
+ consistent key specifications, leveraging the `_StepMDP` class to minimize
3049
+ overhead.
3050
+
3051
+ Example:
3052
+ >>> from torchrl.envs import GymEnv
3053
+ >>> env = GymEnv("Pendulum-1")
3054
+ >>> data = env.reset()
3055
+ >>> for i in range(10):
3056
+ ... # compute action
3057
+ ... env.rand_action(data)
3058
+ ... # Perform action
3059
+ ... next_data = env.step(reset_data)
3060
+ ... data = env.step_mdp(next_data)
3061
+ """
3062
+ return self ._step_mdp (next_tensordict )
3063
+
3018
3064
@property
3019
3065
def _step_mdp (self ):
3020
3066
step_func = self .__dict__ .get ("_step_mdp_value" )
0 commit comments