Skip to content

Commit 4bc40a8

Browse files
author
Vincent Moens
committed
[Feature] env.step_mdp
ghstack-source-id: 145e37c Pull Request resolved: #2636
1 parent 30d21e5 commit 4bc40a8

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

torchrl/envs/common.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,6 +3015,52 @@ def add_truncated_keys(self) -> EnvBase:
30153015
self.__dict__["_done_keys"] = None
30163016
return self
30173017

3018+
def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase:
3019+
"""Advances the environment state by one step using the provided `next_tensordict`.
3020+
3021+
This method updates the environment's state by transitioning from the current
3022+
state to the next, as defined by the `next_tensordict`. The resulting tensordict
3023+
includes updated observations and any other relevant state information, with
3024+
keys managed according to the environment's specifications.
3025+
3026+
Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently
3027+
handle the transition of state, observation, action, reward, and done keys. The
3028+
:class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and
3029+
exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance
3030+
is created with `exclude_action=False`, meaning that action keys are retained in
3031+
the root tensordict.
3032+
3033+
Args:
3034+
next_tensordict (TensorDictBase): A tensordict containing the state of the
3035+
environment at the next time step. This tensordict should include keys
3036+
for observations, actions, rewards, and done flags, as defined by the
3037+
environment's specifications.
3038+
3039+
Returns:
3040+
TensorDictBase: A new tensordict representing the environment state after
3041+
advancing by one step.
3042+
3043+
.. note:: The method ensures that the environment's key specifications are validated
3044+
against the provided `next_tensordict`, issuing warnings if discrepancies
3045+
are found.
3046+
3047+
.. note:: This method is designed to work efficiently with environments that have
3048+
consistent key specifications, leveraging the `_StepMDP` class to minimize
3049+
overhead.
3050+
3051+
Example:
3052+
>>> from torchrl.envs import GymEnv
3053+
>>> env = GymEnv("Pendulum-1")
3054+
>>> data = env.reset()
3055+
>>> for i in range(10):
3056+
... # compute action
3057+
... env.rand_action(data)
3058+
... # Perform action
3059+
... next_data = env.step(reset_data)
3060+
... data = env.step_mdp(next_data)
3061+
"""
3062+
return self._step_mdp(next_tensordict)
3063+
30183064
@property
30193065
def _step_mdp(self):
30203066
step_func = self.__dict__.get("_step_mdp_value")

0 commit comments

Comments
 (0)