Skip to content

Commit b8e0fb5

Browse files
[BugFix] DDPG select also critic input for actor loss (#1563)
Signed-off-by: Matteo Bettini <matbet@meta.com>
1 parent 09e148b commit b8e0fb5

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

torchrl/objectives/ddpg.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from typing import Tuple
1212

1313
import torch
14-
1514
from tensordict.nn import dispatch, make_functional, repopulate_module, TensorDictModule
1615
from tensordict.tensordict import TensorDict, TensorDictBase
17-
from tensordict.utils import NestedKey
16+
17+
from tensordict.utils import NestedKey, unravel_key
1818
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
1919
from torchrl.objectives.common import LossModule
2020
from torchrl.objectives.utils import (
@@ -216,6 +216,9 @@ def __init__(
216216
self.actor_critic.module[1] = self.value_network
217217

218218
self.actor_in_keys = actor_network.in_keys
219+
self.value_exclusive_keys = set(self.value_network.in_keys) - (
220+
set(self.actor_in_keys) | set(self.actor_network.out_keys)
221+
)
219222

220223
self.loss_function = loss_function
221224

@@ -233,14 +236,15 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
233236
self._set_in_keys()
234237

235238
def _set_in_keys(self):
236-
keys = [
237-
("next", self.tensor_keys.reward),
238-
("next", self.tensor_keys.done),
239+
in_keys = {
240+
unravel_key(("next", self.tensor_keys.reward)),
241+
unravel_key(("next", self.tensor_keys.done)),
239242
*self.actor_in_keys,
240-
*[("next", key) for key in self.actor_in_keys],
243+
*[unravel_key(("next", key)) for key in self.actor_in_keys],
241244
*self.value_network.in_keys,
242-
]
243-
self._in_keys = list(set(keys))
245+
*[unravel_key(("next", key)) for key in self.value_network.in_keys],
246+
}
247+
self._in_keys = sorted(in_keys, key=str)
244248

245249
@property
246250
def in_keys(self):
@@ -293,7 +297,9 @@ def _loss_actor(
293297
self,
294298
tensordict: TensorDictBase,
295299
) -> torch.Tensor:
296-
td_copy = tensordict.select(*self.actor_in_keys).detach()
300+
td_copy = tensordict.select(
301+
*self.actor_in_keys, *self.value_exclusive_keys
302+
).detach()
297303
td_copy = self.actor_network(
298304
td_copy,
299305
params=self.actor_network_params,

0 commit comments

Comments
 (0)