@@ -298,6 +298,175 @@ before calling :meth:`~torchrl.data.ReplayBuffer.load_state_dict`. The drawback
298
298
of this method is that it will struggle to save big data structures, which is a
299
299
common setting when using replay buffers.
300
300
301
+ TorchRL Episode Data Format (TED)
302
+ ---------------------------------
303
+
304
+ In TorchRL, sequential data is consistently presented in a specific format, known
305
+ as the TorchRL Episode Data Format (TED). This format is crucial for the seamless
306
+ integration and functioning of various components within TorchRL.
307
+
308
+ Some components, such as replay buffers, are somewhat indifferent to the data
309
+ format. However, others, particularly environments, heavily depend on it for smooth operation.
310
+
311
+ Therefore, it's essential to understand the TED, its purpose, and how to interact
312
+ with it. This guide will provide a clear explanation of the TED, why it's used,
313
+ and how to effectively work with it.
314
+
315
+ The Rationale Behind TED
316
+ ~~~~~~~~~~~~~~~~~~~~~~~~
317
+
318
+ Formatting sequential data can be a complex task, especially in the realm of
319
+ Reinforcement Learning (RL). As practitioners, we often encounter situations
320
+ where data is delivered at the reset time (though not always), and sometimes data
321
+ is provided or discarded at the final step of the trajectory.
322
+
323
+ This variability means that we can observe data of different lengths in a dataset,
324
+ and it's not always immediately clear how to match each time step across the
325
+ various elements of this dataset. Consider the following ambiguous dataset structure:
326
+
327
+ >>> observation.shape
328
+ [200, 3]
329
+ >>> action.shape
330
+ [199, 4]
331
+ >>> info.shape
332
+ [200, 3]
333
+
334
+ At first glance, it seems that the info and observation were delivered
335
+ together (one of each at reset + one of each at each step call), as suggested by
336
+ the action having one less element. However, if info has one less element, we
337
+ must assume that it was either omitted at reset time or not delivered or recorded
338
+ for the last step of the trajectory. Without proper documentation of the data
339
+ structure, it's impossible to determine which info corresponds to which time step.
340
+
341
+ Complicating matters further, some datasets provide inconsistent data formats,
342
+ where ``observations `` or ``infos `` are missing at the start or end of the
343
+ rollout, and this behavior is often not documented.
344
+ The primary aim of TED is to eliminate these ambiguities by providing a clear
345
+ and consistent data representation.
346
+
347
+ The structure of TED
348
+ ~~~~~~~~~~~~~~~~~~~~
349
+
350
+ TED is built upon the canonical definition of a Markov Decision Process (MDP) in RL contexts.
351
+ At each step, an observation conditions an action that results in (1) a new
352
+ observation, (2) an indicator of task completion (terminated, truncated, done),
353
+ and (3) a reward signal.
354
+
355
+ Some elements may be missing (for example, the reward is optional in imitation
356
+ learning contexts), or additional information may be passed through a state or
357
+ info container. In some cases, additional information is required to get the
358
+ observation during a call to ``step `` (for instance, in stateless environment simulators). Furthermore,
359
+ in certain scenarios, an "action" (or any other data) cannot be represented as a
360
+ single tensor and needs to be organized differently. For example, in Multi-Agent RL
361
+ settings, actions, observations, rewards, and completion signals may be composite.
362
+
363
+ TED accommodates all these scenarios with a single, uniform, and unambiguous
364
+ format. We distinguish what happens at time step ``t `` and ``t+1 `` by setting a
365
+ limit at the time the action is executed. In other words, everything that was
366
+ present before ``env.step `` was called belongs to ``t ``, and everything that
367
+ comes after belongs to ``t+1 ``.
368
+
369
+ The general rule is that everything that belongs to time step ``t `` is stored
370
+ at the root of the tensordict, while everything that belongs to ``t+1 `` is stored
371
+ in the ``"next" `` entry of the tensordict. Here's an example:
372
+
373
+ >>> data = env.reset()
374
+ >>> data = policy(data)
375
+ >>> print (env.step(data))
376
+ TensorDict(
377
+ fields={
378
+ action: Tensor(...), # The action taken at time t
379
+ done: Tensor(...), # The done state when the action was taken (at reset)
380
+ next: TensorDict( # all of this content comes from the call to `step`
381
+ fields={
382
+ done: Tensor(...), # The done state after the action has been taken
383
+ observation: Tensor(...), # The observation resulting from the action
384
+ reward: Tensor(...), # The reward resulting from the action
385
+ terminated: Tensor(...), # The terminated state after the action has been taken
386
+ truncated: Tensor(...), # The truncated state after the action has been taken
387
+ batch_size=torch.Size([]),
388
+ device=cpu,
389
+ is_shared=False),
390
+ observation: Tensor(...), # the observation at reset
391
+ terminated: Tensor(...), # the terminated at reset
392
+ truncated: Tensor(...), # the truncated at reset
393
+ batch_size=torch.Size([]),
394
+ device=cpu,
395
+ is_shared=False)
396
+
397
+ During a rollout (either using :class: `~torchrl.envs.EnvBase ` or
398
+ :class: `~torchrl.collectors.SyncDataCollector `), the content of the ``"next" ``
399
+ tensordict is brought to the root through the :func: `~torchrl.envs.utils.step_mdp `
400
+ function when the agent resets its step count: ``t <- t+1 ``. You can read more
401
+ about the environment API :ref: `here <Environment-API >`.
402
+
403
+ In most cases, there is no `True `-valued ``"done" `` state at the root since any
404
+ done state will trigger a (partial) reset which will turn the ``"done" `` to ``False ``.
405
+ However, this is only true as long as resets are automatically performed. In some
406
+ cases, partial resets will not trigger a reset, so we retain these data, which
407
+ should have a considerably lower memory footprint than observations, for instance.
408
+
409
+ This format eliminates any ambiguity regarding the matching of an observation with
410
+ its action, info, or done state.
411
+
412
+ Dimensionality of the Tensordict
413
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
414
+
415
+ During a rollout, all collected tensordicts will be stacked along a new dimension
416
+ positioned at the end. Both collectors and environments will label this dimension
417
+ with the ``"time" `` name. Here's an example:
418
+
419
+ >>> rollout = env.rollout(10 , policy)
420
+ >>> assert rollout.shape[- 1 ] == 10
421
+ >>> assert rollout.names[- 1 ] == " time"
422
+
423
+ This ensures that the time dimension is clearly marked and easily identifiable
424
+ in the data structure.
425
+
426
+ Special cases and footnotes
427
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
428
+
429
+ Multi-Agent data presentation
430
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
431
+
432
+ The multi-agent data formatting documentation can be accessed in the :ref: `MARL environment API <MARL-environment-API >` section.
433
+
434
+ Memory-based policies (RNNs and Transformers)
435
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
436
+
437
+ In the examples provided above, only ``env.step(data) `` generates data that
438
+ needs to be read in the next step. However, in some cases, the policy also
439
+ outputs information that will be required in the next step. This is typically
440
+ the case for RNN-based policies, which output an action as well as a recurrent
441
+ state that needs to be used in the next step.
442
+ To accommodate this, we recommend users to adjust their RNN policy to write this
443
+ data under the ``"next" `` entry of the tensordict. This ensures that this content
444
+ will be brought to the root in the next step. More information can be found in
445
+ :class: `~torchrl.modules.GRUModule ` and :class: `~torchrl.modules.LSTMModule `.
446
+
447
+ Multi-step
448
+ ^^^^^^^^^^
449
+
450
+ Collectors allow users to skip steps when reading the data, accumulating reward
451
+ for the upcoming n steps. This technique is popular in DQN-like algorithms like Rainbow.
452
+ The :class: `~torchrl.data.postprocs.MultiStep ` class performs this data transformation
453
+ on batches coming out of collectors. In these cases, a check like the following
454
+ will fail since the next observation is shifted by n steps:
455
+
456
+ >>> assert (data[... , 1 :][" observation" ] == data[... , :- 1 ][" next" , " observation" ]).all()
457
+
458
+ What about memory requirements?
459
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
460
+
461
+ Implemented naively, this data format consumes approximately twice the memory
462
+ that a flat representation would. In some memory-intensive settings
463
+ (for example, in the :class: `~torchrl.data.datasets.AtariDQNExperienceReplay ` dataset),
464
+ we store only the ``T+1 `` observation on disk and perform the formatting online at get time.
465
+ In other cases, we assume that the 2x memory cost is a small price to pay for a
466
+ clearer representation. However, generalizing the lazy representation for offline
467
+ datasets would certainly be a beneficial feature to have, and we welcome
468
+ contributions in this direction!
469
+
301
470
Datasets
302
471
--------
303
472
0 commit comments