Skip to content

Commit 2fd4632

Browse files
committed
[Feature] PPO advantage normalization (#869)
1 parent 450a380 commit 2fd4632

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

torchrl/objectives/ppo.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class PPOLoss(LossModule):
4747
default: 1.0
4848
gamma (scalar): a discount factor for return computation.
4949
loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
50+
normalize_advantage (bool): if True, the advantage will be normalized before being used.
51+
Defaults to True.
5052
5153
"""
5254

@@ -62,6 +64,7 @@ def __init__(
6264
critic_coef: float = 1.0,
6365
gamma: float = 0.99,
6466
loss_critic_type: str = "smooth_l1",
67+
normalize_advantage: bool = True,
6568
):
6669
super().__init__()
6770
self.convert_to_functional(
@@ -82,6 +85,7 @@ def __init__(
8285
)
8386
self.register_buffer("gamma", torch.tensor(gamma, device=self.device))
8487
self.loss_critic_type = loss_critic_type
88+
self.normalize_advantage = normalize_advantage
8589

8690
def reset(self) -> None:
8791
pass
@@ -137,8 +141,13 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
137141
return self.critic_coef * loss_value
138142

139143
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
140-
tensordict = tensordict.clone()
144+
tensordict = tensordict.clone(False)
141145
advantage = tensordict.get(self.advantage_key)
146+
if self.normalize_advantage and advantage.numel() > 1:
147+
loc = advantage.mean().item()
148+
scale = advantage.std().clamp_min(1e-6).item()
149+
advantage = (advantage - loc) / scale
150+
142151
log_weight, dist = self._log_weight(tensordict)
143152
neg_loss = (log_weight.exp() * advantage).mean()
144153
td_out = TensorDict({"loss_objective": -neg_loss.mean()}, [])
@@ -176,6 +185,8 @@ class ClipPPOLoss(PPOLoss):
176185
default: 1.0
177186
gamma (scalar): a discount factor for return computation.
178187
loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
188+
normalize_advantage (bool): if True, the advantage will be normalized before being used.
189+
Defaults to True.
179190
180191
"""
181192

@@ -190,7 +201,8 @@ def __init__(
190201
entropy_coef: float = 0.01,
191202
critic_coef: float = 1.0,
192203
gamma: float = 0.99,
193-
loss_critic_type: str = "l2",
204+
loss_critic_type: str = "smooth_l1",
205+
normalize_advantage: bool = True,
194206
**kwargs,
195207
):
196208
super(ClipPPOLoss, self).__init__(
@@ -203,6 +215,7 @@ def __init__(
203215
critic_coef=critic_coef,
204216
gamma=gamma,
205217
loss_critic_type=loss_critic_type,
218+
normalize_advantage=normalize_advantage,
206219
**kwargs,
207220
)
208221
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon))
@@ -215,7 +228,7 @@ def _clip_bounds(self):
215228
)
216229

217230
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
218-
tensordict = tensordict.clone()
231+
tensordict = tensordict.clone(False)
219232
advantage = tensordict.get(self.advantage_key)
220233
log_weight, dist = self._log_weight(tensordict)
221234
# ESS for logging
@@ -235,6 +248,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
235248
gain1 = log_weight.exp() * advantage
236249

237250
log_weight_clip = log_weight.clamp(*self._clip_bounds)
251+
if self.normalize_advantage and advantage.numel() > 1:
252+
loc = advantage.mean().item()
253+
scale = advantage.std().clamp_min(1e-6).item()
254+
advantage = (advantage - loc) / scale
238255
gain2 = log_weight_clip.exp() * advantage
239256

240257
gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
@@ -282,6 +299,8 @@ class KLPENPPOLoss(PPOLoss):
282299
default: 1.0
283300
gamma (scalar): a discount factor for return computation.
284301
loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
302+
normalize_advantage (bool): if True, the advantage will be normalized before being used.
303+
Defaults to True.
285304
286305
"""
287306

@@ -300,7 +319,8 @@ def __init__(
300319
entropy_coef: float = 0.01,
301320
critic_coef: float = 1.0,
302321
gamma: float = 0.99,
303-
loss_critic_type: str = "l2",
322+
loss_critic_type: str = "smooth_l1",
323+
normalize_advantage: bool = True,
304324
**kwargs,
305325
):
306326
super(KLPENPPOLoss, self).__init__(
@@ -313,6 +333,7 @@ def __init__(
313333
critic_coef=critic_coef,
314334
gamma=gamma,
315335
loss_critic_type=loss_critic_type,
336+
normalize_advantage=normalize_advantage,
316337
**kwargs,
317338
)
318339

@@ -333,8 +354,12 @@ def __init__(
333354
self.samples_mc_kl = samples_mc_kl
334355

335356
def forward(self, tensordict: TensorDictBase) -> TensorDict:
336-
tensordict = tensordict.clone()
357+
tensordict = tensordict.clone(False)
337358
advantage = tensordict.get(self.advantage_key)
359+
if self.normalize_advantage and advantage.numel() > 1:
360+
loc = advantage.mean().item()
361+
scale = advantage.std().clamp_min(1e-6).item()
362+
advantage = (advantage - loc) / scale
338363
log_weight, dist = self._log_weight(tensordict)
339364
neg_loss = log_weight.exp() * advantage
340365

torchrl/objectives/value/advantages.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,10 @@ def forward(
562562
)
563563

564564
if self.average_gae:
565-
adv = adv - adv.mean()
566-
adv = adv / adv.std().clamp_min(1e-4)
565+
loc = adv.mean()
566+
scale = adv.std().clamp_min(1e-4)
567+
adv = adv - loc
568+
adv = adv / scale
567569

568570
tensordict.set(self.advantage_key, adv)
569571
tensordict.set(self.value_target_key, value_target)

0 commit comments

Comments
 (0)