|
27 | 27 | from tensordict.utils import NestedKey
|
28 | 28 | from torch import distributions as d
|
29 | 29 |
|
| 30 | +from torchrl._utils import _replace_last |
30 | 31 | from torchrl.objectives.common import LossModule
|
31 | 32 |
|
32 | 33 | from torchrl.objectives.utils import (
|
@@ -1267,3 +1268,30 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
1267 | 1268 |
|
1268 | 1269 | def reset(self) -> None:
|
1269 | 1270 | self.beta = self._beta_init
|
| 1271 | + |
| 1272 | + |
| 1273 | +def _make_lp_get_error(tensor_keys, log_prob, err): |
| 1274 | + result = ( |
| 1275 | + f"The sample log probability key (tensor_keys.sample_log_prob={tensor_keys.sample_log_prob}) does " |
| 1276 | + f"not appear in the log-prob tensordict with keys {list(log_prob.keys(True, True))}. " |
| 1277 | + ) |
| 1278 | + # now check if we can substitute the actions with action_log_prob and retrieve the log-probs |
| 1279 | + action_keys = tensor_keys.action |
| 1280 | + if isinstance(action_keys, list): |
| 1281 | + has_all_log_probs = True |
| 1282 | + log_prob_keys = [] |
| 1283 | + for action_key in action_keys: |
| 1284 | + log_prob_key = _replace_last(action_key, "action_log_prob") |
| 1285 | + log_prob_keys.append(log_prob_key) |
| 1286 | + if log_prob_key not in log_prob: |
| 1287 | + has_all_log_probs = False |
| 1288 | + break |
| 1289 | + if has_all_log_probs: |
| 1290 | + result += ( |
| 1291 | + f"The action keys are {action_keys} and all log_prob keys {log_prob_keys} are present in the " |
| 1292 | + f"log-prob tensordict. Calling `loss.set_keys(sample_log_prob={log_prob_keys})` should resolve " |
| 1293 | + f"this error." |
| 1294 | + ) |
| 1295 | + return KeyError(result) |
| 1296 | + result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)." |
| 1297 | + return KeyError(result) |
0 commit comments