Skip to content

Commit c79d66f

Browse files
committed
fix the way calculate mask
1 parent 7ad95f4 commit c79d66f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def compute_gae(
457457
)
458458
return storage, agent_state
459459

460-
def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, truncated):
460+
def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, mask):
461461
newlogprob, entropy, newvalue = get_action_and_value2(params, x, a)
462462
logratio = newlogprob - logp
463463
ratio = jnp.exp(logratio)
@@ -468,10 +468,10 @@ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, truncated):
468468
pg_loss1 = -mb_advantages * ratio
469469
pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
470470
# mask truncated state
471-
pg_loss = (jnp.maximum(pg_loss1, pg_loss2) * (1 - truncated)).sum() / (1 - truncated).sum()
471+
pg_loss = (jnp.maximum(pg_loss1, pg_loss2) * (1 - mask)).sum() / (1 - mask).sum()
472472

473473
# Value loss
474-
v_loss = (((newvalue - mb_returns) * (1 - truncated)) ** 2).sum() / (1 - truncated).sum()
474+
v_loss = (((newvalue - mb_returns) * (1 - mask)) ** 2).sum() / (1 - mask).sum()
475475

476476
entropy_loss = entropy.mean()
477477
loss = pg_loss + v_loss * args.vf_coef
@@ -513,7 +513,7 @@ def update_minibatch(agent_state, minibatch):
513513
minibatch.logprobs,
514514
minibatch.advantages,
515515
minibatch.returns,
516-
minibatch.truncated,
516+
1 - (1 - minibatch.truncated) * (1 - minibatch.dones),
517517
)
518518
agent_state = agent_state.apply_gradients(grads=grads)
519519
return agent_state, (

0 commit comments

Comments
 (0)