Skip to content

Commit b7a0d11

Browse files
matteobettiniVincent Moens
andauthored
[Feature] multiagent data standardization: PPO advantages (#2677)
Co-authored-by: Vincent Moens <vmoens@meta.com>
1 parent 50011dc commit b7a0d11

File tree

3 files changed

+128
-12
lines changed

3 files changed

+128
-12
lines changed

test/test_cost.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020

2121
from packaging import version, version as pack_version
22-
2322
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
2423
from tensordict._C import unravel_keys
2524
from tensordict.nn import (
@@ -38,6 +37,7 @@
3837
from tensordict.nn.utils import Buffer
3938
from tensordict.utils import unravel_key
4039
from torch import autograd, nn
40+
from torchrl._utils import _standardize
4141
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
4242
from torchrl.data.postprocs.postprocs import MultiStep
4343
from torchrl.envs.model_based.dreamer import DreamerEnv
@@ -15848,6 +15848,18 @@ class _AcceptedKeys:
1584815848

1584915849

1585015850
class TestUtils:
15851+
def test_standardization(self):
15852+
t = torch.arange(3 * 4 * 5 * 6, dtype=torch.float32).view(3, 4, 5, 6)
15853+
std_t0 = _standardize(t, exclude_dims=(1, 3))
15854+
std_t1 = (t - t.mean((0, 2), keepdim=True)) / t.std((0, 2), keepdim=True).clamp(
15855+
1 - 6
15856+
)
15857+
torch.testing.assert_close(std_t0, std_t1)
15858+
std_t = _standardize(t, (), -1, 2)
15859+
torch.testing.assert_close(std_t, (t + 1) / 2)
15860+
std_t = _standardize(t, ())
15861+
torch.testing.assert_close(std_t, (t - t.mean()) / t.std())
15862+
1585115863
@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
1585215864
@pytest.mark.parametrize("T", [1, 10])
1585315865
@pytest.mark.parametrize("device", get_default_devices())

torchrl/_utils.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@
2424
from distutils.util import strtobool
2525
from functools import wraps
2626
from importlib import import_module
27-
from typing import Any, Callable, cast, Dict, TypeVar, Union
27+
from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union
2828

2929
import numpy as np
3030
import torch
3131
from packaging.version import parse
3232
from tensordict import unravel_key
3333

3434
from tensordict.utils import NestedKey
35-
from torch import multiprocessing as mp
35+
from torch import multiprocessing as mp, Tensor
3636

3737
try:
3838
from torch.compiler import is_compiling
@@ -872,6 +872,70 @@ def set_mode(self, type: Any | None) -> None:
872872
self._mode = type
873873

874874

875+
def _standardize(
876+
input: Tensor,
877+
exclude_dims: Tuple[int] = (),
878+
mean: Tensor | None = None,
879+
std: Tensor | None = None,
880+
eps: float | None = None,
881+
):
882+
"""Standardizes the input tensor with the possibility of excluding specific dims from the statistics.
883+
884+
Useful when processing multi-agent data to keep the agent dimensions independent.
885+
886+
Args:
887+
input (Tensor): the input tensor to be standardized.
888+
exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: ().
889+
mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None.
890+
std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None.
891+
eps (float): epsilon to be used for numerical stability. Default: float32 resolution.
892+
893+
"""
894+
if eps is None:
895+
if input.dtype.is_floating_point:
896+
eps = torch.finfo(torch.float).resolution
897+
else:
898+
eps = 1e-6
899+
900+
len_exclude_dims = len(exclude_dims)
901+
if not len_exclude_dims:
902+
if mean is None:
903+
mean = input.mean()
904+
else:
905+
# Assume dtypes are compatible
906+
mean = torch.as_tensor(mean, device=input.device)
907+
if std is None:
908+
std = input.std()
909+
else:
910+
# Assume dtypes are compatible
911+
std = torch.as_tensor(std, device=input.device)
912+
return (input - mean) / std.clamp_min(eps)
913+
914+
input_shape = input.shape
915+
exclude_dims = [
916+
d if d >= 0 else d + len(input_shape) for d in exclude_dims
917+
] # Make negative dims positive
918+
919+
if len(set(exclude_dims)) != len_exclude_dims:
920+
raise ValueError("Exclude dims has repeating elements")
921+
if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims):
922+
raise ValueError(
923+
f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}"
924+
)
925+
if len_exclude_dims == len(input_shape):
926+
warnings.warn(
927+
"_standardize called but all dims were excluded from the statistics, returning unprocessed input"
928+
)
929+
return input
930+
931+
included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims)
932+
if mean is None:
933+
mean = torch.mean(input, keepdim=True, dim=included_dims)
934+
if std is None:
935+
std = torch.std(input, keepdim=True, dim=included_dims)
936+
return (input - mean) / std.clamp_min(eps)
937+
938+
875939
@wraps(torch.compile)
876940
def compile_with_warmup(*args, warmup: int = 1, **kwargs):
877941
"""Compile a model with warm-up.

torchrl/objectives/ppo.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensordict.utils import NestedKey
2828
from torch import distributions as d
2929

30+
from torchrl._utils import _standardize
3031
from torchrl.objectives.common import LossModule
3132

3233
from torchrl.objectives.utils import (
@@ -46,6 +47,7 @@
4647
TDLambdaEstimator,
4748
VTrace,
4849
)
50+
from yaml import warnings
4951

5052

5153
class PPOLoss(LossModule):
@@ -87,6 +89,9 @@ class PPOLoss(LossModule):
8789
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
8890
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
8991
before being used. Defaults to ``False``.
92+
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
93+
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
94+
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
9095
separate_losses (bool, optional): if ``True``, shared parameters between
9196
policy and critic will only be trained on the policy loss.
9297
Defaults to ``False``, i.e., gradients are propagated to shared
@@ -311,6 +316,7 @@ def __init__(
311316
critic_coef: float = 1.0,
312317
loss_critic_type: str = "smooth_l1",
313318
normalize_advantage: bool = False,
319+
normalize_advantage_exclude_dims: Tuple[int] = (),
314320
gamma: float = None,
315321
separate_losses: bool = False,
316322
advantage_key: str = None,
@@ -381,6 +387,8 @@ def __init__(
381387
self.critic_coef = None
382388
self.loss_critic_type = loss_critic_type
383389
self.normalize_advantage = normalize_advantage
390+
self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
391+
384392
if gamma is not None:
385393
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
386394
self._set_deprecated_ctor_keys(
@@ -606,9 +614,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
606614
)
607615
advantage = tensordict.get(self.tensor_keys.advantage)
608616
if self.normalize_advantage and advantage.numel() > 1:
609-
loc = advantage.mean()
610-
scale = advantage.std().clamp_min(1e-6)
611-
advantage = (advantage - loc) / scale
617+
if advantage.numel() > tensordict.batch_size.numel() and not len(
618+
self.normalize_advantage_exclude_dims
619+
):
620+
warnings.warn(
621+
"You requested advantage normalization and the advantage key has more dimensions"
622+
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
623+
"if you want to keep any dimension independent while computing normalization statistics. "
624+
"If you are working in multi-agent/multi-objective settings this is highly suggested."
625+
)
626+
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)
612627

613628
log_weight, dist, kl_approx = self._log_weight(tensordict)
614629
if is_tensor_collection(log_weight):
@@ -711,6 +726,9 @@ class ClipPPOLoss(PPOLoss):
711726
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
712727
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
713728
before being used. Defaults to ``False``.
729+
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
730+
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
731+
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
714732
separate_losses (bool, optional): if ``True``, shared parameters between
715733
policy and critic will only be trained on the policy loss.
716734
Defaults to ``False``, i.e., gradients are propagated to shared
@@ -802,6 +820,7 @@ def __init__(
802820
critic_coef: float = 1.0,
803821
loss_critic_type: str = "smooth_l1",
804822
normalize_advantage: bool = False,
823+
normalize_advantage_exclude_dims: Tuple[int] = (),
805824
gamma: float = None,
806825
separate_losses: bool = False,
807826
reduction: str = None,
@@ -821,6 +840,7 @@ def __init__(
821840
critic_coef=critic_coef,
822841
loss_critic_type=loss_critic_type,
823842
normalize_advantage=normalize_advantage,
843+
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
824844
gamma=gamma,
825845
separate_losses=separate_losses,
826846
reduction=reduction,
@@ -871,9 +891,16 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
871891
)
872892
advantage = tensordict.get(self.tensor_keys.advantage)
873893
if self.normalize_advantage and advantage.numel() > 1:
874-
loc = advantage.mean()
875-
scale = advantage.std().clamp_min(1e-6)
876-
advantage = (advantage - loc) / scale
894+
if advantage.numel() > tensordict.batch_size.numel() and not len(
895+
self.normalize_advantage_exclude_dims
896+
):
897+
warnings.warn(
898+
"You requested advantage normalization and the advantage key has more dimensions"
899+
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
900+
"if you want to keep any dimension independent while computing normalization statistics. "
901+
"If you are working in multi-agent/multi-objective settings this is highly suggested."
902+
)
903+
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)
877904

878905
log_weight, dist, kl_approx = self._log_weight(tensordict)
879906
# ESS for logging
@@ -955,6 +982,9 @@ class KLPENPPOLoss(PPOLoss):
955982
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
956983
normalize_advantage (bool, optional): if ``True``, the advantage will be normalized
957984
before being used. Defaults to ``False``.
985+
normalize_advantage_exclude_dims (Tuple[int], optional): dimensions to exclude from the advantage
986+
standardization. Negative dimensions are valid. This is useful in multiagent (or multiobjective) settings
987+
where the agent (or objective) dimension may be excluded from the reductions. Default: ().
958988
separate_losses (bool, optional): if ``True``, shared parameters between
959989
policy and critic will only be trained on the policy loss.
960990
Defaults to ``False``, i.e., gradients are propagated to shared
@@ -1048,6 +1078,7 @@ def __init__(
10481078
critic_coef: float = 1.0,
10491079
loss_critic_type: str = "smooth_l1",
10501080
normalize_advantage: bool = False,
1081+
normalize_advantage_exclude_dims: Tuple[int] = (),
10511082
gamma: float = None,
10521083
separate_losses: bool = False,
10531084
reduction: str = None,
@@ -1063,6 +1094,7 @@ def __init__(
10631094
critic_coef=critic_coef,
10641095
loss_critic_type=loss_critic_type,
10651096
normalize_advantage=normalize_advantage,
1097+
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
10661098
gamma=gamma,
10671099
separate_losses=separate_losses,
10681100
reduction=reduction,
@@ -1151,9 +1183,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
11511183
)
11521184
advantage = tensordict_copy.get(self.tensor_keys.advantage)
11531185
if self.normalize_advantage and advantage.numel() > 1:
1154-
loc = advantage.mean()
1155-
scale = advantage.std().clamp_min(1e-6)
1156-
advantage = (advantage - loc) / scale
1186+
if advantage.numel() > tensordict.batch_size.numel() and not len(
1187+
self.normalize_advantage_exclude_dims
1188+
):
1189+
warnings.warn(
1190+
"You requested advantage normalization and the advantage key has more dimensions"
1191+
" than the tensordict batch. Make sure to pass `normalize_advantage_exclude_dims` "
1192+
"if you want to keep any dimension independent while computing normalization statistics. "
1193+
"If you are working in multi-agent/multi-objective settings this is highly suggested."
1194+
)
1195+
advantage = _standardize(advantage, self.normalize_advantage_exclude_dims)
1196+
11571197
log_weight, dist, kl_approx = self._log_weight(tensordict_copy)
11581198
neg_loss = log_weight.exp() * advantage
11591199

0 commit comments

Comments
 (0)