27
27
from tensordict .utils import NestedKey
28
28
from torch import distributions as d
29
29
30
+ from torchrl ._utils import _standardize
30
31
from torchrl .objectives .common import LossModule
31
32
32
33
from torchrl .objectives .utils import (
46
47
TDLambdaEstimator ,
47
48
VTrace ,
48
49
)
50
+ from yaml import warnings
49
51
50
52
51
53
class PPOLoss (LossModule ):
@@ -87,6 +89,9 @@ class PPOLoss(LossModule):
87
89
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
88
90
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
89
91
before being used. Defaults to ``False``.
92
+ normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
93
+ standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
94
+ where the agent (or objective) dimension may be excluded from the reductions. Default: ().
90
95
separate_losses (bool, optional): if ``True``, shared parameters between
91
96
policy and critic will only be trained on the policy loss.
92
97
Defaults to ``False``, i.e., gradients are propagated to shared
@@ -311,6 +316,7 @@ def __init__(
311
316
critic_coef : float = 1.0 ,
312
317
loss_critic_type : str = "smooth_l1" ,
313
318
normalize_advantage : bool = False ,
319
+ normalize_advantage_exclude_dims : Tuple [int ] = (),
314
320
gamma : float = None ,
315
321
separate_losses : bool = False ,
316
322
advantage_key : str = None ,
@@ -381,6 +387,8 @@ def __init__(
381
387
self .critic_coef = None
382
388
self .loss_critic_type = loss_critic_type
383
389
self .normalize_advantage = normalize_advantage
390
+ self .normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
391
+
384
392
if gamma is not None :
385
393
raise TypeError (_GAMMA_LMBDA_DEPREC_ERROR )
386
394
self ._set_deprecated_ctor_keys (
@@ -606,9 +614,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
606
614
)
607
615
advantage = tensordict .get (self .tensor_keys .advantage )
608
616
if self .normalize_advantage and advantage .numel () > 1 :
609
- loc = advantage .mean ()
610
- scale = advantage .std ().clamp_min (1e-6 )
611
- advantage = (advantage - loc ) / scale
617
+ if advantage .numel () > tensordict .batch_size .numel () and not len (
618
+ self .normalize_advantage_exclude_dims
619
+ ):
620
+ warnings .warn (
621
+ "You requested advantage normalization and the advantage key has more dimensions"
622
+ " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
623
+ "if you want to keep any dimension independent while computing normalization statistics. "
624
+ "If you are working in multi-agent/multi-objective settings this is highly suggested."
625
+ )
626
+ advantage = _standardize (advantage , self .normalize_advantage_exclude_dims )
612
627
613
628
log_weight , dist , kl_approx = self ._log_weight (tensordict )
614
629
if is_tensor_collection (log_weight ):
@@ -711,6 +726,9 @@ class ClipPPOLoss(PPOLoss):
711
726
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
712
727
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
713
728
before being used. Defaults to ``False``.
729
+ normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
730
+ standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
731
+ where the agent (or objective) dimension may be excluded from the reductions. Default: ().
714
732
separate_losses (bool, optional): if ``True``, shared parameters between
715
733
policy and critic will only be trained on the policy loss.
716
734
Defaults to ``False``, i.e., gradients are propagated to shared
@@ -802,6 +820,7 @@ def __init__(
802
820
critic_coef : float = 1.0 ,
803
821
loss_critic_type : str = "smooth_l1" ,
804
822
normalize_advantage : bool = False ,
823
+ normalize_advantage_exclude_dims : Tuple [int ] = (),
805
824
gamma : float = None ,
806
825
separate_losses : bool = False ,
807
826
reduction : str = None ,
@@ -821,6 +840,7 @@ def __init__(
821
840
critic_coef = critic_coef ,
822
841
loss_critic_type = loss_critic_type ,
823
842
normalize_advantage = normalize_advantage ,
843
+ normalize_advantage_exclude_dims = normalize_advantage_exclude_dims ,
824
844
gamma = gamma ,
825
845
separate_losses = separate_losses ,
826
846
reduction = reduction ,
@@ -871,9 +891,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
871
891
)
872
892
advantage = tensordict .get (self .tensor_keys .advantage )
873
893
if self .normalize_advantage and advantage .numel () > 1 :
874
- loc = advantage .mean ()
875
- scale = advantage .std ().clamp_min (1e-6 )
876
- advantage = (advantage - loc ) / scale
894
+ if advantage .numel () > tensordict .batch_size .numel () and not len (
895
+ self .normalize_advantage_exclude_dims
896
+ ):
897
+ warnings .warn (
898
+ "You requested advantage normalization and the advantage key has more dimensions"
899
+ " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
900
+ "if you want to keep any dimension independent while computing normalization statistics. "
901
+ "If you are working in multi-agent/multi-objective settings this is highly suggested."
902
+ )
903
+ advantage = _standardize (advantage , self .normalize_advantage_exclude_dims )
877
904
878
905
log_weight , dist , kl_approx = self ._log_weight (tensordict )
879
906
# ESS for logging
@@ -955,6 +982,9 @@ class KLPENPPOLoss(PPOLoss):
955
982
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
956
983
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
957
984
before being used. Defaults to ``False``.
985
+ normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
986
+ standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
987
+ where the agent (or objective) dimension may be excluded from the reductions. Default: ().
958
988
separate_losses (bool, optional): if ``True``, shared parameters between
959
989
policy and critic will only be trained on the policy loss.
960
990
Defaults to ``False``, i.e., gradients are propagated to shared
@@ -1048,6 +1078,7 @@ def __init__(
1048
1078
critic_coef : float = 1.0 ,
1049
1079
loss_critic_type : str = "smooth_l1" ,
1050
1080
normalize_advantage : bool = False ,
1081
+ normalize_advantage_exclude_dims : Tuple [int ] = (),
1051
1082
gamma : float = None ,
1052
1083
separate_losses : bool = False ,
1053
1084
reduction : str = None ,
@@ -1063,6 +1094,7 @@ def __init__(
1063
1094
critic_coef = critic_coef ,
1064
1095
loss_critic_type = loss_critic_type ,
1065
1096
normalize_advantage = normalize_advantage ,
1097
+ normalize_advantage_exclude_dims = normalize_advantage_exclude_dims ,
1066
1098
gamma = gamma ,
1067
1099
separate_losses = separate_losses ,
1068
1100
reduction = reduction ,
@@ -1151,9 +1183,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
1151
1183
)
1152
1184
advantage = tensordict_copy .get (self .tensor_keys .advantage )
1153
1185
if self .normalize_advantage and advantage .numel () > 1 :
1154
- loc = advantage .mean ()
1155
- scale = advantage .std ().clamp_min (1e-6 )
1156
- advantage = (advantage - loc ) / scale
1186
+ if advantage .numel () > tensordict .batch_size .numel () and not len (
1187
+ self .normalize_advantage_exclude_dims
1188
+ ):
1189
+ warnings .warn (
1190
+ "You requested advantage normalization and the advantage key has more dimensions"
1191
+ " than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
1192
+ "if you want to keep any dimension independent while computing normalization statistics. "
1193
+ "If you are working in multi-agent/multi-objective settings this is highly suggested."
1194
+ )
1195
+ advantage = _standardize (advantage , self .normalize_advantage_exclude_dims )
1196
+
1157
1197
log_weight , dist , kl_approx = self ._log_weight (tensordict_copy )
1158
1198
neg_loss = log_weight .exp () * advantage
1159
1199
0 commit comments