11
11
from typing import Tuple
12
12
13
13
import torch
14
-
15
14
from tensordict .nn import dispatch , make_functional , repopulate_module , TensorDictModule
16
15
from tensordict .tensordict import TensorDict , TensorDictBase
17
- from tensordict .utils import NestedKey
16
+
17
+ from tensordict .utils import NestedKey , unravel_key
18
18
from torchrl .modules .tensordict_module .actors import ActorCriticWrapper
19
19
from torchrl .objectives .common import LossModule
20
20
from torchrl .objectives .utils import (
@@ -216,6 +216,9 @@ def __init__(
216
216
self .actor_critic .module [1 ] = self .value_network
217
217
218
218
self .actor_in_keys = actor_network .in_keys
219
+ self .value_exclusive_keys = set (self .value_network .in_keys ) - (
220
+ set (self .actor_in_keys ) | set (self .actor_network .out_keys )
221
+ )
219
222
220
223
self .loss_function = loss_function
221
224
@@ -233,14 +236,15 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
233
236
self ._set_in_keys ()
234
237
235
238
def _set_in_keys (self ):
236
- keys = [
237
- ( "next" , self .tensor_keys .reward ),
238
- ( "next" , self .tensor_keys .done ),
239
+ in_keys = {
240
+ unravel_key (( "next" , self .tensor_keys .reward ) ),
241
+ unravel_key (( "next" , self .tensor_keys .done ) ),
239
242
* self .actor_in_keys ,
240
- * [( "next" , key ) for key in self .actor_in_keys ],
243
+ * [unravel_key (( "next" , key ) ) for key in self .actor_in_keys ],
241
244
* self .value_network .in_keys ,
242
- ]
243
- self ._in_keys = list (set (keys ))
245
+ * [unravel_key (("next" , key )) for key in self .value_network .in_keys ],
246
+ }
247
+ self ._in_keys = sorted (in_keys , key = str )
244
248
245
249
@property
246
250
def in_keys (self ):
@@ -293,7 +297,9 @@ def _loss_actor(
293
297
self ,
294
298
tensordict : TensorDictBase ,
295
299
) -> torch .Tensor :
296
- td_copy = tensordict .select (* self .actor_in_keys ).detach ()
300
+ td_copy = tensordict .select (
301
+ * self .actor_in_keys , * self .value_exclusive_keys
302
+ ).detach ()
297
303
td_copy = self .actor_network (
298
304
td_copy ,
299
305
params = self .actor_network_params ,
0 commit comments