8
8
import warnings
9
9
from copy import deepcopy
10
10
from dataclasses import dataclass
11
+ from typing import Mapping
11
12
12
13
import torch
13
14
from tensordict import (
@@ -84,7 +85,9 @@ class PPOLoss(LossModule):
84
85
``samples_mc_entropy`` will control how many
85
86
samples will be used to compute this estimate.
86
87
Defaults to ``1``.
87
- entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
88
+ entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
89
+ * **Scalar**: one value applied to the summed entropy of every action head.
90
+ * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
88
91
Defaults to ``0.01``.
89
92
critic_coef (scalar, optional): critic loss multiplier when computing the total
90
93
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
@@ -330,7 +333,7 @@ def __init__(
330
333
* ,
331
334
entropy_bonus : bool = True ,
332
335
samples_mc_entropy : int = 1 ,
333
- entropy_coef : float = 0.01 ,
336
+ entropy_coef : float | Mapping [ str , float ] = 0.01 ,
334
337
critic_coef : float | None = None ,
335
338
loss_critic_type : str = "smooth_l1" ,
336
339
normalize_advantage : bool = False ,
@@ -408,7 +411,22 @@ def __init__(
408
411
torch , "get_default_device" , lambda : torch .device ("cpu" )
409
412
)()
410
413
411
- self .register_buffer ("entropy_coef" , torch .tensor (entropy_coef , device = device ))
414
+ if isinstance (entropy_coef , Mapping ):
415
+ # Store the mapping for per-head coefficients
416
+ self ._entropy_coef_map = {str (k ): float (v ) for k , v in entropy_coef .items ()}
417
+ # Register an empty buffer for compatibility
418
+ self .register_buffer ("entropy_coef" , torch .tensor (0.0 ))
419
+ elif isinstance (entropy_coef , (float , int , torch .Tensor )):
420
+ # Register the scalar entropy coefficient
421
+ coef = (
422
+ float (entropy_coef )
423
+ if not torch .is_tensor (entropy_coef )
424
+ else float (entropy_coef .item ())
425
+ )
426
+ self .register_buffer ("entropy_coef" , torch .tensor (coef ))
427
+ self ._entropy_coef_map = None
428
+ else :
429
+ raise TypeError ("entropy_coef must be a float or a Mapping[str, float]" )
412
430
if critic_coef is not None :
413
431
self .register_buffer (
414
432
"critic_coef" , torch .tensor (critic_coef , device = device )
@@ -540,7 +558,6 @@ def _get_entropy(
540
558
return entropy .unsqueeze (- 1 )
541
559
542
560
def _get_cur_log_prob (self , tensordict ):
543
-
544
561
if isinstance (
545
562
self .actor_network ,
546
563
(ProbabilisticTensorDictSequential , ProbabilisticTensorDictModule ),
@@ -589,7 +606,6 @@ def _get_cur_log_prob(self, tensordict):
589
606
def _log_weight (
590
607
self , tensordict : TensorDictBase , adv_shape : torch .Size
591
608
) -> tuple [torch .Tensor , d .Distribution , torch .Tensor ]:
592
-
593
609
prev_log_prob = _maybe_get_or_select (
594
610
tensordict ,
595
611
self .tensor_keys .sample_log_prob ,
@@ -745,9 +761,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
745
761
if is_tensor_collection (entropy ):
746
762
# Reports the entropy of each action head.
747
763
td_out .set ("composite_entropy" , entropy .detach ())
748
- entropy = _sum_td_features (entropy )
749
- td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
750
- td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
764
+ td_out .set (
765
+ "entropy" , _sum_td_features (entropy ).detach ().mean ()
766
+ ) # for logging
767
+ else :
768
+ td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
769
+ td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
751
770
if self ._has_critic :
752
771
loss_critic , value_clip_fraction = self .loss_critic (tensordict )
753
772
td_out .set ("loss_critic" , loss_critic )
@@ -814,6 +833,35 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
814
833
}
815
834
self ._value_estimator .set_keys (** tensor_keys )
816
835
836
+ def _weighted_loss_entropy (
837
+ self , entropy : torch .Tensor | TensorDictBase
838
+ ) -> torch .Tensor :
839
+ """Compute the weighted entropy loss.
840
+
841
+ If `self._entropy_coef_map` is provided, apply per-head entropy coefficients.
842
+ Otherwise, use the scalar `self.entropy_coef`.
843
+ """
844
+ if self ._entropy_coef_map is None :
845
+ if is_tensor_collection (entropy ):
846
+ entropy = _sum_td_features (entropy )
847
+ return - self .entropy_coef * entropy
848
+
849
+ loss_term = None # running sum over heads
850
+ for head_name , entropy_head in entropy .items ():
851
+ try :
852
+ coeff = self ._entropy_coef_map [head_name ]
853
+ except KeyError as exc :
854
+ raise KeyError (f"Missing entropy coef for head '{ head_name } '" ) from exc
855
+ coeff_t = torch .as_tensor (
856
+ coeff , dtype = entropy_head .dtype , device = entropy_head .device
857
+ )
858
+ head_loss_term = - coeff_t * _sum_td_features (entropy_head )
859
+ loss_term = (
860
+ head_loss_term if loss_term is None else loss_term + head_loss_term
861
+ ) # accumulate
862
+
863
+ return loss_term
864
+
817
865
818
866
class ClipPPOLoss (PPOLoss ):
819
867
"""Clipped PPO loss.
@@ -836,7 +884,9 @@ class ClipPPOLoss(PPOLoss):
836
884
``samples_mc_entropy`` will control how many
837
885
samples will be used to compute this estimate.
838
886
Defaults to ``1``.
839
- entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
887
+ entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
888
+ * **Scalar**: one value applied to the summed entropy of every action head.
889
+ * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
840
890
Defaults to ``0.01``.
841
891
critic_coef (scalar, optional): critic loss multiplier when computing the total
842
892
loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
@@ -939,7 +989,7 @@ def __init__(
939
989
clip_epsilon : float = 0.2 ,
940
990
entropy_bonus : bool = True ,
941
991
samples_mc_entropy : int = 1 ,
942
- entropy_coef : float = 0.01 ,
992
+ entropy_coef : float | Mapping [ str , float ] = 0.01 ,
943
993
critic_coef : float | None = None ,
944
994
loss_critic_type : str = "smooth_l1" ,
945
995
normalize_advantage : bool = False ,
@@ -1064,9 +1114,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1064
1114
if is_tensor_collection (entropy ):
1065
1115
# Reports the entropy of each action head.
1066
1116
td_out .set ("composite_entropy" , entropy .detach ())
1067
- entropy = _sum_td_features (entropy )
1068
- td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1069
- td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1117
+ td_out .set (
1118
+ "entropy" , _sum_td_features (entropy ).detach ().mean ()
1119
+ ) # for logging
1120
+ else :
1121
+ td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1122
+ td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
1070
1123
if self ._has_critic :
1071
1124
loss_critic , value_clip_fraction = self .loss_critic (tensordict )
1072
1125
td_out .set ("loss_critic" , loss_critic )
@@ -1120,7 +1173,9 @@ class KLPENPPOLoss(PPOLoss):
1120
1173
``samples_mc_entropy`` will control how many
1121
1174
samples will be used to compute this estimate.
1122
1175
Defaults to ``1``.
1123
- entropy_coef (scalar, optional): entropy multiplier when computing the total loss.
1176
+ entropy_coef (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
1177
+ * **Scalar**: one value applied to the summed entropy of every action head.
1178
+ * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
1124
1179
Defaults to ``0.01``.
1125
1180
critic_coef (scalar, optional): critic loss multiplier when computing the total
1126
1181
loss. Defaults to ``1.0``.
@@ -1224,7 +1279,7 @@ def __init__(
1224
1279
samples_mc_kl : int = 1 ,
1225
1280
entropy_bonus : bool = True ,
1226
1281
samples_mc_entropy : int = 1 ,
1227
- entropy_coef : float = 0.01 ,
1282
+ entropy_coef : float | Mapping [ str , float ] = 0.01 ,
1228
1283
critic_coef : float | None = None ,
1229
1284
loss_critic_type : str = "smooth_l1" ,
1230
1285
normalize_advantage : bool = False ,
@@ -1405,9 +1460,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
1405
1460
if is_tensor_collection (entropy ):
1406
1461
# Reports the entropy of each action head.
1407
1462
td_out .set ("composite_entropy" , entropy .detach ())
1408
- entropy = _sum_td_features (entropy )
1409
- td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1410
- td_out .set ("loss_entropy" , - self .entropy_coef * entropy )
1463
+ td_out .set (
1464
+ "entropy" , _sum_td_features (entropy ).detach ().mean ()
1465
+ ) # for logging
1466
+ else :
1467
+ td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
1468
+ td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
1411
1469
if self ._has_critic :
1412
1470
loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
1413
1471
td_out .set ("loss_critic" , loss_critic )
0 commit comments