@@ -42,11 +42,15 @@ class PPOLoss(LossModule):
42
42
Args:
43
43
actor (ProbabilisticTensorDictSequential): policy operator.
44
44
critic (ValueOperator): value operator.
45
+
46
+ Keyword Args:
45
47
advantage_key (str, optional): the input tensordict key where the advantage is
46
48
expected to be written.
47
49
Defaults to ``"advantage"``.
48
50
value_target_key (str, optional): the input tensordict key where the target state
49
51
value is expected to be written. Defaults to ``"value_target"``.
52
+ value_key (str, optional): the input tensordict key where the state
53
+ value is expected to be written. Defaults to ``"state_value"``.
50
54
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
51
55
loss to favour exploratory policies.
52
56
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -120,6 +124,7 @@ def __init__(
120
124
* ,
121
125
advantage_key : str = "advantage" ,
122
126
value_target_key : str = "value_target" ,
127
+ value_key : str = "state_value" ,
123
128
entropy_bonus : bool = True ,
124
129
samples_mc_entropy : int = 1 ,
125
130
entropy_coef : float = 0.01 ,
@@ -142,6 +147,7 @@ def __init__(
142
147
self .convert_to_functional (critic , "critic" , compare_against = policy_params )
143
148
self .advantage_key = advantage_key
144
149
self .value_target_key = value_target_key
150
+ self .value_key = value_key
145
151
self .samples_mc_entropy = samples_mc_entropy
146
152
self .entropy_bonus = entropy_bonus
147
153
self .separate_losses = separate_losses
@@ -193,15 +199,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
193
199
tensordict = tensordict .detach ()
194
200
try :
195
201
target_return = tensordict .get (self .value_target_key )
196
- state_value = self .critic (
197
- tensordict ,
198
- params = self .critic_params ,
199
- ).get ("state_value" )
200
- loss_value = distance_loss (
201
- target_return ,
202
- state_value ,
203
- loss_function = self .loss_critic_type ,
204
- )
205
202
except KeyError :
206
203
raise KeyError (
207
204
f"the key { self .value_target_key } was not found in the input tensordict. "
@@ -210,6 +207,25 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
210
207
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
211
208
f"can be used for the value loss."
212
209
)
210
+
211
+ state_value_td = self .critic (
212
+ tensordict ,
213
+ params = self .critic_params ,
214
+ )
215
+
216
+ try :
217
+ state_value = state_value_td .get (self .value_key )
218
+ except KeyError :
219
+ raise KeyError (
220
+ f"the key { self .value_key } was not found in the input tensordict. "
221
+ f"Make sure that the value_key passed to PPO is accurate."
222
+ )
223
+
224
+ loss_value = distance_loss (
225
+ target_return ,
226
+ state_value ,
227
+ loss_function = self .loss_critic_type ,
228
+ )
213
229
return self .critic_coef * loss_value
214
230
215
231
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
@@ -277,10 +293,14 @@ class ClipPPOLoss(PPOLoss):
277
293
Args:
278
294
actor (ProbabilisticTensorDictSequential): policy operator.
279
295
critic (ValueOperator): value operator.
296
+
297
+ Keyword Args:
280
298
advantage_key (str, optional): the input tensordict key where the advantage is expected to be written.
281
299
Defaults to ``"advantage"``.
282
300
value_target_key (str, optional): the input tensordict key where the target state
283
301
value is expected to be written. Defaults to ``"value_target"``.
302
+ value_key (str, optional): the input tensordict key where the state
303
+ value is expected to be written. Defaults to ``"state_value"``.
284
304
clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
285
305
default: 0.2
286
306
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
@@ -353,6 +373,7 @@ def __init__(
353
373
critic : TensorDictModule ,
354
374
* ,
355
375
advantage_key : str = "advantage" ,
376
+ value_key : str = "state_value" ,
356
377
clip_epsilon : float = 0.2 ,
357
378
entropy_bonus : bool = True ,
358
379
samples_mc_entropy : int = 1 ,
@@ -369,6 +390,7 @@ def __init__(
369
390
critic ,
370
391
advantage_key = advantage_key ,
371
392
entropy_bonus = entropy_bonus ,
393
+ value_key = value_key ,
372
394
samples_mc_entropy = samples_mc_entropy ,
373
395
entropy_coef = entropy_coef ,
374
396
critic_coef = critic_coef ,
@@ -447,10 +469,14 @@ class KLPENPPOLoss(PPOLoss):
447
469
Args:
448
470
actor (ProbabilisticTensorDictSequential): policy operator.
449
471
critic (ValueOperator): value operator.
472
+
473
+ Keyword Args:
450
474
advantage_key (str, optional): the input tensordict key where the advantage is expected to be written.
451
475
Defaults to ``"advantage"``.
452
476
value_target_key (str, optional): the input tensordict key where the target state
453
477
value is expected to be written. Defaults to ``"value_target"``.
478
+ value_key (str, optional): the input tensordict key where the state
479
+ value is expected to be written. Defaults to ``"state_value"``.
454
480
dtarg (scalar, optional): target KL divergence. Defaults to ``0.01``.
455
481
samples_mc_kl (int, optional): number of samples used to compute the KL divergence
456
482
if no analytical formula can be found. Defaults to ``1``.
@@ -532,6 +558,7 @@ def __init__(
532
558
* ,
533
559
advantage_key = "advantage" ,
534
560
dtarg : float = 0.01 ,
561
+ value_key : str = "state_value" ,
535
562
beta : float = 1.0 ,
536
563
increment : float = 2 ,
537
564
decrement : float = 0.5 ,
@@ -558,6 +585,7 @@ def __init__(
558
585
normalize_advantage = normalize_advantage ,
559
586
gamma = gamma ,
560
587
separate_losses = separate_losses ,
588
+ value_key = value_key ,
561
589
** kwargs ,
562
590
)
563
591
0 commit comments