Skip to content

Commit 90c8e40

Browse files
author
Vincent Moens
committed
[BugFix] Better account of composite distributions in PPO
ghstack-source-id: 3d86f99 Pull Request resolved: #2622
1 parent d537dcb commit 90c8e40

File tree

3 files changed

+58
-22
lines changed

3 files changed

+58
-22
lines changed

torchrl/objectives/ppo.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TensorDictParams,
1919
)
2020
from tensordict.nn import (
21+
CompositeDistribution,
2122
dispatch,
2223
ProbabilisticTensorDictModule,
2324
ProbabilisticTensorDictSequential,
@@ -33,6 +34,7 @@
3334
_clip_value_loss,
3435
_GAMMA_LMBDA_DEPREC_ERROR,
3536
_reduce,
37+
_sum_td_features,
3638
default_value_kwargs,
3739
distance_loss,
3840
ValueEstimators,
@@ -462,9 +464,13 @@ def reset(self) -> None:
462464

463465
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
464466
try:
465-
entropy = dist.entropy()
467+
if isinstance(dist, CompositeDistribution):
468+
kwargs = {"aggregate_probabilities": False, "include_sum": False}
469+
else:
470+
kwargs = {}
471+
entropy = dist.entropy(**kwargs)
466472
if is_tensor_collection(entropy):
467-
entropy = entropy.get(dist.entropy_key)
473+
entropy = _sum_td_features(entropy)
468474
except NotImplementedError:
469475
x = dist.rsample((self.samples_mc_entropy,))
470476
log_prob = dist.log_prob(x)
@@ -497,13 +503,20 @@ def _log_weight(
497503
if isinstance(action, torch.Tensor):
498504
log_prob = dist.log_prob(action)
499505
else:
500-
maybe_log_prob = dist.log_prob(tensordict)
501-
if not isinstance(maybe_log_prob, torch.Tensor):
502-
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
503-
# be a tensor
504-
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
506+
if isinstance(dist, CompositeDistribution):
507+
is_composite = True
508+
kwargs = {
509+
"inplace": False,
510+
"aggregate_probabilities": False,
511+
"include_sum": False,
512+
}
505513
else:
506-
log_prob = maybe_log_prob
514+
is_composite = False
515+
kwargs = {}
516+
log_prob = dist.log_prob(tensordict, **kwargs)
517+
if is_composite and not isinstance(prev_log_prob, TensorDict):
518+
log_prob = _sum_td_features(log_prob)
519+
log_prob.view_as(prev_log_prob)
507520

508521
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
509522
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
@@ -598,6 +611,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
598611
advantage = (advantage - loc) / scale
599612

600613
log_weight, dist, kl_approx = self._log_weight(tensordict)
614+
if is_tensor_collection(log_weight):
615+
log_weight = _sum_td_features(log_weight)
616+
log_weight = log_weight.view(advantage.shape)
601617
neg_loss = log_weight.exp() * advantage
602618
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
603619
if self.entropy_bonus:
@@ -1149,16 +1165,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
11491165
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
11501166
except NotImplementedError:
11511167
x = previous_dist.sample((self.samples_mc_kl,))
1152-
previous_log_prob = previous_dist.log_prob(x)
1153-
current_log_prob = current_dist.log_prob(x)
1168+
if isinstance(previous_dist, CompositeDistribution):
1169+
kwargs = {
1170+
"aggregate_probabilities": False,
1171+
"inplace": False,
1172+
"include_sum": False,
1173+
}
1174+
else:
1175+
kwargs = {}
1176+
previous_log_prob = previous_dist.log_prob(x, **kwargs)
1177+
current_log_prob = current_dist.log_prob(x, **kwargs)
11541178
if is_tensor_collection(current_log_prob):
1155-
previous_log_prob = previous_log_prob.get(
1156-
self.tensor_keys.sample_log_prob
1157-
)
1158-
current_log_prob = current_log_prob.get(
1159-
self.tensor_keys.sample_log_prob
1160-
)
1161-
1179+
previous_log_prob = _sum_td_features(previous_log_prob)
1180+
current_log_prob = _sum_td_features(current_log_prob)
11621181
kl = (previous_log_prob - current_log_prob).mean(0)
11631182
kl = kl.unsqueeze(-1)
11641183
neg_loss = neg_loss - self.beta * kl

torchrl/objectives/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,8 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize
615615
raise ValueError("Cannot group optimizers of different type.")
616616
params.extend(optimizer.param_groups)
617617
return cls(params)
618+
619+
620+
def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
621+
# Sum all features and return a tensor
622+
return data.sum(dim="feature", reduce=True)

torchrl/objectives/value/advantages.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
import torch
1616
from tensordict import TensorDictBase
1717
from tensordict.nn import (
18+
CompositeDistribution,
1819
dispatch,
20+
ProbabilisticTensorDictModule,
1921
set_skip_existing,
2022
TensorDictModule,
2123
TensorDictModuleBase,
2224
)
25+
from tensordict.nn.probabilistic import interaction_type
2326
from tensordict.utils import NestedKey
2427
from torch import Tensor
2528

@@ -74,14 +77,22 @@ def new_func(self, *args, **kwargs):
7477

7578

7679
def _call_actor_net(
77-
actor_net: TensorDictModuleBase,
80+
actor_net: ProbabilisticTensorDictModule,
7881
data: TensorDictBase,
7982
params: TensorDictBase,
8083
log_prob_key: NestedKey,
8184
):
82-
# TODO: extend to handle time dimension (and vmap?)
83-
log_pi = actor_net(data.select(*actor_net.in_keys, strict=False)).get(log_prob_key)
84-
return log_pi
85+
dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False))
86+
if isinstance(dist, CompositeDistribution):
87+
kwargs = {
88+
"aggregate_probabilities": True,
89+
"inplace": False,
90+
"include_sum": False,
91+
}
92+
else:
93+
kwargs = {}
94+
s = actor_net._dist_sample(dist, interaction_type=interaction_type())
95+
return dist.log_prob(s, **kwargs)
8596

8697

8798
class ValueEstimatorBase(TensorDictModuleBase):
@@ -1771,7 +1782,8 @@ def forward(
17711782
data=tensordict,
17721783
params=None,
17731784
log_prob_key=self.tensor_keys.sample_log_prob,
1774-
).view_as(value)
1785+
)
1786+
log_pi = log_pi.view_as(value)
17751787

17761788
# Compute the V-Trace correction
17771789
done = tensordict.get(("next", self.tensor_keys.done))

0 commit comments

Comments
 (0)