@@ -7918,14 +7918,13 @@ def _create_mock_actor(
7918
7918
action_spec = Bounded(
7919
7919
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
7920
7920
)
7921
- if composite_action_dist:
7922
- action_spec = Composite({action_key: {"action1": action_spec}})
7923
7921
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
7924
7922
if composite_action_dist:
7925
7923
if action_key is None:
7926
7924
action_key = ("action", "action1")
7927
7925
else:
7928
7926
action_key = (action_key, "action1")
7927
+ action_spec = Composite({action_key: {"action1": action_spec}})
7929
7928
distribution_class = functools.partial(
7930
7929
CompositeDistribution,
7931
7930
distribution_map={
@@ -8380,7 +8379,10 @@ def test_ppo_composite_no_aggregate(
8380
8379
loss_critic_type="l2",
8381
8380
functional=functional,
8382
8381
)
8383
- loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8382
+ loss_fn.set_keys(
8383
+ action=("action", "action1"),
8384
+ sample_log_prob=[("action", "action1_log_prob")],
8385
+ )
8384
8386
if advantage is not None:
8385
8387
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8386
8388
advantage(td)
@@ -8495,7 +8497,10 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
8495
8497
advantage(td)
8496
8498
8497
8499
if composite_action_dist:
8498
- loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8500
+ loss_fn.set_keys(
8501
+ action=("action", "action1"),
8502
+ sample_log_prob=[("action", "action1_log_prob")],
8503
+ )
8499
8504
loss = loss_fn(td)
8500
8505
8501
8506
loss_critic = loss["loss_critic"]
@@ -8607,8 +8612,14 @@ def test_ppo_shared_seq(
8607
8612
advantage(td)
8608
8613
8609
8614
if composite_action_dist:
8610
- loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8611
- loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8615
+ loss_fn.set_keys(
8616
+ action=("action", "action1"),
8617
+ sample_log_prob=[("action", "action1_log_prob")],
8618
+ )
8619
+ loss_fn2.set_keys(
8620
+ action=("action", "action1"),
8621
+ sample_log_prob=[("action", "action1_log_prob")],
8622
+ )
8612
8623
8613
8624
loss = loss_fn(td).exclude("entropy")
8614
8625
@@ -8701,7 +8712,10 @@ def zero_param(p):
8701
8712
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8702
8713
advantage(td)
8703
8714
if composite_action_dist:
8704
- loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8715
+ loss_fn.set_keys(
8716
+ action=("action", "action1"),
8717
+ sample_log_prob=[("action", "action1_log_prob")],
8718
+ )
8705
8719
loss = loss_fn(td)
8706
8720
8707
8721
loss_critic = loss["loss_critic"]
@@ -8791,29 +8805,24 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
8791
8805
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
8792
8806
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
8793
8807
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
8794
- @pytest.mark.parametrize("composite_action_dist", [True, False])
8795
- def test_ppo_tensordict_keys_run(
8796
- self, loss_class, advantage, td_est, composite_action_dist
8797
- ):
8808
+ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
8798
8809
"""Test PPO loss module with non-default tensordict keys."""
8799
8810
torch.manual_seed(self.seed)
8800
8811
gradient_mode = True
8801
8812
tensor_keys = {
8802
8813
"advantage": "advantage_test",
8803
8814
"value_target": "value_target_test",
8804
8815
"value": "state_value_test",
8805
- "sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test",
8806
- "action": ("action_test", "action") if composite_action_dist else "action_test",
8816
+ "sample_log_prob": "sample_log_prob_test",
8817
+ "action": "action_test",
8807
8818
}
8808
8819
8809
8820
td = self._create_seq_mock_data_ppo(
8810
8821
sample_log_prob_key=tensor_keys["sample_log_prob"],
8811
8822
action_key=tensor_keys["action"],
8812
- composite_action_dist=composite_action_dist,
8813
8823
)
8814
8824
actor = self._create_mock_actor(
8815
8825
sample_log_prob_key=tensor_keys["sample_log_prob"],
8816
- composite_action_dist=composite_action_dist,
8817
8826
action_key=tensor_keys["action"],
8818
8827
)
8819
8828
value = self._create_mock_value(out_keys=[tensor_keys["value"]])
@@ -8851,8 +8860,6 @@ def test_ppo_tensordict_keys_run(
8851
8860
raise NotImplementedError
8852
8861
8853
8862
loss_fn = loss_class(actor, value, loss_critic_type="l2")
8854
- if composite_action_dist:
8855
- tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]]
8856
8863
loss_fn.set_keys(**tensor_keys)
8857
8864
if advantage is not None:
8858
8865
# collect tensordict key names for the advantage module
@@ -9030,11 +9037,16 @@ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist):
9030
9037
reduction=reduction,
9031
9038
)
9032
9039
advantage(td)
9040
+ if composite_action_dist:
9041
+ loss_fn.set_keys(
9042
+ action=("action", "action1"),
9043
+ sample_log_prob=[("action", "action1_log_prob")],
9044
+ )
9033
9045
loss = loss_fn(td)
9034
9046
if reduction == "none":
9035
9047
for key in loss.keys():
9036
9048
if key.startswith("loss_"):
9037
- assert loss[key].shape == td.shape
9049
+ assert loss[key].shape == td.shape, key
9038
9050
else:
9039
9051
for key in loss.keys():
9040
9052
if not key.startswith("loss_"):
@@ -9082,6 +9094,11 @@ def test_ppo_value_clipping(
9082
9094
clip_value=clip_value,
9083
9095
)
9084
9096
advantage(td)
9097
+ if composite_action_dist:
9098
+ loss_fn.set_keys(
9099
+ action=("action", "action1"),
9100
+ sample_log_prob=[("action", "action1_log_prob")],
9101
+ )
9085
9102
9086
9103
value = td.pop(loss_fn.tensor_keys.value)
9087
9104
0 commit comments