Skip to content

Commit 00657f0

Browse files
felixsittenauerFelix Sittenauer
andauthored
[Feature] Adds per-head entropy coefficients to PPOLoss (#2972)
Co-authored-by: Felix Sittenauer <felix.sittenauer@helsing.ai>
1 parent 31bd542 commit 00657f0

File tree

2 files changed

+151
-19
lines changed

2 files changed

+151
-19
lines changed

test/test_cost.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
from torchrl.objectives.redq import REDQLoss
113113
from torchrl.objectives.reinforce import ReinforceLoss
114114
from torchrl.objectives.utils import (
115+
_sum_td_features,
115116
_vmap_func,
116117
HardUpdate,
117118
hold_out_net,
@@ -9734,7 +9735,8 @@ def mixture_constructor(logits, loc, scale):
97349735
reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)
97359736
),
97369737
)
9737-
ppo = cls(policy, value_operator)
9738+
scalar_entropy = 0.07
9739+
ppo = cls(policy, value_operator, entropy_coef=scalar_entropy)
97389740
ppo.set_keys(
97399741
action=[
97409742
("agent0", "action"),
@@ -9748,8 +9750,50 @@ def mixture_constructor(logits, loc, scale):
97489750
],
97499751
)
97509752
loss = ppo(data)
9753+
composite_entropy = loss["composite_entropy"]
9754+
entropy = _sum_td_features(composite_entropy)
9755+
expected_loss = -(scalar_entropy * entropy).mean() # batch mean
9756+
torch.testing.assert_close(
9757+
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7
9758+
)
97519759
loss.sum(reduce=True)
97529760

9761+
# keep per-head entropies instead of the aggregated tensor
9762+
set_composite_lp_aggregate(False).set()
9763+
coef_map = {
9764+
"agent0": 0.10,
9765+
"agent1": 0.05,
9766+
"agent2": 0.02,
9767+
}
9768+
ppo_weighted = cls(policy, value_operator, entropy_coef=coef_map)
9769+
ppo_weighted.set_keys(
9770+
action=[
9771+
("agent0", "action"),
9772+
("agent1", "action"),
9773+
("agent2", "action"),
9774+
],
9775+
sample_log_prob=[
9776+
("agent0", "action_log_prob"),
9777+
("agent1", "action_log_prob"),
9778+
("agent2", "action_log_prob"),
9779+
],
9780+
)
9781+
loss = ppo_weighted(data)
9782+
composite_entropy = loss["composite_entropy"]
9783+
9784+
# sanity check: loss_entropy is scalar + finite
9785+
assert loss["loss_entropy"].ndim == 0
9786+
assert torch.isfinite(loss["loss_entropy"])
9787+
# Check individual loss is computed with the right weights
9788+
expected_loss = 0.0
9789+
for name, head_entropy in composite_entropy.items():
9790+
expected_loss -= (
9791+
coef_map[name] * _sum_td_features(head_entropy)
9792+
).mean()
9793+
torch.testing.assert_close(
9794+
loss["loss_entropy"], expected_loss, rtol=1e-5, atol=1e-7
9795+
)
9796+
97539797
def test_ppo_marl_aggregate(self):
97549798
env = MARLEnv()
97559799

@@ -9791,6 +9835,36 @@ def primer(td):
97919835
assert isinstance(ppo.tensor_keys.action, list)
97929836
ppo(output)
97939837

9838+
def _make_entropy_loss(self, entropy_coef):
9839+
actor, critic = self._create_mock_actor_value()
9840+
return PPOLoss(actor, critic, entropy_coef=entropy_coef)
9841+
9842+
def test_weighted_entropy_scalar(self):
9843+
loss = self._make_entropy_loss(entropy_coef=0.5)
9844+
entropy = torch.tensor(2.0)
9845+
out = loss._weighted_loss_entropy(entropy)
9846+
torch.testing.assert_close(out, torch.tensor(-1.0))
9847+
9848+
def test_weighted_entropy_mapping(self):
9849+
coef = {"head_0": 0.3, "head_1": 0.7}
9850+
loss = self._make_entropy_loss(entropy_coef=coef)
9851+
entropy = TensorDict(
9852+
{
9853+
"head_0": {"action_log_prob": torch.tensor(1.0)},
9854+
"head_1": {"action_log_prob": torch.tensor(2.0)},
9855+
},
9856+
[],
9857+
)
9858+
out = loss._weighted_loss_entropy(entropy)
9859+
expected = -(coef["head_0"] * 1.0 + coef["head_1"] * 2.0)
9860+
torch.testing.assert_close(out, torch.tensor(expected))
9861+
9862+
def test_weighted_entropy_mapping_missing_key(self):
9863+
loss = self._make_entropy_loss(entropy_coef={"head_not_present": 0.5})
9864+
entropy = TensorDict({"head_0": {"action_log_prob": torch.tensor(1.0)}}, [])
9865+
with pytest.raises(KeyError):
9866+
loss._weighted_loss_entropy(entropy)
9867+
97949868

97959869
class TestA2C(LossModuleTestBase):
97969870
seed = 0

torchrl/objectives/ppo.py

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import warnings
99
from copy import deepcopy
1010
from dataclasses import dataclass
11+
from typing import Mapping
1112

1213
import torch
1314
from tensordict import (
@@ -84,7 +85,9 @@ class PPOLoss(LossModule):
8485
``samples_mc_entropy`` will control how many
8586
samples will be used to compute this estimate.
8687
Defaults to ``1``.
87-
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
88+
entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
89+
* **Scalar**: one value applied to the summed entropy of every action head.
90+
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
8891
Defaults to ``0.01``.
8992
critic_coef (scalar, optional): critic loss multiplier when computing the total
9093
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
@@ -330,7 +333,7 @@ def __init__(
330333
*,
331334
entropy_bonus: bool = True,
332335
samples_mc_entropy: int = 1,
333-
entropy_coef: float = 0.01,
336+
entropy_coef: float | Mapping[str, float] = 0.01,
334337
critic_coef: float | None = None,
335338
loss_critic_type: str = "smooth_l1",
336339
normalize_advantage: bool = False,
@@ -408,7 +411,22 @@ def __init__(
408411
torch, "get_default_device", lambda: torch.device("cpu")
409412
)()
410413

411-
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
414+
if isinstance(entropy_coef, Mapping):
415+
# Store the mapping for per-head coefficients
416+
self._entropy_coef_map = {str(k): float(v) for k, v in entropy_coef.items()}
417+
# Register an empty buffer for compatibility
418+
self.register_buffer("entropy_coef", torch.tensor(0.0))
419+
elif isinstance(entropy_coef, (float, int, torch.Tensor)):
420+
# Register the scalar entropy coefficient
421+
coef = (
422+
float(entropy_coef)
423+
if not torch.is_tensor(entropy_coef)
424+
else float(entropy_coef.item())
425+
)
426+
self.register_buffer("entropy_coef", torch.tensor(coef))
427+
self._entropy_coef_map = None
428+
else:
429+
raise TypeError("entropy_coef must be a float or a Mapping[str, float]")
412430
if critic_coef is not None:
413431
self.register_buffer(
414432
"critic_coef", torch.tensor(critic_coef, device=device)
@@ -540,7 +558,6 @@ def _get_entropy(
540558
return entropy.unsqueeze(-1)
541559

542560
def _get_cur_log_prob(self, tensordict):
543-
544561
if isinstance(
545562
self.actor_network,
546563
(ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule),
@@ -589,7 +606,6 @@ def _get_cur_log_prob(self, tensordict):
589606
def _log_weight(
590607
self, tensordict: TensorDictBase, adv_shape: torch.Size
591608
) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]:
592-
593609
prev_log_prob = _maybe_get_or_select(
594610
tensordict,
595611
self.tensor_keys.sample_log_prob,
@@ -745,9 +761,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
745761
if is_tensor_collection(entropy):
746762
# Reports the entropy of each action head.
747763
td_out.set("composite_entropy", entropy.detach())
748-
entropy = _sum_td_features(entropy)
749-
td_out.set("entropy", entropy.detach().mean()) # for logging
750-
td_out.set("loss_entropy", -self.entropy_coef * entropy)
764+
td_out.set(
765+
"entropy", _sum_td_features(entropy).detach().mean()
766+
) # for logging
767+
else:
768+
td_out.set("entropy", entropy.detach().mean()) # for logging
769+
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
751770
if self._has_critic:
752771
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
753772
td_out.set("loss_critic", loss_critic)
@@ -814,6 +833,35 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
814833
}
815834
self._value_estimator.set_keys(**tensor_keys)
816835

836+
def _weighted_loss_entropy(
837+
self, entropy: torch.Tensor | TensorDictBase
838+
) -> torch.Tensor:
839+
"""Compute the weighted entropy loss.
840+
841+
If `self._entropy_coef_map` is provided, apply per-head entropy coefficients.
842+
Otherwise, use the scalar `self.entropy_coef`.
843+
"""
844+
if self._entropy_coef_map is None:
845+
if is_tensor_collection(entropy):
846+
entropy = _sum_td_features(entropy)
847+
return -self.entropy_coef * entropy
848+
849+
loss_term = None # running sum over heads
850+
for head_name, entropy_head in entropy.items():
851+
try:
852+
coeff = self._entropy_coef_map[head_name]
853+
except KeyError as exc:
854+
raise KeyError(f"Missing entropy coef for head '{head_name}'") from exc
855+
coeff_t = torch.as_tensor(
856+
coeff, dtype=entropy_head.dtype, device=entropy_head.device
857+
)
858+
head_loss_term = -coeff_t * _sum_td_features(entropy_head)
859+
loss_term = (
860+
head_loss_term if loss_term is None else loss_term + head_loss_term
861+
) # accumulate
862+
863+
return loss_term
864+
817865

818866
class ClipPPOLoss(PPOLoss):
819867
"""Clipped PPO loss.
@@ -836,7 +884,9 @@ class ClipPPOLoss(PPOLoss):
836884
``samples_mc_entropy`` will control how many
837885
samples will be used to compute this estimate.
838886
Defaults to ``1``.
839-
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
887+
entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
888+
* **Scalar**: one value applied to the summed entropy of every action head.
889+
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
840890
Defaults to ``0.01``.
841891
critic_coef (scalar, optional): critic loss multiplier when computing the total
842892
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
@@ -939,7 +989,7 @@ def __init__(
939989
clip_epsilon: float = 0.2,
940990
entropy_bonus: bool = True,
941991
samples_mc_entropy: int = 1,
942-
entropy_coef: float = 0.01,
992+
entropy_coef: float | Mapping[str, float] = 0.01,
943993
critic_coef: float | None = None,
944994
loss_critic_type: str = "smooth_l1",
945995
normalize_advantage: bool = False,
@@ -1064,9 +1114,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
10641114
if is_tensor_collection(entropy):
10651115
# Reports the entropy of each action head.
10661116
td_out.set("composite_entropy", entropy.detach())
1067-
entropy = _sum_td_features(entropy)
1068-
td_out.set("entropy", entropy.detach().mean()) # for logging
1069-
td_out.set("loss_entropy", -self.entropy_coef * entropy)
1117+
td_out.set(
1118+
"entropy", _sum_td_features(entropy).detach().mean()
1119+
) # for logging
1120+
else:
1121+
td_out.set("entropy", entropy.detach().mean()) # for logging
1122+
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
10701123
if self._has_critic:
10711124
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
10721125
td_out.set("loss_critic", loss_critic)
@@ -1120,7 +1173,9 @@ class KLPENPPOLoss(PPOLoss):
11201173
``samples_mc_entropy`` will control how many
11211174
samples will be used to compute this estimate.
11221175
Defaults to ``1``.
1123-
entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
1176+
entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
1177+
* **Scalar**: one value applied to the summed entropy of every action head.
1178+
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
11241179
Defaults to ``0.01``.
11251180
critic_coef (scalar, optional): critic loss multiplier when computing the total
11261181
loss. Defaults to ``1.0``.
@@ -1224,7 +1279,7 @@ def __init__(
12241279
samples_mc_kl: int = 1,
12251280
entropy_bonus: bool = True,
12261281
samples_mc_entropy: int = 1,
1227-
entropy_coef: float = 0.01,
1282+
entropy_coef: float | Mapping[str, float] = 0.01,
12281283
critic_coef: float | None = None,
12291284
loss_critic_type: str = "smooth_l1",
12301285
normalize_advantage: bool = False,
@@ -1405,9 +1460,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
14051460
if is_tensor_collection(entropy):
14061461
# Reports the entropy of each action head.
14071462
td_out.set("composite_entropy", entropy.detach())
1408-
entropy = _sum_td_features(entropy)
1409-
td_out.set("entropy", entropy.detach().mean()) # for logging
1410-
td_out.set("loss_entropy", -self.entropy_coef * entropy)
1463+
td_out.set(
1464+
"entropy", _sum_td_features(entropy).detach().mean()
1465+
) # for logging
1466+
else:
1467+
td_out.set("entropy", entropy.detach().mean()) # for logging
1468+
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
14111469
if self._has_critic:
14121470
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
14131471
td_out.set("loss_critic", loss_critic)

0 commit comments

Comments
 (0)