Skip to content

Commit 5a2d9e2

Browse files
author
Vincent Moens
committed
[Doc] Fix broken links and formatting issues in doc
ghstack-source-id: 4e3f84f Pull Request resolved: #2574
1 parent 83a7a57 commit 5a2d9e2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+402
-397
lines changed

docs/source/index.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ TorchRL provides pytorch and python-first, low and high level abstractions for R
2222
The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.
2323

2424
This repo attempts to align with the existing pytorch ecosystem libraries in that it has a "dataset pillar"
25-
:doc:`(environments) <reference/envs>`,
26-
:ref:`transforms <reference/envs:Transforms>`,
27-
:doc:`models <reference/modules>`,
25+
:ref:`(environments) <Environment-API>`,
26+
:ref:`transforms <transforms>`,
27+
:ref:`models <ref_modules>`,
2828
data utilities (e.g. collectors and containers), etc.
2929
TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch).
3030
Common environment libraries (e.g. OpenAI gym) are only optional.
3131

3232
On the low-level end, torchrl comes with a set of highly re-usable functionals
33-
for :doc:`cost functions <reference/objectives>`, :ref:`returns <reference/objectives:Returns>` and data processing.
33+
for :ref:`cost functions <ref_objectives>`, :ref:`returns <ref_returns>` and data processing.
3434

3535
TorchRL aims at a high modularity and good runtime performance.
3636

docs/source/reference/data.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ not predictable.
944944
MultiCategorical
945945
MultiOneHot
946946
NonTensor
947-
OneHotDiscrete
947+
OneHot
948948
Stacked
949949
StackedComposite
950950
Unbounded
@@ -1050,7 +1050,7 @@ and the tree can be expanded for each of these. The following figure shows how t
10501050

10511051
BinaryToDecimal
10521052
HashToInt
1053-
MCTSForeset
1053+
MCTSForest
10541054
QueryModule
10551055
RandomProjectionHash
10561056
SipHash

docs/source/reference/objectives.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ QMixer
274274

275275
Returns
276276
-------
277+
278+
.. _ref_returns:
279+
277280
.. currentmodule:: torchrl.objectives.value
278281

279282
.. autosummary::

sota-implementations/decision_transformer/lamb.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ class Lamb(Optimizer):
1515
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
1616
Arguments:
1717
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
18-
lr (float, optional): learning rate. (default: 1e-3)
18+
lr (:obj:`float`, optional): learning rate. (default: 1e-3)
1919
betas (Tuple[float, float], optional): coefficients used for computing
2020
running averages of gradient and its norm. (default: (0.9, 0.999))
21-
eps (float, optional): term added to the denominator to improve
21+
eps (:obj:`float`, optional): term added to the denominator to improve
2222
numerical stability. (default: 1e-8)
23-
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
23+
weight_decay (:obj:`float`, optional): weight decay (L2 penalty) (default: 0)
2424
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
2525
calculating running averages of gradient. (default: True)
26-
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
26+
max_grad_norm (:obj:`float`, optional): value used to clip global grad norm (default: 1.0)
2727
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
2828
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
2929
weight decay parameter (default: False)

torchrl/collectors/collectors.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,11 +1384,13 @@ class _MultiDataCollector(DataCollectorBase):
13841384
instances) it will be wrapped in a `nn.Module` first.
13851385
Then, the collector will try to assess if these
13861386
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
1387+
13871388
- If the policy forward signature matches any of ``forward(self, tensordict)``,
13881389
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
13891390
any typing with a single argument typed as a subclass of ``TensorDictBase``)
13901391
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
1391-
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
1392+
- In all other cases an attempt to wrap it will be undergone as such:
1393+
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
13921394
13931395
Keyword Args:
13941396
frames_per_batch (int): A keyword-only argument representing the
@@ -1476,7 +1478,7 @@ class _MultiDataCollector(DataCollectorBase):
14761478
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
14771479
will be called before (sync) or after (async) each data collection.
14781480
Defaults to ``False``.
1479-
preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
1481+
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
14801482
that will be allowed to finished collecting their rollout before the rest are forced to end early.
14811483
num_threads (int, optional): number of threads for this process.
14821484
Defaults to the number of workers.
@@ -2093,11 +2095,13 @@ class MultiSyncDataCollector(_MultiDataCollector):
20932095
trajectory and the start of the next collection.
20942096
This class can be safely used with online RL sota-implementations.
20952097
2096-
.. note:: Python requires multiprocessed code to be instantiated within a main guard:
2098+
.. note::
2099+
Python requires multiprocessed code to be instantiated within a main guard:
20972100
20982101
>>> from torchrl.collectors import MultiSyncDataCollector
20992102
>>> if __name__ == "__main__":
21002103
... # Create your collector here
2104+
... collector = MultiSyncDataCollector(...)
21012105
21022106
See https://docs.python.org/3/library/multiprocessing.html for more info.
21032107
@@ -2125,8 +2129,8 @@ class MultiSyncDataCollector(_MultiDataCollector):
21252129
... if i == 2:
21262130
... print(data)
21272131
... break
2128-
... collector.shutdown()
2129-
... del collector
2132+
>>> collector>shutdown()
2133+
>>> del collector
21302134
TensorDict(
21312135
fields={
21322136
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
@@ -2753,7 +2757,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
27532757
update_at_each_batch (boolm optional): if ``True``, :meth:`~.update_policy_weight_()`
27542758
will be called before (sync) or after (async) each data collection.
27552759
Defaults to ``False``.
2756-
preemptive_threshold (float, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
2760+
preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
27572761
that will be allowed to finished collecting their rollout before the rest are forced to end early.
27582762
num_threads (int, optional): number of threads for this process.
27592763
Defaults to the number of workers.

torchrl/data/map/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Tree(TensorClass["nocast"]):
4848
If there are multiple actions taken at this node, subtrees are stored in the corresponding
4949
entry. Rollouts can be reconstructed using the :meth:`~.rollout_from_path` method.
5050
node (TensorDict): Data defining this node (e.g., observations) before the next branching.
51-
Entries usually matches the ``in_keys`` in ``MCTSForeset.node_map``.
51+
Entries usually matches the ``in_keys`` in ``MCTSForest.node_map``.
5252
subtree (Tree): A stack of subtrees produced when actions are taken.
5353
num_children (int): The number of child nodes (read-only).
5454
is_terminal (bool): whether the tree has children nodes (read-only).

torchrl/data/postprocs/postprocs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class MultiStep(nn.Module):
9090
It is an identity transform whenever :attr:`n_steps` is 0.
9191
9292
Args:
93-
gamma (float): Discount factor for return computation
93+
gamma (:obj:`float`): Discount factor for return computation
9494
n_steps (integer): maximum look-ahead steps.
9595
9696
.. note:: This class is meant to be used within a ``DataCollector``.

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -897,16 +897,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
897897
898898
All arguments are keyword-only arguments.
899899
900-
Presented in
901-
"Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
902-
Prioritized experience replay."
903-
(https://arxiv.org/abs/1511.05952)
900+
Presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015.
901+
Prioritized experience replay." (https://arxiv.org/abs/1511.05952)
904902
905903
Args:
906-
alpha (float): exponent α determines how much prioritization is used,
904+
alpha (:obj:`float`): exponent α determines how much prioritization is used,
907905
with α = 0 corresponding to the uniform case.
908-
beta (float): importance sampling negative exponent.
909-
eps (float): delta added to the priorities to ensure that the buffer
906+
beta (:obj:`float`): importance sampling negative exponent.
907+
eps (:obj:`float`): delta added to the priorities to ensure that the buffer
910908
does not contain null priorities.
911909
storage (Storage, optional): the storage to be used. If none is provided
912910
a default :class:`~torchrl.data.replay_buffers.ListStorage` with
@@ -1366,10 +1364,10 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer):
13661364
tensordict to be passed to it with its new priority value.
13671365
13681366
Keyword Args:
1369-
alpha (float): exponent α determines how much prioritization is used,
1367+
alpha (:obj:`float`): exponent α determines how much prioritization is used,
13701368
with α = 0 corresponding to the uniform case.
1371-
beta (float): importance sampling negative exponent.
1372-
eps (float): delta added to the priorities to ensure that the buffer
1369+
beta (:obj:`float`): importance sampling negative exponent.
1370+
eps (:obj:`float`): delta added to the priorities to ensure that the buffer
13731371
does not contain null priorities.
13741372
storage (Storage, optional): the storage to be used. If none is provided
13751373
a default :class:`~torchrl.data.replay_buffers.ListStorage` with

torchrl/data/replay_buffers/samplers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ class PrioritizedSampler(Sampler):
298298
299299
Args:
300300
max_capacity (int): maximum capacity of the buffer.
301-
alpha (float): exponent α determines how much prioritization is used,
301+
alpha (:obj:`float`): exponent α determines how much prioritization is used,
302302
with α = 0 corresponding to the uniform case.
303-
beta (float): importance sampling negative exponent.
304-
eps (float, optional): delta added to the priorities to ensure that the buffer
303+
beta (:obj:`float`): importance sampling negative exponent.
304+
eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
305305
does not contain null priorities. Defaults to 1e-8.
306306
reduction (str, optional): the reduction method for multidimensional
307307
tensordicts (ie stored trajectory). Can be one of "max", "min",
@@ -1652,10 +1652,10 @@ class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler):
16521652
:meth:`~.update_priority`.
16531653
16541654
Args:
1655-
alpha (float): exponent α determines how much prioritization is used,
1655+
alpha (:obj:`float`): exponent α determines how much prioritization is used,
16561656
with α = 0 corresponding to the uniform case.
1657-
beta (float): importance sampling negative exponent.
1658-
eps (float, optional): delta added to the priorities to ensure that the buffer
1657+
beta (:obj:`float`): importance sampling negative exponent.
1658+
eps (:obj:`float`, optional): delta added to the priorities to ensure that the buffer
16591659
does not contain null priorities. Defaults to 1e-8.
16601660
reduction (str, optional): the reduction method for multidimensional
16611661
tensordicts (i.e., stored trajectory). Can be one of "max", "min",

torchrl/data/rlhf/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class ConstantKLController(KLControllerBase):
4141
with.
4242
4343
Keyword Arguments:
44-
kl_coef (float): The coefficient to multiply KL with when calculating the
44+
kl_coef (:obj:`float`): The coefficient to multiply KL with when calculating the
4545
reward.
4646
model (nn.Module, optional): wrapped model that needs to be controlled.
4747
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
@@ -73,8 +73,8 @@ class AdaptiveKLController(KLControllerBase):
7373
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".
7474
7575
Keyword Arguments:
76-
init_kl_coef (float): The starting value of the coefficient.
77-
target (float): The target KL value. When the observed KL is smaller, the
76+
init_kl_coef (:obj:`float`): The starting value of the coefficient.
77+
target (:obj:`float`): The target KL value. When the observed KL is smaller, the
7878
coefficient is decreased, thereby relaxing the KL penalty in the training
7979
objective and allowing the model to stray further from the reference model.
8080
When the observed KL is greater than the target, the KL coefficient is
@@ -146,10 +146,10 @@ class RolloutFromModel:
146146
reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given
147147
``input_ids`` and ``attention_mask``, calculates rewards for each token and
148148
end_scores (the reward for the final token in each sequence).
149-
kl_coef: (float, optional): initial kl coefficient.
149+
kl_coef: (:obj:`float`, optional): initial kl coefficient.
150150
max_new_tokens (int, optional): the maximum length of the sequence.
151151
Defaults to 50.
152-
score_clip (float, optional): Scores from the reward model are clipped to the
152+
score_clip (:obj:`float`, optional): Scores from the reward model are clipped to the
153153
range ``(-score_clip, score_clip)``. Defaults to 10.
154154
kl_scheduler (KLControllerBase, optional): the KL coefficient scheduler.
155155
num_steps (int, optional): number of steps between two optimization.

0 commit comments

Comments
 (0)