@@ -47,6 +47,8 @@ class PPOLoss(LossModule):
47
47
default: 1.0
48
48
gamma (scalar): a discount factor for return computation.
49
49
loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
50
+ normalize_advantage (bool): if True, the advantage will be normalized before being used.
51
+ Defaults to True.
50
52
51
53
"""
52
54
@@ -62,6 +64,7 @@ def __init__(
62
64
critic_coef : float = 1.0 ,
63
65
gamma : float = 0.99 ,
64
66
loss_critic_type : str = "smooth_l1" ,
67
+ normalize_advantage : bool = True ,
65
68
):
66
69
super ().__init__ ()
67
70
self .convert_to_functional (
@@ -82,6 +85,7 @@ def __init__(
82
85
)
83
86
self .register_buffer ("gamma" , torch .tensor (gamma , device = self .device ))
84
87
self .loss_critic_type = loss_critic_type
88
+ self .normalize_advantage = normalize_advantage
85
89
86
90
def reset (self ) -> None :
87
91
pass
@@ -137,8 +141,13 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
137
141
return self .critic_coef * loss_value
138
142
139
143
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
140
- tensordict = tensordict .clone ()
144
+ tensordict = tensordict .clone (False )
141
145
advantage = tensordict .get (self .advantage_key )
146
+ if self .normalize_advantage and advantage .numel () > 1 :
147
+ loc = advantage .mean ().item ()
148
+ scale = advantage .std ().clamp_min (1e-6 ).item ()
149
+ advantage = (advantage - loc ) / scale
150
+
142
151
log_weight , dist = self ._log_weight (tensordict )
143
152
neg_loss = (log_weight .exp () * advantage ).mean ()
144
153
td_out = TensorDict ({"loss_objective" : - neg_loss .mean ()}, [])
@@ -176,6 +185,8 @@ class ClipPPOLoss(PPOLoss):
176
185
default: 1.0
177
186
gamma (scalar): a discount factor for return computation.
178
187
loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
188
+ normalize_advantage (bool): if True, the advantage will be normalized before being used.
189
+ Defaults to True.
179
190
180
191
"""
181
192
@@ -190,7 +201,8 @@ def __init__(
190
201
entropy_coef : float = 0.01 ,
191
202
critic_coef : float = 1.0 ,
192
203
gamma : float = 0.99 ,
193
- loss_critic_type : str = "l2" ,
204
+ loss_critic_type : str = "smooth_l1" ,
205
+ normalize_advantage : bool = True ,
194
206
** kwargs ,
195
207
):
196
208
super (ClipPPOLoss , self ).__init__ (
@@ -203,6 +215,7 @@ def __init__(
203
215
critic_coef = critic_coef ,
204
216
gamma = gamma ,
205
217
loss_critic_type = loss_critic_type ,
218
+ normalize_advantage = normalize_advantage ,
206
219
** kwargs ,
207
220
)
208
221
self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon ))
@@ -215,7 +228,7 @@ def _clip_bounds(self):
215
228
)
216
229
217
230
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
218
- tensordict = tensordict .clone ()
231
+ tensordict = tensordict .clone (False )
219
232
advantage = tensordict .get (self .advantage_key )
220
233
log_weight , dist = self ._log_weight (tensordict )
221
234
# ESS for logging
@@ -235,6 +248,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
235
248
gain1 = log_weight .exp () * advantage
236
249
237
250
log_weight_clip = log_weight .clamp (* self ._clip_bounds )
251
+ if self .normalize_advantage and advantage .numel () > 1 :
252
+ loc = advantage .mean ().item ()
253
+ scale = advantage .std ().clamp_min (1e-6 ).item ()
254
+ advantage = (advantage - loc ) / scale
238
255
gain2 = log_weight_clip .exp () * advantage
239
256
240
257
gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 )[0 ]
@@ -282,6 +299,8 @@ class KLPENPPOLoss(PPOLoss):
282
299
default: 1.0
283
300
gamma (scalar): a discount factor for return computation.
284
301
loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
302
+ normalize_advantage (bool): if True, the advantage will be normalized before being used.
303
+ Defaults to True.
285
304
286
305
"""
287
306
@@ -300,7 +319,8 @@ def __init__(
300
319
entropy_coef : float = 0.01 ,
301
320
critic_coef : float = 1.0 ,
302
321
gamma : float = 0.99 ,
303
- loss_critic_type : str = "l2" ,
322
+ loss_critic_type : str = "smooth_l1" ,
323
+ normalize_advantage : bool = True ,
304
324
** kwargs ,
305
325
):
306
326
super (KLPENPPOLoss , self ).__init__ (
@@ -313,6 +333,7 @@ def __init__(
313
333
critic_coef = critic_coef ,
314
334
gamma = gamma ,
315
335
loss_critic_type = loss_critic_type ,
336
+ normalize_advantage = normalize_advantage ,
316
337
** kwargs ,
317
338
)
318
339
@@ -333,8 +354,12 @@ def __init__(
333
354
self .samples_mc_kl = samples_mc_kl
334
355
335
356
def forward (self , tensordict : TensorDictBase ) -> TensorDict :
336
- tensordict = tensordict .clone ()
357
+ tensordict = tensordict .clone (False )
337
358
advantage = tensordict .get (self .advantage_key )
359
+ if self .normalize_advantage and advantage .numel () > 1 :
360
+ loc = advantage .mean ().item ()
361
+ scale = advantage .std ().clamp_min (1e-6 ).item ()
362
+ advantage = (advantage - loc ) / scale
338
363
log_weight , dist = self ._log_weight (tensordict )
339
364
neg_loss = log_weight .exp () * advantage
340
365
0 commit comments