Skip to content

Commit ef5a37d

Browse files
author
Vincent Moens
committed
[Quality,BE] Better doc for step_mdp
ghstack-source-id: 1f5aed6 Pull Request resolved: #2639
1 parent dd26ae7 commit ef5a37d

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

torchrl/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ def _can_be_pickled(obj):
829829
def _make_ordinal_device(device: torch.device):
830830
if device is None:
831831
return device
832+
device = torch.device(device)
832833
if device.type == "cuda" and device.index is None:
833834
return torch.device("cuda", index=torch.cuda.current_device())
834835
if device.type == "mps" and device.index is None:

torchrl/envs/utils.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import re
1515
import warnings
1616
from enum import Enum
17-
from typing import Any, Dict, List, Union
17+
from typing import Any, Dict, List
1818

1919
import torch
2020

@@ -339,48 +339,47 @@ def step_mdp(
339339
exclude_reward: bool = True,
340340
exclude_done: bool = False,
341341
exclude_action: bool = True,
342-
reward_keys: Union[NestedKey, List[NestedKey]] = "reward",
343-
done_keys: Union[NestedKey, List[NestedKey]] = "done",
344-
action_keys: Union[NestedKey, List[NestedKey]] = "action",
342+
reward_keys: NestedKey | List[NestedKey] = "reward",
343+
done_keys: NestedKey | List[NestedKey] = "done",
344+
action_keys: NestedKey | List[NestedKey] = "action",
345345
) -> TensorDictBase:
346346
"""Creates a new tensordict that reflects a step in time of the input tensordict.
347347
348348
Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict.
349-
The arguments allow for a precise control over what should be kept and what
349+
The arguments allow for precise control over what should be kept and what
350350
should be copied from the ``"next"`` entry. The default behavior is:
351-
move the observation entries, reward and done states to the root, exclude
352-
the current action and keep all extra keys (non-action, non-done, non-reward).
351+
move the observation entries, reward, and done states to the root, exclude
352+
the current action, and keep all extra keys (non-action, non-done, non-reward).
353353
354354
Args:
355-
tensordict (TensorDictBase): tensordict with keys to be renamed
356-
next_tensordict (TensorDictBase, optional): destination tensordict
357-
keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept.
355+
tensordict (TensorDictBase): The tensordict with keys to be renamed.
356+
next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created.
357+
keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept.
358358
Default is ``True``.
359-
exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded
359+
exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded
360360
from the resulting tensordict. If ``False``, it will be copied (and replaced)
361-
from the ``"next"`` entry (if present).
362-
Default is ``True``.
363-
exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded
361+
from the ``"next"`` entry (if present). Default is ``True``.
362+
exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded
364363
from the resulting tensordict. If ``False``, it will be copied (and replaced)
365-
from the ``"next"`` entry (if present).
366-
Default is ``False``.
367-
exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will
364+
from the ``"next"`` entry (if present). Default is ``False``.
365+
exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will
368366
be discarded from the resulting tensordict. If ``False``, it will
369367
be kept in the root tensordict (since it should not be present in
370-
the ``"next"`` entry).
371-
Default is ``True``.
372-
reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults
368+
the ``"next"`` entry). Default is ``True``.
369+
reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults
373370
to "reward".
374-
done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults
371+
done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults
375372
to "done".
376-
action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults
373+
action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults
377374
to "action".
378375
379376
Returns:
380-
A new tensordict (or next_tensordict) containing the tensors of the t+1 step.
377+
TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step.
378+
379+
.. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the
380+
key values to reduce the overhead of making a step in the MDP.
381381
382382
Examples:
383-
This funtion allows for this kind of loop to be used:
384383
>>> from tensordict import TensorDict
385384
>>> import torch
386385
>>> td = TensorDict({
@@ -784,8 +783,8 @@ def check_env_specs(
784783

785784
if _has_dynamic_specs(env.specs):
786785
for real, fake in zip(
787-
real_tensordict.filter_non_tensor_data().unbind(-1),
788-
fake_tensordict.filter_non_tensor_data().unbind(-1),
786+
real_tensordict_select.filter_non_tensor_data().unbind(-1),
787+
fake_tensordict_select.filter_non_tensor_data().unbind(-1),
789788
):
790789
fake = fake.apply(lambda x, y: x.expand_as(y), real)
791790
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():

0 commit comments

Comments
 (0)