5
5
import math
6
6
import warnings
7
7
from dataclasses import dataclass
8
+ from functools import wraps
8
9
from numbers import Number
9
10
from typing import Dict , Optional , Tuple , Union
10
11
43
44
FUNCTORCH_ERROR = err
44
45
45
46
47
+ def _delezify (func ):
48
+ @wraps (func )
49
+ def new_func (self , * args , ** kwargs ):
50
+ self .target_entropy
51
+ return func (self , * args , ** kwargs )
52
+
53
+ return new_func
54
+
55
+
46
56
class SACLoss (LossModule ):
47
57
"""TorchRL implementation of the SAC loss.
48
58
@@ -371,7 +381,6 @@ def __init__(
371
381
372
382
self ._target_entropy = target_entropy
373
383
self ._action_spec = action_spec
374
- self .target_entropy_buffer = None
375
384
if self ._version == 1 :
376
385
self .actor_critic = ActorCriticWrapper (
377
386
self .actor_network , self .value_network
@@ -384,48 +393,54 @@ def __init__(
384
393
if self ._version == 1 :
385
394
self ._vmap_qnetwork00 = vmap (qvalue_network )
386
395
396
+ @property
397
+ def target_entropy_buffer (self ):
398
+ return self .target_entropy
399
+
387
400
@property
388
401
def target_entropy (self ):
389
- target_entropy = self .target_entropy_buffer
390
- if target_entropy is None :
391
- delattr (self , "target_entropy_buffer" )
392
- target_entropy = self ._target_entropy
393
- action_spec = self ._action_spec
394
- actor_network = self .actor_network
395
- device = next (self .parameters ()).device
396
- if target_entropy == "auto" :
397
- action_spec = (
398
- action_spec
399
- if action_spec is not None
400
- else getattr (actor_network , "spec" , None )
401
- )
402
- if action_spec is None :
403
- raise RuntimeError (
404
- "Cannot infer the dimensionality of the action. Consider providing "
405
- "the target entropy explicitely or provide the spec of the "
406
- "action tensor in the actor network."
407
- )
408
- if not isinstance (action_spec , CompositeSpec ):
409
- action_spec = CompositeSpec ({self .tensor_keys .action : action_spec })
410
- if (
411
- isinstance (self .tensor_keys .action , tuple )
412
- and len (self .tensor_keys .action ) > 1
413
- ):
414
- action_container_shape = action_spec [
415
- self .tensor_keys .action [:- 1 ]
416
- ].shape
417
- else :
418
- action_container_shape = action_spec .shape
419
- target_entropy = - float (
420
- action_spec [self .tensor_keys .action ]
421
- .shape [len (action_container_shape ) :]
422
- .numel ()
402
+ target_entropy = self ._buffers .get ("_target_entropy" , None )
403
+ if target_entropy is not None :
404
+ return target_entropy
405
+ target_entropy = self ._target_entropy
406
+ action_spec = self ._action_spec
407
+ actor_network = self .actor_network
408
+ device = next (self .parameters ()).device
409
+ if target_entropy == "auto" :
410
+ action_spec = (
411
+ action_spec
412
+ if action_spec is not None
413
+ else getattr (actor_network , "spec" , None )
414
+ )
415
+ if action_spec is None :
416
+ raise RuntimeError (
417
+ "Cannot infer the dimensionality of the action. Consider providing "
418
+ "the target entropy explicitely or provide the spec of the "
419
+ "action tensor in the actor network."
423
420
)
424
- self .register_buffer (
425
- "target_entropy_buffer" , torch .tensor (target_entropy , device = device )
421
+ if not isinstance (action_spec , CompositeSpec ):
422
+ action_spec = CompositeSpec ({self .tensor_keys .action : action_spec })
423
+ if (
424
+ isinstance (self .tensor_keys .action , tuple )
425
+ and len (self .tensor_keys .action ) > 1
426
+ ):
427
+
428
+ action_container_shape = action_spec [self .tensor_keys .action [:- 1 ]].shape
429
+ else :
430
+ action_container_shape = action_spec .shape
431
+ target_entropy = - float (
432
+ action_spec [self .tensor_keys .action ]
433
+ .shape [len (action_container_shape ) :]
434
+ .numel ()
426
435
)
427
- return self .target_entropy_buffer
428
- return target_entropy
436
+ delattr (self , "_target_entropy" )
437
+ self .register_buffer (
438
+ "_target_entropy" , torch .tensor (target_entropy , device = device )
439
+ )
440
+ return self ._target_entropy
441
+
442
+ state_dict = _delezify (LossModule .state_dict )
443
+ load_state_dict = _delezify (LossModule .load_state_dict )
429
444
430
445
def _forward_value_estimator_keys (self , ** kwargs ) -> None :
431
446
if self ._value_estimator is not None :
0 commit comments