Skip to content

Commit 86ab9b7

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 5c03f9f commit 86ab9b7

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

torchrl/objectives/ppo.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensordict.utils import NestedKey
2828
from torch import distributions as d
2929

30+
from torchrl._utils import _replace_last
3031
from torchrl.objectives.common import LossModule
3132

3233
from torchrl.objectives.utils import (
@@ -1267,3 +1268,30 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
12671268

12681269
def reset(self) -> None:
12691270
self.beta = self._beta_init
1271+
1272+
1273+
def _make_lp_get_error(tensor_keys, log_prob, err):
1274+
result = (
1275+
f"The sample log probability key (tensor_keys.sample_log_prob={tensor_keys.sample_log_prob}) does "
1276+
f"not appear in the log-prob tensordict with keys {list(log_prob.keys(True, True))}. "
1277+
)
1278+
# now check if we can substitute the actions with action_log_prob and retrieve the log-probs
1279+
action_keys = tensor_keys.action
1280+
if isinstance(action_keys, list):
1281+
has_all_log_probs = True
1282+
log_prob_keys = []
1283+
for action_key in action_keys:
1284+
log_prob_key = _replace_last(action_key, "action_log_prob")
1285+
log_prob_keys.append(log_prob_key)
1286+
if log_prob_key not in log_prob:
1287+
has_all_log_probs = False
1288+
break
1289+
if has_all_log_probs:
1290+
result += (
1291+
f"The action keys are {action_keys} and all log_prob keys {log_prob_keys} are present in the "
1292+
f"log-prob tensordict. Calling `loss.set_keys(sample_log_prob={log_prob_keys})` should resolve "
1293+
f"this error."
1294+
)
1295+
return KeyError(result)
1296+
result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
1297+
return KeyError(result)

0 commit comments

Comments
 (0)