@@ -457,7 +457,7 @@ def compute_gae(
457
457
)
458
458
return storage , agent_state
459
459
460
- def ppo_loss (params , x , a , logp , mb_advantages , mb_returns , truncated ):
460
+ def ppo_loss (params , x , a , logp , mb_advantages , mb_returns , mask ):
461
461
newlogprob , entropy , newvalue = get_action_and_value2 (params , x , a )
462
462
logratio = newlogprob - logp
463
463
ratio = jnp .exp (logratio )
@@ -468,10 +468,10 @@ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, truncated):
468
468
pg_loss1 = - mb_advantages * ratio
469
469
pg_loss2 = - mb_advantages * jnp .clip (ratio , 1 - args .clip_coef , 1 + args .clip_coef )
470
470
# mask truncated state
471
- pg_loss = (jnp .maximum (pg_loss1 , pg_loss2 ) * (1 - truncated )).sum () / (1 - truncated ).sum ()
471
+ pg_loss = (jnp .maximum (pg_loss1 , pg_loss2 ) * (1 - mask )).sum () / (1 - mask ).sum ()
472
472
473
473
# Value loss
474
- v_loss = (((newvalue - mb_returns ) * (1 - truncated )) ** 2 ).sum () / (1 - truncated ).sum ()
474
+ v_loss = (((newvalue - mb_returns ) * (1 - mask )) ** 2 ).sum () / (1 - mask ).sum ()
475
475
476
476
entropy_loss = entropy .mean ()
477
477
loss = pg_loss + v_loss * args .vf_coef
@@ -513,7 +513,7 @@ def update_minibatch(agent_state, minibatch):
513
513
minibatch .logprobs ,
514
514
minibatch .advantages ,
515
515
minibatch .returns ,
516
- minibatch .truncated ,
516
+ 1 - ( 1 - minibatch .truncated ) * ( 1 - minibatch . dones ) ,
517
517
)
518
518
agent_state = agent_state .apply_gradients (grads = grads )
519
519
return agent_state , (
0 commit comments