Skip to content

Commit 0452133

Browse files
vmoensxmaples
andauthored
[BugFix] Instruct the value key to PPOLoss (#1124)
Co-authored-by: xmaples <5900204+xmaples@users.noreply.github.com>
1 parent 714d645 commit 0452133

File tree

1 file changed

+37
-9
lines changed

1 file changed

+37
-9
lines changed

torchrl/objectives/ppo.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,15 @@ class PPOLoss(LossModule):
4242
Args:
4343
actor (ProbabilisticTensorDictSequential): policy operator.
4444
critic (ValueOperator): value operator.
45+
46+
Keyword Args:
4547
advantage_key (str, optional): the input tensordict key where the advantage is
4648
expected to be written.
4749
Defaults to ``"advantage"``.
4850
value_target_key (str, optional): the input tensordict key where the target state
4951
value is expected to be written. Defaults to ``"value_target"``.
52+
value_key (str, optional): the input tensordict key where the state
53+
value is expected to be written. Defaults to ``"state_value"``.
5054
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
5155
loss to favour exploratory policies.
5256
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -120,6 +124,7 @@ def __init__(
120124
*,
121125
advantage_key: str = "advantage",
122126
value_target_key: str = "value_target",
127+
value_key: str = "state_value",
123128
entropy_bonus: bool = True,
124129
samples_mc_entropy: int = 1,
125130
entropy_coef: float = 0.01,
@@ -142,6 +147,7 @@ def __init__(
142147
self.convert_to_functional(critic, "critic", compare_against=policy_params)
143148
self.advantage_key = advantage_key
144149
self.value_target_key = value_target_key
150+
self.value_key = value_key
145151
self.samples_mc_entropy = samples_mc_entropy
146152
self.entropy_bonus = entropy_bonus
147153
self.separate_losses = separate_losses
@@ -193,15 +199,6 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
193199
tensordict = tensordict.detach()
194200
try:
195201
target_return = tensordict.get(self.value_target_key)
196-
state_value = self.critic(
197-
tensordict,
198-
params=self.critic_params,
199-
).get("state_value")
200-
loss_value = distance_loss(
201-
target_return,
202-
state_value,
203-
loss_function=self.loss_critic_type,
204-
)
205202
except KeyError:
206203
raise KeyError(
207204
f"the key {self.value_target_key} was not found in the input tensordict. "
@@ -210,6 +207,25 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
210207
f"TDLambdaEstimate and TDEstimate all return a 'value_target' entry that "
211208
f"can be used for the value loss."
212209
)
210+
211+
state_value_td = self.critic(
212+
tensordict,
213+
params=self.critic_params,
214+
)
215+
216+
try:
217+
state_value = state_value_td.get(self.value_key)
218+
except KeyError:
219+
raise KeyError(
220+
f"the key {self.value_key} was not found in the input tensordict. "
221+
f"Make sure that the value_key passed to PPO is accurate."
222+
)
223+
224+
loss_value = distance_loss(
225+
target_return,
226+
state_value,
227+
loss_function=self.loss_critic_type,
228+
)
213229
return self.critic_coef * loss_value
214230

215231
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -277,10 +293,14 @@ class ClipPPOLoss(PPOLoss):
277293
Args:
278294
actor (ProbabilisticTensorDictSequential): policy operator.
279295
critic (ValueOperator): value operator.
296+
297+
Keyword Args:
280298
advantage_key (str, optional): the input tensordict key where the advantage is expected to be written.
281299
Defaults to ``"advantage"``.
282300
value_target_key (str, optional): the input tensordict key where the target state
283301
value is expected to be written. Defaults to ``"value_target"``.
302+
value_key (str, optional): the input tensordict key where the state
303+
value is expected to be written. Defaults to ``"state_value"``.
284304
clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
285305
default: 0.2
286306
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
@@ -353,6 +373,7 @@ def __init__(
353373
critic: TensorDictModule,
354374
*,
355375
advantage_key: str = "advantage",
376+
value_key: str = "state_value",
356377
clip_epsilon: float = 0.2,
357378
entropy_bonus: bool = True,
358379
samples_mc_entropy: int = 1,
@@ -369,6 +390,7 @@ def __init__(
369390
critic,
370391
advantage_key=advantage_key,
371392
entropy_bonus=entropy_bonus,
393+
value_key=value_key,
372394
samples_mc_entropy=samples_mc_entropy,
373395
entropy_coef=entropy_coef,
374396
critic_coef=critic_coef,
@@ -447,10 +469,14 @@ class KLPENPPOLoss(PPOLoss):
447469
Args:
448470
actor (ProbabilisticTensorDictSequential): policy operator.
449471
critic (ValueOperator): value operator.
472+
473+
Keyword Args:
450474
advantage_key (str, optional): the input tensordict key where the advantage is expected to be written.
451475
Defaults to ``"advantage"``.
452476
value_target_key (str, optional): the input tensordict key where the target state
453477
value is expected to be written. Defaults to ``"value_target"``.
478+
value_key (str, optional): the input tensordict key where the state
479+
value is expected to be written. Defaults to ``"state_value"``.
454480
dtarg (scalar, optional): target KL divergence. Defaults to ``0.01``.
455481
samples_mc_kl (int, optional): number of samples used to compute the KL divergence
456482
if no analytical formula can be found. Defaults to ``1``.
@@ -532,6 +558,7 @@ def __init__(
532558
*,
533559
advantage_key="advantage",
534560
dtarg: float = 0.01,
561+
value_key: str = "state_value",
535562
beta: float = 1.0,
536563
increment: float = 2,
537564
decrement: float = 0.5,
@@ -558,6 +585,7 @@ def __init__(
558585
normalize_advantage=normalize_advantage,
559586
gamma=gamma,
560587
separate_losses=separate_losses,
588+
value_key=value_key,
561589
**kwargs,
562590
)
563591

0 commit comments

Comments
 (0)