Skip to content

Commit cb37521

Browse files
author
Vincent Moens
committed
[BugFix] adapt log-prob TD batch-size to advantage shape in PPO
ghstack-source-id: 8ccd12f Pull Request resolved: #2756
1 parent 2f8c118 commit cb37521

File tree

7 files changed

+272
-20
lines changed

7 files changed

+272
-20
lines changed

docs/source/reference/data.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,51 @@ should have a considerably lower memory footprint than observations, for instanc
585585
This format eliminates any ambiguity regarding the matching of an observation with
586586
its action, info, or done state.
587587

588+
A note on singleton dimensions in TED
589+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
590+
591+
.. _reward_done_singleton:
592+
593+
In TorchRL, the standard practice is that `done` states (including terminated and truncated) and rewards should have a
594+
dimension that can be expanded to match the shape of observations, states, and actions without recurring to anything
595+
else than repetition (i.e., the reward must have as many dimensions as the observation and/or action, or their
596+
embeddings).
597+
598+
Essentially, this format is acceptable (though not strictly enforced):
599+
600+
>>> print(rollout[t])
601+
... TensorDict(
602+
... fields={
603+
... action: Tensor(n_action),
604+
... done: Tensor(1), # The done state has a rightmost singleton dimension
605+
... next: TensorDict(
606+
... fields={
607+
... done: Tensor(1),
608+
... observation: Tensor(n_obs),
609+
... reward: Tensor(1), # The reward has a rightmost singleton dimension
610+
... terminated: Tensor(1),
611+
... truncated: Tensor(1),
612+
... batch_size=torch.Size([]),
613+
... device=cpu,
614+
... is_shared=False),
615+
... observation: Tensor(n_obs), # the observation at reset
616+
... terminated: Tensor(1), # the terminated at reset
617+
... truncated: Tensor(1), # the truncated at reset
618+
... batch_size=torch.Size([]),
619+
... device=cpu,
620+
... is_shared=False)
621+
622+
The rationale behind this is to ensure that the results of operations (such as value estimation) on observations and/or
623+
actions have the same number of dimensions as the reward and `done` state. This consistency allows subsequent operations
624+
to proceed without issues:
625+
626+
>>> state_value = f(observation)
627+
>>> next_state_value = state_value + reward
628+
629+
Without this singleton dimension at the end of the reward, broadcasting rules (which only work when tensors can be
630+
expanded from the left) would try to expand the reward on the left. This could lead to failures (at best) or introduce
631+
bugs (at worst).
632+
588633
Flattening TED to reduce memory consumption
589634
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
590635

docs/source/reference/objectives.rst

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ REDQ
151151
REDQLoss
152152

153153
CrossQ
154-
----
154+
------
155155

156156
.. autosummary::
157157
:toctree: generated/
@@ -160,7 +160,7 @@ CrossQ
160160
CrossQLoss
161161

162162
IQL
163-
----
163+
---
164164

165165
.. autosummary::
166166
:toctree: generated/
@@ -170,7 +170,7 @@ IQL
170170
DiscreteIQLLoss
171171

172172
CQL
173-
----
173+
---
174174

175175
.. autosummary::
176176
:toctree: generated/
@@ -189,7 +189,7 @@ GAIL
189189
GAILLoss
190190

191191
DT
192-
----
192+
--
193193

194194
.. autosummary::
195195
:toctree: generated/
@@ -199,7 +199,7 @@ DT
199199
OnlineDTLoss
200200

201201
TD3
202-
----
202+
---
203203

204204
.. autosummary::
205205
:toctree: generated/
@@ -208,7 +208,7 @@ TD3
208208
TD3Loss
209209

210210
TD3+BC
211-
----
211+
------
212212

213213
.. autosummary::
214214
:toctree: generated/
@@ -227,6 +227,85 @@ PPO
227227
ClipPPOLoss
228228
KLPENPPOLoss
229229

230+
Using PPO with multi-head action policies
231+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
232+
233+
In some cases, we have a single advantage value but more than one action undertaken. Each action has its own
234+
log-probability, and shape. For instance, it can be that the action space is structured as follows:
235+
236+
>>> action_td = TensorDict(
237+
... action0=Tensor(batch, n_agents, f0),
238+
... action1=Tensor(batch, n_agents, f1, f2),
239+
... batch_size=torch.Size((batch,))
240+
... )
241+
242+
where `f0`, `f1` and `f2` are some arbitrary integers.
243+
244+
Note that, in TorchRL, the tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
245+
has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will
246+
also have the shape of the replay buffer `batch_size`. The `n_agent` dimension, although common to each action, does not
247+
in general appear in the tensordict's batch-size.
248+
249+
There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the
250+
environment. For example, some environments have a shared done state among all agents. A more complete tensordict
251+
would in this case look like
252+
253+
>>> action_td = TensorDict(
254+
... action0=Tensor(batch, n_agents, f0),
255+
... action1=Tensor(batch, n_agents, f1, f2),
256+
... done=Tensor(batch, 1),
257+
... observation=Tensor(batch, n_agents, f3),
258+
... [...] # etc
259+
... batch_size=torch.Size((batch,))
260+
... )
261+
262+
Notice that `done` states and `reward` are usually flanked by a rightmost singleton dimension. See this :ref:`part of the doc <reward_done_singleton>`
263+
to learn more about this restriction.
264+
265+
The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`,
266+
:class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`.
267+
When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the
268+
beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not
269+
be aggregated but rather written as leaves in the tensordict.
270+
271+
The log-probability of our actions given their respective distributions may look like anything like
272+
273+
>>> action_td = TensorDict(
274+
... action0_log_prob=Tensor(batch, n_agents),
275+
... action1_log_prob=Tensor(batch, n_agents, f1),
276+
... batch_size=torch.Size((batch,))
277+
... )
278+
279+
or
280+
281+
>>> action_td = TensorDict(
282+
... action0_log_prob=Tensor(batch, n_agents),
283+
... action1_log_prob=Tensor(batch, n_agents),
284+
... batch_size=torch.Size((batch,))
285+
... )
286+
287+
ie, the number of dimensions of distributions log-probabilities generally varies from the sample's dimensionality to
288+
anything inferior to that, e.g. if the distribution is multivariate -- :class:`~torch.distributions.Dirichlet` for
289+
instance -- or an :class:`~torch.distributions.Independent` instance.
290+
The dimension of the tensordict, on the contrary, still matches the env's / replay-buffer's batch-size.
291+
292+
During a call to the PPO loss, the loss module will schematically execute the following set of operations:
293+
294+
>>> def ppo(tensordict):
295+
... prev_log_prob = tensordict.select(*log_prob_keys)
296+
... action = tensordict.select(*action_keys)
297+
... new_log_prob = dist.log_prob(action)
298+
... log_weight = new_log_prob - prev_log_prob
299+
... advantage = tensordict.get("advantage") # computed by GAE earlier
300+
... # attempt to map shape
301+
... log_weight.batch_size = advantage.batch_size[:-1]
302+
... log_weight = sum(log_weight.sum(dim="feature").values(True, True)) # get a single tensor of log_weights
303+
... return minimum(log_weight.exp() * advantage, log_weight.exp().clamp(1-eps, 1+eps) * advantage)
304+
305+
To appreciate what a PPO pipeline looks like with multihead policies, an example can be found in the library's
306+
`example directory <https://github.com/pytorch/rl/blob/main/examples/agents/composite_ppo.py>`__.
307+
308+
230309
A2C
231310
---
232311

@@ -258,6 +337,7 @@ Dreamer
258337

259338
Multi-agent objectives
260339
-----------------------
340+
261341
.. currentmodule:: torchrl.objectives.multiagent
262342

263343
These objectives are specific to multi-agent algorithms.
@@ -305,6 +385,7 @@ Returns
305385

306386
Utils
307387
-----
388+
308389
.. currentmodule:: torchrl.objectives
309390

310391
.. autosummary::

examples/agents/composite_ppo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
Multi-head Agent and PPO Loss
88
=============================
9+
910
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
1011
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
1112

test/test_cost.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313
from copy import deepcopy
1414
from dataclasses import asdict, dataclass
15+
from typing import Optional
1516

1617
import numpy as np
1718
import pytest
@@ -43,6 +44,7 @@
4344
from torchrl._utils import _standardize
4445
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
4546
from torchrl.data.postprocs.postprocs import MultiStep
47+
from torchrl.envs import EnvBase
4648
from torchrl.envs.model_based.dreamer import DreamerEnv
4749
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
4850
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
@@ -199,6 +201,70 @@ def get_devices():
199201
return devices
200202

201203

204+
class MARLEnv(EnvBase):
205+
def __init__(self):
206+
batch = self.batch = (3,)
207+
super().__init__(batch_size=batch)
208+
self.n_agents = n_agents = (4,)
209+
self.obs_feat = obs_feat = (5,)
210+
211+
self.full_observation_spec = Composite(
212+
observation=Unbounded(batch + n_agents + obs_feat),
213+
batch_size=batch,
214+
)
215+
self.full_done_spec = Composite(
216+
done=Unbounded(batch + (1,), dtype=torch.bool),
217+
terminated=Unbounded(batch + (1,), dtype=torch.bool),
218+
truncated=Unbounded(batch + (1,), dtype=torch.bool),
219+
batch_size=batch,
220+
)
221+
222+
self.act_feat_dirich = act_feat_dirich = (
223+
10,
224+
2,
225+
)
226+
self.act_feat_categ = act_feat_categ = (7,)
227+
self.full_action_spec = Composite(
228+
dirich=Unbounded(batch + n_agents + act_feat_dirich),
229+
categ=Unbounded(batch + n_agents + act_feat_categ),
230+
batch_size=batch,
231+
)
232+
233+
self.full_reward_spec = Composite(
234+
reward=Unbounded(batch + n_agents + (1,)), batch_size=batch
235+
)
236+
237+
@classmethod
238+
def make_composite_dist(cls):
239+
dist_cstr = functools.partial(
240+
CompositeDistribution,
241+
distribution_map={
242+
"dirich": lambda concentration: torch.distributions.Independent(
243+
torch.distributions.Dirichlet(concentration), 1
244+
),
245+
"categ": torch.distributions.Categorical,
246+
},
247+
)
248+
return ProbabilisticTensorDictModule(
249+
in_keys=["params"],
250+
out_keys=["dirich", "categ"],
251+
distribution_class=dist_cstr,
252+
return_log_prob=True,
253+
)
254+
255+
def _step(
256+
self,
257+
tensordict: TensorDictBase,
258+
) -> TensorDictBase:
259+
...
260+
261+
def _reset(self, tensordic):
262+
...
263+
264+
def _set_seed(self, seed: Optional[int]):
265+
...
266+
267+
202268
class LossModuleTestBase:
203269
@pytest.fixture(scope="class", autouse=True)
204270
def _composite_log_prob(self):
@@ -9238,6 +9304,40 @@ def mixture_constructor(logits, loc, scale):
92389304
loss = ppo(data)
92399305
loss.sum(reduce=True)
92409306

9307+
def test_ppo_marl_aggregate(self):
9308+
env = MARLEnv()
9309+
9310+
def primer(td):
9311+
params = TensorDict(
9312+
dirich=TensorDict(concentration=env.action_spec["dirich"].one()),
9313+
categ=TensorDict(logits=env.action_spec["categ"].one()),
9314+
batch_size=td.batch_size,
9315+
)
9316+
td.set("params", params)
9317+
return td
9318+
9319+
policy = ProbabilisticTensorDictSequential(
9320+
primer,
9321+
env.make_composite_dist(),
9322+
# return_composite=True,
9323+
)
9324+
output = policy(env.fake_tensordict())
9325+
assert output.shape == env.batch_size
9326+
assert output["dirich_log_prob"].shape == env.batch_size + env.n_agents
9327+
assert output["categ_log_prob"].shape == env.batch_size + env.n_agents
9328+
9329+
output["advantage"] = output["next", "reward"].clone()
9330+
output["value_target"] = output["next", "reward"].clone()
9331+
critic = TensorDictModule(
9332+
lambda obs: obs.new_zeros((*obs.shape[:-1], 1)),
9333+
in_keys=list(env.full_observation_spec.keys(True, True)),
9334+
out_keys=["state_value"],
9335+
)
9336+
ppo = ClipPPOLoss(actor_network=policy, critic_network=critic)
9337+
ppo.set_keys(action=list(env.full_action_spec.keys(True, True)))
9338+
assert isinstance(ppo.tensor_keys.action, list)
9339+
ppo(output)
9340+
92419341

92429342
class TestA2C(LossModuleTestBase):
92439343
seed = 0

0 commit comments

Comments
 (0)