@@ -280,32 +280,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
280
280
a tuple of 2 tensors containing the DDPG loss.
281
281
282
282
"""
283
- loss_value , td_error , pred_val , target_value = self ._loss_value (tensordict )
284
- td_error = td_error .detach ()
285
- if tensordict .device is not None :
286
- td_error = td_error .to (tensordict .device )
287
- tensordict .set (
288
- self .tensor_keys .priority ,
289
- td_error ,
290
- inplace = True ,
291
- )
292
- loss_actor = self ._loss_actor (tensordict )
283
+ loss_value , metadata = self .loss_value (tensordict )
284
+ loss_actor , metadata_actor = self .loss_actor (tensordict )
285
+ metadata .update (metadata_actor )
293
286
return TensorDict (
294
- source = {
295
- "loss_actor" : loss_actor .mean (),
296
- "loss_value" : loss_value .mean (),
297
- "pred_value" : pred_val .mean ().detach (),
298
- "target_value" : target_value .mean ().detach (),
299
- "pred_value_max" : pred_val .max ().detach (),
300
- "target_value_max" : target_value .max ().detach (),
301
- },
287
+ source = {"loss_actor" : loss_actor , "loss_value" : loss_value , ** metadata },
302
288
batch_size = [],
303
289
)
304
290
305
- def _loss_actor (
291
+ def loss_actor (
306
292
self ,
307
293
tensordict : TensorDictBase ,
308
- ) -> torch .Tensor :
294
+ ) -> [ torch .Tensor , dict ] :
309
295
td_copy = tensordict .select (
310
296
* self .actor_in_keys , * self .value_exclusive_keys
311
297
).detach ()
@@ -317,12 +303,14 @@ def _loss_actor(
317
303
td_copy ,
318
304
params = self ._cached_detached_value_params ,
319
305
)
320
- return - td_copy .get (self .tensor_keys .state_action_value )
306
+ loss_actor = - td_copy .get (self .tensor_keys .state_action_value )
307
+ metadata = {}
308
+ return loss_actor .mean (), metadata
321
309
322
- def _loss_value (
310
+ def loss_value (
323
311
self ,
324
312
tensordict : TensorDictBase ,
325
- ) -> Tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
313
+ ) -> Tuple [torch .Tensor , dict ]:
326
314
# value loss
327
315
td_copy = tensordict .select (* self .value_network .in_keys ).detach ()
328
316
self .value_network (
@@ -340,7 +328,24 @@ def _loss_value(
340
328
pred_val , target_value , loss_function = self .loss_function
341
329
)
342
330
343
- return loss_value , (pred_val - target_value ).pow (2 ), pred_val , target_value
331
+ td_error = (pred_val - target_value ).pow (2 )
332
+ td_error = td_error .detach ()
333
+ if tensordict .device is not None :
334
+ td_error = td_error .to (tensordict .device )
335
+ tensordict .set (
336
+ self .tensor_keys .priority ,
337
+ td_error ,
338
+ inplace = True ,
339
+ )
340
+ with torch .no_grad ():
341
+ metadata = {
342
+ "td_error" : td_error .mean (),
343
+ "pred_value" : pred_val .mean (),
344
+ "target_value" : target_value .mean (),
345
+ "target_value_max" : target_value .max (),
346
+ "pred_value_max" : pred_val .max (),
347
+ }
348
+ return loss_value .mean (), metadata
344
349
345
350
def make_value_estimator (self , value_type : ValueEstimators = None , ** hyperparams ):
346
351
if value_type is None :
0 commit comments