Skip to content

Commit 3fd637f

Browse files
author
Vincent Moens
authored
[Doc] TED format (#1836)
1 parent e679e71 commit 3fd637f

File tree

3 files changed

+175
-2
lines changed

3 files changed

+175
-2
lines changed

docs/source/reference/data.rst

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,175 @@ before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback
298298
of this method is that it will struggle to save big data structures, which is a
299299
common setting when using replay buffers.
300300

301+
TorchRL Episode Data Format (TED)
302+
---------------------------------
303+
304+
In TorchRL, sequential data is consistently presented in a specific format, known
305+
as the TorchRL Episode Data Format (TED). This format is crucial for the seamless
306+
integration and functioning of various components within TorchRL.
307+
308+
Some components, such as replay buffers, are somewhat indifferent to the data
309+
format. However, others, particularly environments, heavily depend on it for smooth operation.
310+
311+
Therefore, it's essential to understand the TED, its purpose, and how to interact
312+
with it. This guide will provide a clear explanation of the TED, why it's used,
313+
and how to effectively work with it.
314+
315+
The Rationale Behind TED
316+
~~~~~~~~~~~~~~~~~~~~~~~~
317+
318+
Formatting sequential data can be a complex task, especially in the realm of
319+
Reinforcement Learning (RL). As practitioners, we often encounter situations
320+
where data is delivered at the reset time (though not always), and sometimes data
321+
is provided or discarded at the final step of the trajectory.
322+
323+
This variability means that we can observe data of different lengths in a dataset,
324+
and it's not always immediately clear how to match each time step across the
325+
various elements of this dataset. Consider the following ambiguous dataset structure:
326+
327+
>>> observation.shape
328+
[200, 3]
329+
>>> action.shape
330+
[199, 4]
331+
>>> info.shape
332+
[200, 3]
333+
334+
At first glance, it seems that the info and observation were delivered
335+
together (one of each at reset + one of each at each step call), as suggested by
336+
the action having one less element. However, if info has one less element, we
337+
must assume that it was either omitted at reset time or not delivered or recorded
338+
for the last step of the trajectory. Without proper documentation of the data
339+
structure, it's impossible to determine which info corresponds to which time step.
340+
341+
Complicating matters further, some datasets provide inconsistent data formats,
342+
where ``observations`` or ``infos`` are missing at the start or end of the
343+
rollout, and this behavior is often not documented.
344+
The primary aim of TED is to eliminate these ambiguities by providing a clear
345+
and consistent data representation.
346+
347+
The structure of TED
348+
~~~~~~~~~~~~~~~~~~~~
349+
350+
TED is built upon the canonical definition of a Markov Decision Process (MDP) in RL contexts.
351+
At each step, an observation conditions an action that results in (1) a new
352+
observation, (2) an indicator of task completion (terminated, truncated, done),
353+
and (3) a reward signal.
354+
355+
Some elements may be missing (for example, the reward is optional in imitation
356+
learning contexts), or additional information may be passed through a state or
357+
info container. In some cases, additional information is required to get the
358+
observation during a call to ``step`` (for instance, in stateless environment simulators). Furthermore,
359+
in certain scenarios, an "action" (or any other data) cannot be represented as a
360+
single tensor and needs to be organized differently. For example, in Multi-Agent RL
361+
settings, actions, observations, rewards, and completion signals may be composite.
362+
363+
TED accommodates all these scenarios with a single, uniform, and unambiguous
364+
format. We distinguish what happens at time step ``t`` and ``t+1`` by setting a
365+
limit at the time the action is executed. In other words, everything that was
366+
present before ``env.step`` was called belongs to ``t``, and everything that
367+
comes after belongs to ``t+1``.
368+
369+
The general rule is that everything that belongs to time step ``t`` is stored
370+
at the root of the tensordict, while everything that belongs to ``t+1`` is stored
371+
in the ``"next"`` entry of the tensordict. Here's an example:
372+
373+
>>> data = env.reset()
374+
>>> data = policy(data)
375+
>>> print(env.step(data))
376+
TensorDict(
377+
fields={
378+
action: Tensor(...), # The action taken at time t
379+
done: Tensor(...), # The done state when the action was taken (at reset)
380+
next: TensorDict( # all of this content comes from the call to `step`
381+
fields={
382+
done: Tensor(...), # The done state after the action has been taken
383+
observation: Tensor(...), # The observation resulting from the action
384+
reward: Tensor(...), # The reward resulting from the action
385+
terminated: Tensor(...), # The terminated state after the action has been taken
386+
truncated: Tensor(...), # The truncated state after the action has been taken
387+
batch_size=torch.Size([]),
388+
device=cpu,
389+
is_shared=False),
390+
observation: Tensor(...), # the observation at reset
391+
terminated: Tensor(...), # the terminated at reset
392+
truncated: Tensor(...), # the truncated at reset
393+
batch_size=torch.Size([]),
394+
device=cpu,
395+
is_shared=False)
396+
397+
During a rollout (either using :class:`~torchrl.envs.EnvBase` or
398+
:class:`~torchrl.collectors.SyncDataCollector`), the content of the ``"next"``
399+
tensordict is brought to the root through the :func:`~torchrl.envs.utils.step_mdp`
400+
function when the agent resets its step count: ``t <- t+1``. You can read more
401+
about the environment API :ref:`here <Environment-API>`.
402+
403+
In most cases, there is no `True`-valued ``"done"`` state at the root since any
404+
done state will trigger a (partial) reset which will turn the ``"done"`` to ``False``.
405+
However, this is only true as long as resets are automatically performed. In some
406+
cases, partial resets will not trigger a reset, so we retain these data, which
407+
should have a considerably lower memory footprint than observations, for instance.
408+
409+
This format eliminates any ambiguity regarding the matching of an observation with
410+
its action, info, or done state.
411+
412+
Dimensionality of the Tensordict
413+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
414+
415+
During a rollout, all collected tensordicts will be stacked along a new dimension
416+
positioned at the end. Both collectors and environments will label this dimension
417+
with the ``"time"`` name. Here's an example:
418+
419+
>>> rollout = env.rollout(10, policy)
420+
>>> assert rollout.shape[-1] == 10
421+
>>> assert rollout.names[-1] == "time"
422+
423+
This ensures that the time dimension is clearly marked and easily identifiable
424+
in the data structure.
425+
426+
Special cases and footnotes
427+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
428+
429+
Multi-Agent data presentation
430+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
431+
432+
The multi-agent data formatting documentation can be accessed in the :ref:`MARL environment API <MARL-environment-API>` section.
433+
434+
Memory-based policies (RNNs and Transformers)
435+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
436+
437+
In the examples provided above, only ``env.step(data)`` generates data that
438+
needs to be read in the next step. However, in some cases, the policy also
439+
outputs information that will be required in the next step. This is typically
440+
the case for RNN-based policies, which output an action as well as a recurrent
441+
state that needs to be used in the next step.
442+
To accommodate this, we recommend users to adjust their RNN policy to write this
443+
data under the ``"next"`` entry of the tensordict. This ensures that this content
444+
will be brought to the root in the next step. More information can be found in
445+
:class:`~torchrl.modules.GRUModule` and :class:`~torchrl.modules.LSTMModule`.
446+
447+
Multi-step
448+
^^^^^^^^^^
449+
450+
Collectors allow users to skip steps when reading the data, accumulating reward
451+
for the upcoming n steps. This technique is popular in DQN-like algorithms like Rainbow.
452+
The :class:`~torchrl.data.postprocs.MultiStep` class performs this data transformation
453+
on batches coming out of collectors. In these cases, a check like the following
454+
will fail since the next observation is shifted by n steps:
455+
456+
>>> assert (data[..., 1:]["observation"] == data[..., :-1]["next", "observation"]).all()
457+
458+
What about memory requirements?
459+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
460+
461+
Implemented naively, this data format consumes approximately twice the memory
462+
that a flat representation would. In some memory-intensive settings
463+
(for example, in the :class:`~torchrl.data.datasets.AtariDQNExperienceReplay` dataset),
464+
we store only the ``T+1`` observation on disk and perform the formatting online at get time.
465+
In other cases, we assume that the 2x memory cost is a small price to pay for a
466+
clearer representation. However, generalizing the lazy representation for offline
467+
datasets would certainly be a beneficial feature to have, and we welcome
468+
contributions in this direction!
469+
301470
Datasets
302471
--------
303472

docs/source/reference/envs.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
torchrl.envs package
44
====================
55

6+
.. _Environment-API:
7+
68
TorchRL offers an API to handle environments of different backends, such as gym,
79
dm-control, dm-lab, model-based environments as well as custom environments.
810
The goal is to be able to swap environments in an experiment with little or no effort,
@@ -333,6 +335,8 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w
333335
Multi-agent environments
334336
------------------------
335337

338+
.. _MARL-environment-API:
339+
336340
.. currentmodule:: torchrl.envs
337341

338342
TorchRL supports multi-agent learning out-of-the-box.

torchrl/modules/tensordict_module/rnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,8 @@ class GRU(GRUBase):
828828
"""A PyTorch module for executing multiple steps of a multi-layer GRU. The module behaves exactly like :class:`torch.nn.GRU`, but this implementation is exclusively coded in Python.
829829
830830
.. note::
831-
This class is implemented without relying on CuDNN, which makes it compatible with :func:`torch.vmap` and :func:`torch.compile`.
831+
This class is implemented without relying on CuDNN, which makes it
832+
compatible with :func:`torch.vmap` and :func:`torch.compile`.
832833
833834
Examples:
834835
>>> import torch
@@ -1031,7 +1032,6 @@ class GRUModule(ModuleBase):
10311032
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
10321033
GRU layer except the last layer, with dropout probability equal to
10331034
:attr:`dropout`. Default: 0
1034-
proj_size: If ``> 0``, will use GRU with projections of corresponding size. Default: 0
10351035
python_based: If ``True``, will use a full Python implementation of the GRU cell. Default: ``False``
10361036
10371037
Keyword Args:

0 commit comments

Comments
 (0)