Skip to content

Commit 6dc021b

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 399b618 + 1ee71e3 commit 6dc021b

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed

test/test_cost.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7918,14 +7918,13 @@ def _create_mock_actor(
79187918
action_spec = Bounded(
79197919
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
79207920
)
7921-
if composite_action_dist:
7922-
action_spec = Composite({action_key: {"action1": action_spec}})
79237921
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
79247922
if composite_action_dist:
79257923
if action_key is None:
79267924
action_key = ("action", "action1")
79277925
else:
79287926
action_key = (action_key, "action1")
7927+
action_spec = Composite({action_key: {"action1": action_spec}})
79297928
distribution_class = functools.partial(
79307929
CompositeDistribution,
79317930
distribution_map={
@@ -8380,7 +8379,10 @@ def test_ppo_composite_no_aggregate(
83808379
loss_critic_type="l2",
83818380
functional=functional,
83828381
)
8383-
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8382+
loss_fn.set_keys(
8383+
action=("action", "action1"),
8384+
sample_log_prob=[("action", "action1_log_prob")],
8385+
)
83848386
if advantage is not None:
83858387
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
83868388
advantage(td)
@@ -8495,7 +8497,10 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
84958497
advantage(td)
84968498

84978499
if composite_action_dist:
8498-
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8500+
loss_fn.set_keys(
8501+
action=("action", "action1"),
8502+
sample_log_prob=[("action", "action1_log_prob")],
8503+
)
84998504
loss = loss_fn(td)
85008505

85018506
loss_critic = loss["loss_critic"]
@@ -8607,8 +8612,14 @@ def test_ppo_shared_seq(
86078612
advantage(td)
86088613

86098614
if composite_action_dist:
8610-
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8611-
loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8615+
loss_fn.set_keys(
8616+
action=("action", "action1"),
8617+
sample_log_prob=[("action", "action1_log_prob")],
8618+
)
8619+
loss_fn2.set_keys(
8620+
action=("action", "action1"),
8621+
sample_log_prob=[("action", "action1_log_prob")],
8622+
)
86128623

86138624
loss = loss_fn(td).exclude("entropy")
86148625

@@ -8701,7 +8712,10 @@ def zero_param(p):
87018712
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
87028713
advantage(td)
87038714
if composite_action_dist:
8704-
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8715+
loss_fn.set_keys(
8716+
action=("action", "action1"),
8717+
sample_log_prob=[("action", "action1_log_prob")],
8718+
)
87058719
loss = loss_fn(td)
87068720

87078721
loss_critic = loss["loss_critic"]
@@ -8791,29 +8805,24 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
87918805
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
87928806
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
87938807
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
8794-
@pytest.mark.parametrize("composite_action_dist", [True, False])
8795-
def test_ppo_tensordict_keys_run(
8796-
self, loss_class, advantage, td_est, composite_action_dist
8797-
):
8808+
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
87988809
"""Test PPO loss module with non-default tensordict keys."""
87998810
torch.manual_seed(self.seed)
88008811
gradient_mode = True
88018812
tensor_keys = {
88028813
"advantage": "advantage_test",
88038814
"value_target": "value_target_test",
88048815
"value": "state_value_test",
8805-
"sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test",
8806-
"action": ("action_test", "action") if composite_action_dist else "action_test",
8816+
"sample_log_prob": "sample_log_prob_test",
8817+
"action": "action_test",
88078818
}
88088819

88098820
td = self._create_seq_mock_data_ppo(
88108821
sample_log_prob_key=tensor_keys["sample_log_prob"],
88118822
action_key=tensor_keys["action"],
8812-
composite_action_dist=composite_action_dist,
88138823
)
88148824
actor = self._create_mock_actor(
88158825
sample_log_prob_key=tensor_keys["sample_log_prob"],
8816-
composite_action_dist=composite_action_dist,
88178826
action_key=tensor_keys["action"],
88188827
)
88198828
value = self._create_mock_value(out_keys=[tensor_keys["value"]])
@@ -8851,8 +8860,6 @@ def test_ppo_tensordict_keys_run(
88518860
raise NotImplementedError
88528861

88538862
loss_fn = loss_class(actor, value, loss_critic_type="l2")
8854-
if composite_action_dist:
8855-
tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]]
88568863
loss_fn.set_keys(**tensor_keys)
88578864
if advantage is not None:
88588865
# collect tensordict key names for the advantage module
@@ -9030,11 +9037,16 @@ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist):
90309037
reduction=reduction,
90319038
)
90329039
advantage(td)
9040+
if composite_action_dist:
9041+
loss_fn.set_keys(
9042+
action=("action", "action1"),
9043+
sample_log_prob=[("action", "action1_log_prob")],
9044+
)
90339045
loss = loss_fn(td)
90349046
if reduction == "none":
90359047
for key in loss.keys():
90369048
if key.startswith("loss_"):
9037-
assert loss[key].shape == td.shape
9049+
assert loss[key].shape == td.shape, key
90389050
else:
90399051
for key in loss.keys():
90409052
if not key.startswith("loss_"):
@@ -9082,6 +9094,11 @@ def test_ppo_value_clipping(
90829094
clip_value=clip_value,
90839095
)
90849096
advantage(td)
9097+
if composite_action_dist:
9098+
loss_fn.set_keys(
9099+
action=("action", "action1"),
9100+
sample_log_prob=[("action", "action1_log_prob")],
9101+
)
90859102

90869103
value = td.pop(loss_fn.tensor_keys.value)
90879104

torchrl/objectives/ppo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,6 @@ def _log_weight(
569569
log_prob = _sum_td_features(log_prob)
570570
log_prob.view_as(prev_log_prob)
571571

572-
print(log_prob , prev_log_prob)
573572
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
574573
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
575574
if is_tensor_collection(kl_approx):
@@ -946,7 +945,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
946945
ratio = log_weight_clip.exp()
947946
gain2 = ratio * advantage
948947

949-
gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
948+
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
950949
if is_tensor_collection(gain):
951950
gain = _sum_td_features(gain)
952951
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])

0 commit comments

Comments
 (0)