|
14 | 14 | import re
|
15 | 15 | import warnings
|
16 | 16 | from enum import Enum
|
17 |
| -from typing import Any, Dict, List, Union |
| 17 | +from typing import Any, Dict, List |
18 | 18 |
|
19 | 19 | import torch
|
20 | 20 |
|
@@ -339,48 +339,47 @@ def step_mdp(
|
339 | 339 | exclude_reward: bool = True,
|
340 | 340 | exclude_done: bool = False,
|
341 | 341 | exclude_action: bool = True,
|
342 |
| - reward_keys: Union[NestedKey, List[NestedKey]] = "reward", |
343 |
| - done_keys: Union[NestedKey, List[NestedKey]] = "done", |
344 |
| - action_keys: Union[NestedKey, List[NestedKey]] = "action", |
| 342 | + reward_keys: NestedKey | List[NestedKey] = "reward", |
| 343 | + done_keys: NestedKey | List[NestedKey] = "done", |
| 344 | + action_keys: NestedKey | List[NestedKey] = "action", |
345 | 345 | ) -> TensorDictBase:
|
346 | 346 | """Creates a new tensordict that reflects a step in time of the input tensordict.
|
347 | 347 |
|
348 | 348 | Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict.
|
349 |
| - The arguments allow for a precise control over what should be kept and what |
| 349 | + The arguments allow for precise control over what should be kept and what |
350 | 350 | should be copied from the ``"next"`` entry. The default behavior is:
|
351 |
| - move the observation entries, reward and done states to the root, exclude |
352 |
| - the current action and keep all extra keys (non-action, non-done, non-reward). |
| 351 | + move the observation entries, reward, and done states to the root, exclude |
| 352 | + the current action, and keep all extra keys (non-action, non-done, non-reward). |
353 | 353 |
|
354 | 354 | Args:
|
355 |
| - tensordict (TensorDictBase): tensordict with keys to be renamed |
356 |
| - next_tensordict (TensorDictBase, optional): destination tensordict |
357 |
| - keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept. |
| 355 | + tensordict (TensorDictBase): The tensordict with keys to be renamed. |
| 356 | + next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created. |
| 357 | + keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept. |
358 | 358 | Default is ``True``.
|
359 |
| - exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded |
| 359 | + exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded |
360 | 360 | from the resulting tensordict. If ``False``, it will be copied (and replaced)
|
361 |
| - from the ``"next"`` entry (if present). |
362 |
| - Default is ``True``. |
363 |
| - exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded |
| 361 | + from the ``"next"`` entry (if present). Default is ``True``. |
| 362 | + exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded |
364 | 363 | from the resulting tensordict. If ``False``, it will be copied (and replaced)
|
365 |
| - from the ``"next"`` entry (if present). |
366 |
| - Default is ``False``. |
367 |
| - exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will |
| 364 | + from the ``"next"`` entry (if present). Default is ``False``. |
| 365 | + exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will |
368 | 366 | be discarded from the resulting tensordict. If ``False``, it will
|
369 | 367 | be kept in the root tensordict (since it should not be present in
|
370 |
| - the ``"next"`` entry). |
371 |
| - Default is ``True``. |
372 |
| - reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults |
| 368 | + the ``"next"`` entry). Default is ``True``. |
| 369 | + reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults |
373 | 370 | to "reward".
|
374 |
| - done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults |
| 371 | + done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults |
375 | 372 | to "done".
|
376 |
| - action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults |
| 373 | + action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults |
377 | 374 | to "action".
|
378 | 375 |
|
379 | 376 | Returns:
|
380 |
| - A new tensordict (or next_tensordict) containing the tensors of the t+1 step. |
| 377 | + TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step. |
| 378 | +
|
| 379 | + .. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the |
| 380 | + key values to reduce the overhead of making a step in the MDP. |
381 | 381 |
|
382 | 382 | Examples:
|
383 |
| - This funtion allows for this kind of loop to be used: |
384 | 383 | >>> from tensordict import TensorDict
|
385 | 384 | >>> import torch
|
386 | 385 | >>> td = TensorDict({
|
@@ -784,8 +783,8 @@ def check_env_specs(
|
784 | 783 |
|
785 | 784 | if _has_dynamic_specs(env.specs):
|
786 | 785 | for real, fake in zip(
|
787 |
| - real_tensordict.filter_non_tensor_data().unbind(-1), |
788 |
| - fake_tensordict.filter_non_tensor_data().unbind(-1), |
| 786 | + real_tensordict_select.filter_non_tensor_data().unbind(-1), |
| 787 | + fake_tensordict_select.filter_non_tensor_data().unbind(-1), |
789 | 788 | ):
|
790 | 789 | fake = fake.apply(lambda x, y: x.expand_as(y), real)
|
791 | 790 | if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
|
|
0 commit comments