Skip to content

Commit 6a3e9f8

Browse files
Vincent Moensmatteobettini
andauthored
[BugFix] Patch SAC to allow state_dict manipulation before exec (#1607)
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
1 parent 37c01cc commit 6a3e9f8

File tree

2 files changed

+97
-39
lines changed

2 files changed

+97
-39
lines changed

test/test_cost.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,6 +3260,49 @@ def test_sac_notensordict(
32603260
assert loss_actor == loss_val_td["loss_actor"]
32613261
assert loss_alpha == loss_val_td["loss_alpha"]
32623262

3263+
def test_state_dict(self, version):
3264+
if version == 1:
3265+
pytest.skip("Test not implemented for version 1.")
3266+
model = torch.nn.Linear(3, 4)
3267+
actor_module = TensorDictModule(model, in_keys=["obs"], out_keys=["logits"])
3268+
policy = ProbabilisticActor(
3269+
module=actor_module,
3270+
in_keys=["logits"],
3271+
out_keys=["action"],
3272+
distribution_class=TanhDelta,
3273+
)
3274+
value = ValueOperator(module=model, in_keys=["obs"], out_keys="value")
3275+
3276+
loss = SACLoss(
3277+
actor_network=policy,
3278+
qvalue_network=value,
3279+
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
3280+
)
3281+
state = loss.state_dict()
3282+
3283+
loss = SACLoss(
3284+
actor_network=policy,
3285+
qvalue_network=value,
3286+
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
3287+
)
3288+
loss.load_state_dict(state)
3289+
3290+
# with an access in between
3291+
loss = SACLoss(
3292+
actor_network=policy,
3293+
qvalue_network=value,
3294+
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
3295+
)
3296+
loss.target_entropy
3297+
state = loss.state_dict()
3298+
3299+
loss = SACLoss(
3300+
actor_network=policy,
3301+
qvalue_network=value,
3302+
action_spec=UnboundedContinuousTensorSpec(shape=(2,)),
3303+
)
3304+
loss.load_state_dict(state)
3305+
32633306

32643307
@pytest.mark.skipif(
32653308
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"

torchrl/objectives/sac.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import warnings
77
from dataclasses import dataclass
8+
from functools import wraps
89
from numbers import Number
910
from typing import Dict, Optional, Tuple, Union
1011

@@ -43,6 +44,15 @@
4344
FUNCTORCH_ERROR = err
4445

4546

47+
def _delezify(func):
48+
@wraps(func)
49+
def new_func(self, *args, **kwargs):
50+
self.target_entropy
51+
return func(self, *args, **kwargs)
52+
53+
return new_func
54+
55+
4656
class SACLoss(LossModule):
4757
"""TorchRL implementation of the SAC loss.
4858
@@ -371,7 +381,6 @@ def __init__(
371381

372382
self._target_entropy = target_entropy
373383
self._action_spec = action_spec
374-
self.target_entropy_buffer = None
375384
if self._version == 1:
376385
self.actor_critic = ActorCriticWrapper(
377386
self.actor_network, self.value_network
@@ -384,48 +393,54 @@ def __init__(
384393
if self._version == 1:
385394
self._vmap_qnetwork00 = vmap(qvalue_network)
386395

396+
@property
397+
def target_entropy_buffer(self):
398+
return self.target_entropy
399+
387400
@property
388401
def target_entropy(self):
389-
target_entropy = self.target_entropy_buffer
390-
if target_entropy is None:
391-
delattr(self, "target_entropy_buffer")
392-
target_entropy = self._target_entropy
393-
action_spec = self._action_spec
394-
actor_network = self.actor_network
395-
device = next(self.parameters()).device
396-
if target_entropy == "auto":
397-
action_spec = (
398-
action_spec
399-
if action_spec is not None
400-
else getattr(actor_network, "spec", None)
401-
)
402-
if action_spec is None:
403-
raise RuntimeError(
404-
"Cannot infer the dimensionality of the action. Consider providing "
405-
"the target entropy explicitely or provide the spec of the "
406-
"action tensor in the actor network."
407-
)
408-
if not isinstance(action_spec, CompositeSpec):
409-
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
410-
if (
411-
isinstance(self.tensor_keys.action, tuple)
412-
and len(self.tensor_keys.action) > 1
413-
):
414-
action_container_shape = action_spec[
415-
self.tensor_keys.action[:-1]
416-
].shape
417-
else:
418-
action_container_shape = action_spec.shape
419-
target_entropy = -float(
420-
action_spec[self.tensor_keys.action]
421-
.shape[len(action_container_shape) :]
422-
.numel()
402+
target_entropy = self._buffers.get("_target_entropy", None)
403+
if target_entropy is not None:
404+
return target_entropy
405+
target_entropy = self._target_entropy
406+
action_spec = self._action_spec
407+
actor_network = self.actor_network
408+
device = next(self.parameters()).device
409+
if target_entropy == "auto":
410+
action_spec = (
411+
action_spec
412+
if action_spec is not None
413+
else getattr(actor_network, "spec", None)
414+
)
415+
if action_spec is None:
416+
raise RuntimeError(
417+
"Cannot infer the dimensionality of the action. Consider providing "
418+
"the target entropy explicitely or provide the spec of the "
419+
"action tensor in the actor network."
423420
)
424-
self.register_buffer(
425-
"target_entropy_buffer", torch.tensor(target_entropy, device=device)
421+
if not isinstance(action_spec, CompositeSpec):
422+
action_spec = CompositeSpec({self.tensor_keys.action: action_spec})
423+
if (
424+
isinstance(self.tensor_keys.action, tuple)
425+
and len(self.tensor_keys.action) > 1
426+
):
427+
428+
action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
429+
else:
430+
action_container_shape = action_spec.shape
431+
target_entropy = -float(
432+
action_spec[self.tensor_keys.action]
433+
.shape[len(action_container_shape) :]
434+
.numel()
426435
)
427-
return self.target_entropy_buffer
428-
return target_entropy
436+
delattr(self, "_target_entropy")
437+
self.register_buffer(
438+
"_target_entropy", torch.tensor(target_entropy, device=device)
439+
)
440+
return self._target_entropy
441+
442+
state_dict = _delezify(LossModule.state_dict)
443+
load_state_dict = _delezify(LossModule.load_state_dict)
429444

430445
def _forward_value_estimator_keys(self, **kwargs) -> None:
431446
if self._value_estimator is not None:

0 commit comments

Comments
 (0)