Skip to content

Commit 3141723

Browse files
authored
[Doc] More doc on trainers (#663)
1 parent d28a8c3 commit 3141723

File tree

3 files changed

+174
-33
lines changed

3 files changed

+174
-33
lines changed

README.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,6 @@
1111

1212
# TorchRL
1313

14-
## Disclaimer
15-
16-
This library is not officially released yet and is subject to change.
17-
18-
The features are available before an official release so that users and collaborators can get early access and provide feedback. No guarantee of stability, robustness or backward compatibility is provided.
19-
20-
---
21-
2214
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
2315

2416
It provides pytorch and **python-first**, low and high level abstractions for RL that are intended to be **efficient**, **modular**, **documented** and properly **tested**.
@@ -536,5 +528,12 @@ In the near future, we plan to:
536528
We welcome any contribution, should you want to contribute to these new features
537529
or any other, lister or not, in the issues section of this repository.
538530

531+
532+
## Disclaimer
533+
534+
This library is not officially released yet and is subject to change.
535+
536+
The features are available before an official release so that users and collaborators can get early access and provide feedback. No guarantee of stability, robustness or backward compatibility is provided.
537+
539538
# License
540539
TorchRL is licensed under the MIT License. See [LICENSE](LICENSE) for details.

docs/source/reference/trainers.rst

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,131 @@
33
torchrl.trainers package
44
========================
55

6+
The trainer package provides utilities to write re-usable training scripts. The core idea is to use a
7+
trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner
8+
loop the optimization steps. We believe this fits multiple RL training schemes, such as
9+
on-policy, off-policy, model-based and model-free solutions, offline RL and others.
10+
More particular cases, such as meta-RL algorithms may have training schemes that differ substentially.
11+
12+
The :obj:`trainer.train()` method can be sketched as follows:
13+
14+
.. code-block::
15+
:caption: Trainer loops
16+
17+
>>> for batch in collector:
18+
... batch = self._process_batch_hook(batch) # "batch_process"
19+
... self._pre_steps_log_hook(batch) # "pre_steps_log"
20+
... self._pre_optim_hook() # "pre_optim_steps"
21+
... for j in range(self.optim_steps_per_batch):
22+
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
23+
... losses = self.loss_module(sub_batch)
24+
... self._post_loss_hook(sub_batch) # "post_loss"
25+
... self.optimizer.step()
26+
... self.optimizer.zero_grad()
27+
... self._post_optim_hook() # "post_optim"
28+
... self._post_optim_log(sub_batch) # "post_optim_log"
29+
... self._post_steps_hook() # "post_steps"
30+
... self._post_steps_log_hook(batch) # "post_steps_log"
31+
32+
There are 9 hooks that can be used in a trainer loop: :obj:`"batch_process"`, :obj:`"pre_optim_steps"`,
33+
:obj:`"process_optim_batch"`, :obj:`"post_loss"`, :obj:`"post_steps"`, :obj:`"post_optim"`, :obj:`"pre_steps_log"`,
34+
:obj:`"post_steps_log"` and :obj:`"post_optim_log"`. They are indicated in the comments where they are applied.
35+
Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"` and :obj:`"process_optim_batch"`),
36+
**logging** (:obj:`"pre_steps_log"`, :obj:`"post_optim_log"` and :obj:`"post_steps_log"`) and **operations** hook
37+
(:obj:`"pre_optim_steps"`, :obj:`"post_loss"`, :obj:`"post_optim"` and :obj:`"post_steps"`).
38+
39+
- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
40+
a :obj:`TensorDict` object as input and update it given some strategy.
41+
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
42+
constants update), data subsampling (:doc:`BatchSubSampler`) and such.
43+
44+
- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
45+
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
46+
logger (:obj:`LogReward`) and such. Hooks should return a dictionary (or a None value) containing the
47+
data to log. The key :obj:`"log_pbar"` is reserved to boolean values indicating if the logged value
48+
should be displayed on the progression bar printed on the training log.
49+
50+
- **Operation** hooks are hooks that execute specific operations over the models, data collectors,
51+
target network updates and such. For instance, syncing the weights of the collectors using :obj:`UpdateWeights`
52+
or update the priority of the replay buffer using :obj:`ReplayBufferTrainer.update_priority` are examples
53+
of operation hooks. They are data-independent (they do not require a :obj:`TensorDict`
54+
input), they are just supposed to be executed once at every iteration (or every N iterations).
55+
56+
The hooks provided by TorchRL usually inherit from a common abstract class :obj:`TrainerHookBase`,
57+
and all implement three base methods: a :obj:`state_dict` and :obj:`load_state_dict` method for
58+
checkpointing and a :obj:`register` method that registers the hook at the default value in the
59+
trainer. This method takes a trainer and a module name as input. For instance, the following logging
60+
hook is executed every 10 calls to :obj:`"post_optim_log"`:
61+
62+
.. code-block::
63+
64+
>>> class LoggingHook(TrainerHookBase):
65+
... def __init__(self):
66+
... self.counter = 0
67+
...
68+
... def register(self, trainer, name):
69+
... trainer.register_module(self, "logging_hook")
70+
... trainer.register_op("post_optim_log", self)
71+
...
72+
... def save_dict(self):
73+
... return {"counter": self.counter}
74+
...
75+
... def load_state_dict(self, state_dict):
76+
... self.counter = state_dict["counter"]
77+
...
78+
... def __call__(self, batch):
79+
... if self.counter % 10 == 0:
80+
... self.counter += 1
81+
... out = {"some_value": batch["some_value"].item(), "log_pbar": False}
82+
... else:
83+
... out = None
84+
... self.counter += 1
85+
... return out
86+
87+
Checkpointing
88+
-------------
89+
90+
The trainer class and hooks support checkpointing, which can be achieved either
91+
using the `torchsnapshot <https://github.com/pytorch/torchsnapshot/>`_ backend or
92+
the regular torch backend. This can be controlled via the global variable :obj:`CKPT_BACKEND`:
93+
94+
.. code-block::
95+
96+
$ CKPT_BACKEND=torch python script.py
97+
98+
which defaults to :obj:`torchsnapshot`. The advantage of torchsnapshot over pytorch
99+
is that it is a more flexible API, which supports distributed checkpointing and
100+
also allows users to load tensors from a file stored on disk to a tensor with a
101+
physical storage (which pytorch currently does not support). This allows, for instance,
102+
to load tensors from and to a replay buffer that would otherwise not fit in memory.
103+
104+
When building a trainer, one can provide a file path where the checkpoints are to
105+
be written. With the :obj:`torchsnapshot` backend, a directory path is expected,
106+
whereas the :obj:`torch` backend expects a file path (typically a :obj:`.pt` file).
107+
108+
.. code-block::
109+
110+
>>> filepath = "path/to/dir/"
111+
>>> trainer = Trainer(
112+
... collector=collector,
113+
... total_frames=total_frames,
114+
... frame_skip=frame_skip,
115+
... loss_module=loss_module,
116+
... optimizer=optimizer,
117+
... save_trainer_file=filepath,
118+
... )
119+
>>> select_keys = SelectKeys(["action", "observation"])
120+
>>> select_keys.register(trainer)
121+
>>> # to save to a path
122+
>>> trainer.save_trainer(True)
123+
>>> # to load from a path
124+
>>> trainer.load_from_file(filepath)
125+
126+
The :obj:`Trainer.train()` method can be used to execute the above loop with all of
127+
its hooks, although using the :obj:`Trainer` class for its checkpointing capability
128+
only is also a perfectly valid use.
129+
130+
6131
Trainer and hooks
7132
-----------------
8133

torchrl/trainers/trainers.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import abc
89
import pathlib
910
import warnings
1011
from collections import OrderedDict, defaultdict
@@ -60,6 +61,22 @@
6061
TYPE_DESCR = {float: "4.4f", int: ""}
6162

6263

64+
class TrainerHookBase:
65+
"""An abstract hooking class for torchrl Trainer class."""
66+
67+
@abc.abstractmethod
68+
def state_dict(self) -> Dict[str, Any]:
69+
raise NotImplementedError
70+
71+
@abc.abstractmethod
72+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
73+
raise NotImplementedError
74+
75+
@abc.abstractmethod
76+
def register(self, trainer: Trainer, name: str):
77+
raise NotImplementedError
78+
79+
6380
class Trainer:
6481
"""A generic Trainer class.
6582
@@ -540,7 +557,7 @@ def _load_list_state_dict(list_state_dict, hook_list):
540557
hook_list[i] = (item, kwargs)
541558

542559

543-
class SelectKeys:
560+
class SelectKeys(TrainerHookBase):
544561
"""Selects keys in a TensorDict batch.
545562
546563
Args:
@@ -580,12 +597,12 @@ def state_dict(self) -> Dict[str, Any]:
580597
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
581598
pass
582599

583-
def register(self, trainer) -> None:
600+
def register(self, trainer, name="select_keys") -> None:
584601
trainer.register_op("batch_process", self)
585-
trainer.register_module("select_keys", self)
602+
trainer.register_module(name, self)
586603

587604

588-
class ReplayBufferTrainer:
605+
class ReplayBufferTrainer(TrainerHookBase):
589606
"""Replay buffer hook provider.
590607
591608
Args:
@@ -673,14 +690,14 @@ def state_dict(self) -> Dict[str, Any]:
673690
def load_state_dict(self, state_dict) -> None:
674691
self.replay_buffer.load_state_dict(state_dict["replay_buffer"])
675692

676-
def register(self, trainer: Trainer):
693+
def register(self, trainer: Trainer, name: str = "replay_buffer"):
677694
trainer.register_op("batch_process", self.extend)
678695
trainer.register_op("process_optim_batch", self.sample)
679696
trainer.register_op("post_loss", self.update_priority)
680-
trainer.register_module("replay_buffer", self)
697+
trainer.register_module(name, self)
681698

682699

683-
class ClearCudaCache:
700+
class ClearCudaCache(TrainerHookBase):
684701
"""Clears cuda cache at a given interval.
685702
686703
Examples:
@@ -699,7 +716,7 @@ def __call__(self, *args, **kwargs):
699716
torch.cuda.empty_cache()
700717

701718

702-
class LogReward:
719+
class LogReward(TrainerHookBase):
703720
"""Reward logger hook.
704721
705722
Args:
@@ -730,12 +747,12 @@ def __call__(self, batch: TensorDictBase) -> Dict:
730747
"log_pbar": self.log_pbar,
731748
}
732749

733-
def register(self, trainer: Trainer):
750+
def register(self, trainer: Trainer, name: str = "log_reward"):
734751
trainer.register_op("pre_steps_log", self)
735-
trainer.register_module("log_reward", self)
752+
trainer.register_module(name, self)
736753

737754

738-
class RewardNormalizer:
755+
class RewardNormalizer(TrainerHookBase):
739756
"""Reward normalizer hook.
740757
741758
Args:
@@ -822,10 +839,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
822839
for key, value in state_dict.items():
823840
setattr(self, key, value)
824841

825-
def register(self, trainer: Trainer):
842+
def register(self, trainer: Trainer, name: str = "reward_normalizer"):
826843
trainer.register_op("batch_process", self.update_reward_stats)
827844
trainer.register_op("process_optim_batch", self.normalize_reward)
828-
trainer.register_module("reward_normalizer", self)
845+
trainer.register_module(name, self)
829846

830847

831848
def mask_batch(batch: TensorDictBase) -> TensorDictBase:
@@ -849,7 +866,7 @@ def mask_batch(batch: TensorDictBase) -> TensorDictBase:
849866
return batch
850867

851868

852-
class BatchSubSampler:
869+
class BatchSubSampler(TrainerHookBase):
853870
"""Data subsampler for online RL algorithms.
854871
855872
This class subsamples a part of a whole batch of data just collected from the
@@ -969,15 +986,15 @@ def state_dict(self) -> Dict[str, Any]:
969986
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
970987
pass
971988

972-
def register(self, trainer):
989+
def register(self, trainer: Trainer, name: str = "batch_subsampler"):
973990
trainer.register_op(
974991
"process_optim_batch",
975992
self,
976993
)
977-
trainer.register_module("batch_subsampler", self)
994+
trainer.register_module(name, self)
978995

979996

980-
class Recorder:
997+
class Recorder(TrainerHookBase):
981998
"""Recorder hook for Trainer.
982999
9831000
Args:
@@ -1092,15 +1109,15 @@ def load_state_dict(self, state_dict: Dict) -> None:
10921109
self._count = state_dict["_count"]
10931110
self.recorder.load_state_dict(state_dict["recorder_state_dict"])
10941111

1095-
def register(self, trainer: Trainer):
1096-
trainer.register_module("recorder", self)
1112+
def register(self, trainer: Trainer, name: str = "recorder"):
1113+
trainer.register_module(name, self)
10971114
trainer.register_op(
10981115
"post_steps_log",
10991116
self,
11001117
)
11011118

11021119

1103-
class UpdateWeights:
1120+
class UpdateWeights(TrainerHookBase):
11041121
"""A collector weights update hook class.
11051122
11061123
This hook must be used whenever the collector policy weights sit on a
@@ -1130,8 +1147,8 @@ def __call__(self):
11301147
if self.counter % self.update_weights_interval == 0:
11311148
self.collector.update_policy_weights_()
11321149

1133-
def register(self, trainer: Trainer):
1134-
trainer.register_module("update_weights", self)
1150+
def register(self, trainer: Trainer, name: str = "update_weights"):
1151+
trainer.register_module(name, self)
11351152
trainer.register_op(
11361153
"post_steps",
11371154
self,
@@ -1144,7 +1161,7 @@ def load_state_dict(self, state_dict) -> None:
11441161
return
11451162

11461163

1147-
class CountFramesLog:
1164+
class CountFramesLog(TrainerHookBase):
11481165
"""A frame counter hook.
11491166
11501167
Args:
@@ -1178,8 +1195,8 @@ def __call__(self, batch: TensorDictBase) -> Dict:
11781195
self.frame_count += current_frames
11791196
return {"n_frames": self.frame_count, "log_pbar": self.log_pbar}
11801197

1181-
def register(self, trainer: Trainer):
1182-
trainer.register_module("count_frames_log", self)
1198+
def register(self, trainer: Trainer, name: str = "count_frames_log"):
1199+
trainer.register_module(name, self)
11831200
trainer.register_op(
11841201
"pre_steps_log",
11851202
self,

0 commit comments

Comments
 (0)