Skip to content

Commit 14e639d

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 86ab9b7 commit 14e639d

File tree

3 files changed

+86
-25
lines changed

3 files changed

+86
-25
lines changed

examples/agents/composite_actor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,9 @@ def forward(self, x):
5050
data = TensorDict({"x": torch.rand(10)}, [])
5151
module(data)
5252
print(actor(data))
53+
54+
55+
# TODO:
56+
# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action")
57+
# 2. Must multi-head require an action_key to be a list of keys (I guess so)
58+
# 3. Using maps in the Actor

test/test_cost.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7908,11 +7908,11 @@ def _create_mock_actor(
79087908
obs_dim=3,
79097909
action_dim=4,
79107910
device="cpu",
7911-
action_key="action",
7911+
action_key=None,
79127912
observation_key="observation",
79137913
sample_log_prob_key="sample_log_prob",
79147914
composite_action_dist=False,
7915-
aggregate_probabilities=True,
7915+
aggregate_probabilities=None,
79167916
):
79177917
# Actor
79187918
action_spec = Bounded(
@@ -7922,13 +7922,17 @@ def _create_mock_actor(
79227922
action_spec = Composite({action_key: {"action1": action_spec}})
79237923
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
79247924
if composite_action_dist:
7925+
if action_key is None:
7926+
action_key = ("action", "action1")
7927+
else:
7928+
action_key = (action_key, "action1")
79257929
distribution_class = functools.partial(
79267930
CompositeDistribution,
79277931
distribution_map={
79287932
"action1": TanhNormal,
79297933
},
79307934
name_map={
7931-
"action1": (action_key, "action1"),
7935+
"action1": action_key,
79327936
},
79337937
log_prob_key=sample_log_prob_key,
79347938
aggregate_probabilities=aggregate_probabilities,
@@ -7939,6 +7943,8 @@ def _create_mock_actor(
79397943
]
79407944
actor_in_keys = ["params"]
79417945
else:
7946+
if action_key is None:
7947+
action_key = "action"
79427948
distribution_class = TanhNormal
79437949
module_out_keys = actor_in_keys = ["loc", "scale"]
79447950
module = TensorDictModule(
@@ -8149,8 +8155,8 @@ def _create_seq_mock_data_ppo(
81498155
action_dim=4,
81508156
atoms=None,
81518157
device="cpu",
8152-
sample_log_prob_key="sample_log_prob",
8153-
action_key="action",
8158+
sample_log_prob_key=None,
8159+
action_key=None,
81548160
composite_action_dist=False,
81558161
):
81568162
# create a tensordict
@@ -8172,6 +8178,17 @@ def _create_seq_mock_data_ppo(
81728178
params_scale = torch.rand_like(action) / 10
81738179
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
81748180
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
8181+
if sample_log_prob_key is None:
8182+
if composite_action_dist:
8183+
sample_log_prob_key = ("action", "action1_log_prob")
8184+
else:
8185+
sample_log_prob_key = "sample_log_prob"
8186+
8187+
if action_key is None:
8188+
if composite_action_dist:
8189+
action_key = ("action", "action1")
8190+
else:
8191+
action_key = "action"
81758192
td = TensorDict(
81768193
batch_size=(batch, T),
81778194
source={
@@ -8183,7 +8200,7 @@ def _create_seq_mock_data_ppo(
81838200
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
81848201
},
81858202
"collector": {"mask": mask},
8186-
action_key: {"action1": action} if composite_action_dist else action,
8203+
action_key: action,
81878204
sample_log_prob_key: (
81888205
torch.randn_like(action[..., 1]) / 10
81898206
).masked_fill_(~mask, 0.0),
@@ -8263,6 +8280,13 @@ def test_ppo(
82638280
loss_critic_type="l2",
82648281
functional=functional,
82658282
)
8283+
if composite_action_dist:
8284+
loss_fn.set_keys(
8285+
action=("action", "action1"),
8286+
sample_log_prob=[("action", "action1_log_prob")],
8287+
)
8288+
if advantage is not None:
8289+
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
82668290
if advantage is not None:
82678291
advantage(td)
82688292
else:
@@ -8356,7 +8380,9 @@ def test_ppo_composite_no_aggregate(
83568380
loss_critic_type="l2",
83578381
functional=functional,
83588382
)
8383+
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
83598384
if advantage is not None:
8385+
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
83608386
advantage(td)
83618387
else:
83628388
if td_est is not None:
@@ -8464,7 +8490,12 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
84648490
)
84658491

84668492
if advantage is not None:
8493+
if composite_action_dist:
8494+
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
84678495
advantage(td)
8496+
8497+
if composite_action_dist:
8498+
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
84688499
loss = loss_fn(td)
84698500

84708501
loss_critic = loss["loss_critic"]
@@ -8571,7 +8602,14 @@ def test_ppo_shared_seq(
85718602
)
85728603

85738604
if advantage is not None:
8605+
if composite_action_dist:
8606+
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
85748607
advantage(td)
8608+
8609+
if composite_action_dist:
8610+
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8611+
loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8612+
85758613
loss = loss_fn(td).exclude("entropy")
85768614

85778615
sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
@@ -8659,7 +8697,11 @@ def zero_param(p):
86598697
# assert len(list(floss_fn.parameters())) == 0
86608698
with params.to_module(loss_fn):
86618699
if advantage is not None:
8700+
if composite_action_dist:
8701+
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
86628702
advantage(td)
8703+
if composite_action_dist:
8704+
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
86638705
loss = loss_fn(td)
86648706

86658707
loss_critic = loss["loss_critic"]
@@ -8760,8 +8802,8 @@ def test_ppo_tensordict_keys_run(
87608802
"advantage": "advantage_test",
87618803
"value_target": "value_target_test",
87628804
"value": "state_value_test",
8763-
"sample_log_prob": "sample_log_prob_test",
8764-
"action": "action_test",
8805+
"sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test",
8806+
"action": ("action_test", "action") if composite_action_dist else "action_test",
87658807
}
87668808

87678809
td = self._create_seq_mock_data_ppo(
@@ -8809,6 +8851,8 @@ def test_ppo_tensordict_keys_run(
88098851
raise NotImplementedError
88108852

88118853
loss_fn = loss_class(actor, value, loss_critic_type="l2")
8854+
if composite_action_dist:
8855+
tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]]
88128856
loss_fn.set_keys(**tensor_keys)
88138857
if advantage is not None:
88148858
# collect tensordict key names for the advantage module

torchrl/objectives/ppo.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import contextlib
8+
import warnings
89

910
from copy import deepcopy
1011
from dataclasses import dataclass
@@ -531,26 +532,35 @@ def _log_weight(
531532
raise RuntimeError(
532533
f"tensordict stored {self.tensor_keys.action} requires grad."
533534
)
534-
if isinstance(action, torch.Tensor):
535+
if isinstance(dist, CompositeDistribution):
536+
is_composite = True
537+
aggregate = dist.aggregate_probabilities
538+
if aggregate is None:
539+
aggregate = False
540+
include_sum = dist.include_sum
541+
if include_sum is None:
542+
include_sum = False
543+
kwargs = {
544+
"inplace": False,
545+
"aggregate_probabilities": aggregate,
546+
"include_sum": include_sum,
547+
}
548+
else:
549+
is_composite = False
550+
kwargs = {}
551+
if not is_composite:
535552
log_prob = dist.log_prob(action)
536553
else:
537-
if isinstance(dist, CompositeDistribution):
538-
is_composite = True
539-
aggregate = dist.aggregate_probabilities
540-
if aggregate is None:
541-
aggregate = False
542-
include_sum = dist.include_sum
543-
if include_sum is None:
544-
include_sum = False
545-
kwargs = {
546-
"inplace": False,
547-
"aggregate_probabilities": aggregate,
548-
"include_sum": include_sum,
549-
}
550-
else:
551-
is_composite = False
552-
kwargs = {}
553554
log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs)
555+
if not is_tensor_collection(prev_log_prob):
556+
# this isn't great, in general multihead actions should have a composite log-prob too
557+
warnings.warn(
558+
"You are using a composite distribution, yet your log-probability is a tensor. "
559+
"This usually happens whenever the CompositeDistribution has aggregate_probabilities=True "
560+
"or include_sum=True. These options should be avoided: leaf log-probs should be written "
561+
"independently and PPO will take care of the aggregation.",
562+
category=UserWarning,
563+
)
554564
if (
555565
is_composite
556566
and not is_tensor_collection(prev_log_prob)
@@ -559,6 +569,7 @@ def _log_weight(
559569
log_prob = _sum_td_features(log_prob)
560570
log_prob.view_as(prev_log_prob)
561571

572+
print(log_prob , prev_log_prob)
562573
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
563574
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
564575
if is_tensor_collection(kl_approx):

0 commit comments

Comments
 (0)