Skip to content

Commit c8a3eeb

Browse files
louisfauryLouis Faury
andauthored
[Feature] ClippedPPOLoss can handle composite value networks (#3031)
Co-authored-by: Louis Faury <louis.faury@helsing.ai>
1 parent 773c366 commit c8a3eeb

File tree

3 files changed

+78
-48
lines changed

3 files changed

+78
-48
lines changed

test/test_cost.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9865,6 +9865,47 @@ def test_weighted_entropy_mapping_missing_key(self):
98659865
with pytest.raises(KeyError):
98669866
loss._weighted_loss_entropy(entropy)
98679867

9868+
def test_critic_loss_tensordict(self):
9869+
# Creates a dummy actor.
9870+
actor, _ = self._create_mock_actor_value()
9871+
9872+
# Creates a critic that produces a tensordict of values.
9873+
class CompositeValueNetwork(nn.Module):
9874+
def forward(self, _) -> tuple[torch.Tensor, torch.Tensor]:
9875+
return torch.tensor([0.0]), torch.tensor([0.0])
9876+
9877+
critic = TensorDictModule(
9878+
CompositeValueNetwork(),
9879+
in_keys=["state"],
9880+
out_keys=[("state_value", "value_0"), ("state_value", "value_1")],
9881+
)
9882+
9883+
# Creates the loss and its input tensordict.
9884+
loss = ClipPPOLoss(actor, critic, loss_critic_type="l2", clip_value=0.1)
9885+
td = TensorDict(
9886+
{
9887+
"state": torch.tensor([0.0]),
9888+
"value_target": TensorDict(
9889+
{"value_0": torch.tensor([-1.0]), "value_1": torch.tensor([2.0])}
9890+
),
9891+
# Log an existing 'state_value' for the 'clip_fraction'
9892+
"state_value": TensorDict(
9893+
{"value_0": torch.tensor([0.0]), "value_1": torch.tensor([0.0])}
9894+
),
9895+
},
9896+
batch_size=(1,),
9897+
)
9898+
9899+
critic_loss, clip_fraction, explained_variance = loss.loss_critic(td)
9900+
9901+
assert isinstance(critic_loss, TensorDict)
9902+
assert "value_0" in critic_loss.keys() and "value_1" in critic_loss.keys()
9903+
torch.testing.assert_close(critic_loss["value_0"], torch.tensor([1.0]))
9904+
torch.testing.assert_close(critic_loss["value_1"], torch.tensor([4.0]))
9905+
9906+
assert isinstance(clip_fraction, TensorDict)
9907+
assert isinstance(explained_variance, TensorDict)
9908+
98689909

98699910
class TestA2C(LossModuleTestBase):
98709911
seed = 0

torchrl/objectives/ppo.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,9 @@ def _log_weight(
690690

691691
return log_weight, dist, kl_approx
692692

693-
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
693+
def loss_critic(
694+
self, tensordict: TensorDictBase
695+
) -> tuple[torch.Tensor | TensorDict, ...]:
694696
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
695697
# TODO: if the advantage is gathered by forward, this introduces an
696698
# overhead that we could easily reduce.
@@ -709,28 +711,24 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
709711
)
710712

711713
if self.clip_value:
712-
old_state_value = tensordict.get(
713-
self.tensor_keys.value, None
714-
) # TODO: None soon to be removed
714+
old_state_value = tensordict.get(self.tensor_keys.value)
715715
if old_state_value is None:
716716
raise KeyError(
717717
f"clip_value is set to {self.clip_value}, but "
718718
f"the key {self.tensor_keys.value} was not found in the input tensordict. "
719-
f"Make sure that the value_key passed to PPO exists in the input tensordict."
719+
f"Make sure that the 'value_key' passed to PPO exists in the input tensordict."
720720
)
721721

722722
with self.critic_network_params.to_module(
723723
self.critic_network
724724
) if self.functional else contextlib.nullcontext():
725725
state_value_td = self.critic_network(tensordict)
726726

727-
state_value = state_value_td.get(
728-
self.tensor_keys.value, None
729-
) # TODO: None soon to be removed
727+
state_value = state_value_td.get(self.tensor_keys.value)
730728
if state_value is None:
731729
raise KeyError(
732730
f"the key {self.tensor_keys.value} was not found in the critic output tensordict. "
733-
f"Make sure that the value_key passed to PPO is accurate."
731+
f"Make sure that the 'value_key' passed to PPO is accurate."
734732
)
735733

736734
loss_value = distance_loss(
@@ -756,8 +754,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
756754
tgt = target_return.detach()
757755
pred = state_value.detach()
758756
eps = torch.finfo(tgt.dtype).eps
759-
resid = torch.var(tgt - pred, unbiased=False, dim=0)
760-
total = torch.var(tgt, unbiased=False, dim=0)
757+
758+
resid = torch.var(tgt - pred, correction=0, dim=0)
759+
total = torch.var(tgt, correction=0, dim=0)
761760
explained_variance = 1.0 - resid / (total + eps)
762761

763762
self._clear_weakrefs(
@@ -954,7 +953,7 @@ class ClipPPOLoss(PPOLoss):
954953
``samples_mc_entropy`` will control how many
955954
samples will be used to compute this estimate.
956955
Defaults to ``1``.
957-
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
956+
entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
958957
* **Scalar**: one value applied to the summed entropy of every action head.
959958
* **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
960959
Defaults to ``0.01``.

torchrl/objectives/utils.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from copy import copy
1111
from enum import Enum
12-
from typing import Any, Callable, Iterable
12+
from typing import Any, Callable, Iterable, TypeVar
1313

1414
import torch
1515
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
@@ -101,54 +101,44 @@ def decorate_context(*args, **kwargs):
101101
return decorate_context
102102

103103

104+
TensorLike = TypeVar("TensorLike", Tensor, TensorDict)
105+
106+
104107
def distance_loss(
105-
v1: torch.Tensor,
106-
v2: torch.Tensor,
108+
v1: TensorLike,
109+
v2: TensorLike,
107110
loss_function: str,
108111
strict_shape: bool = True,
109-
) -> torch.Tensor:
112+
) -> TensorLike:
110113
"""Computes a distance loss between two tensors.
111114
112115
Args:
113-
v1 (Tensor): a tensor with a shape compatible with v2
114-
v2 (Tensor): a tensor with a shape compatible with v1
116+
v1 (Tensor | TensorDict): a tensor or tensordict with a shape compatible with v2.
117+
v2 (Tensor | TensorDict): a tensor or tensordict with a shape compatible with v1.
115118
loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used.
116119
strict_shape (bool): if False, v1 and v2 are allowed to have a different shape.
117120
Default is ``True``.
118121
119122
Returns:
120-
A tensor of the shape v1.view_as(v2) or v2.view_as(v1) with values equal to the distance loss between the
121-
two.
123+
A tensor or tensordict of the shape v1.view_as(v2) or v2.view_as(v1)
124+
with values equal to the distance loss between the two.
122125
123126
"""
124127
if v1.shape != v2.shape and strict_shape:
125128
raise RuntimeError(
126-
f"The input tensors have shapes {v1.shape} and {v2.shape} which are incompatible."
129+
f"The input tensors or tensordicts have shapes {v1.shape} and {v2.shape} which are incompatible."
127130
)
128131

129132
if loss_function == "l2":
130-
value_loss = F.mse_loss(
131-
v1,
132-
v2,
133-
reduction="none",
134-
)
133+
return F.mse_loss(v1, v2, reduction="none")
135134

136-
elif loss_function == "l1":
137-
value_loss = F.l1_loss(
138-
v1,
139-
v2,
140-
reduction="none",
141-
)
135+
if loss_function == "l1":
136+
return F.l1_loss(v1, v2, reduction="none")
142137

143-
elif loss_function == "smooth_l1":
144-
value_loss = F.smooth_l1_loss(
145-
v1,
146-
v2,
147-
reduction="none",
148-
)
149-
else:
150-
raise NotImplementedError(f"Unknown loss {loss_function}")
151-
return value_loss
138+
if loss_function == "smooth_l1":
139+
return F.smooth_l1_loss(v1, v2, reduction="none")
140+
141+
raise NotImplementedError(f"Unknown loss {loss_function}.")
152142

153143

154144
class TargetNetUpdater:
@@ -620,13 +610,13 @@ def _reduce(tensor: torch.Tensor, reduction: str) -> float | torch.Tensor:
620610

621611

622612
def _clip_value_loss(
623-
old_state_value: torch.Tensor,
624-
state_value: torch.Tensor,
625-
clip_value: torch.Tensor,
626-
target_return: torch.Tensor,
627-
loss_value: torch.Tensor,
613+
old_state_value: torch.Tensor | TensorDict,
614+
state_value: torch.Tensor | TensorDict,
615+
clip_value: torch.Tensor | TensorDict,
616+
target_return: torch.Tensor | TensorDict,
617+
loss_value: torch.Tensor | TensorDict,
628618
loss_critic_type: str,
629-
):
619+
) -> tuple[torch.Tensor | TensorDict, torch.Tensor]:
630620
"""Value clipping method for loss computation.
631621
632622
This method computes a clipped state value from the old state value and the state value,
@@ -644,7 +634,7 @@ def _clip_value_loss(
644634
loss_function=loss_critic_type,
645635
)
646636
# Chose the most pessimistic value prediction between clipped and non-clipped
647-
loss_value = torch.max(loss_value, loss_value_clipped)
637+
loss_value = torch.maximum(loss_value, loss_value_clipped)
648638
return loss_value, clip_fraction
649639

650640

0 commit comments

Comments
 (0)