@@ -7908,11 +7908,11 @@ def _create_mock_actor(
7908
7908
obs_dim=3,
7909
7909
action_dim=4,
7910
7910
device="cpu",
7911
- action_key="action" ,
7911
+ action_key=None ,
7912
7912
observation_key="observation",
7913
7913
sample_log_prob_key="sample_log_prob",
7914
7914
composite_action_dist=False,
7915
- aggregate_probabilities=True ,
7915
+ aggregate_probabilities=None ,
7916
7916
):
7917
7917
# Actor
7918
7918
action_spec = Bounded(
@@ -7922,13 +7922,17 @@ def _create_mock_actor(
7922
7922
action_spec = Composite({action_key: {"action1": action_spec}})
7923
7923
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
7924
7924
if composite_action_dist:
7925
+ if action_key is None:
7926
+ action_key = ("action", "action1")
7927
+ else:
7928
+ action_key = (action_key, "action1")
7925
7929
distribution_class = functools.partial(
7926
7930
CompositeDistribution,
7927
7931
distribution_map={
7928
7932
"action1": TanhNormal,
7929
7933
},
7930
7934
name_map={
7931
- "action1": ( action_key, "action1") ,
7935
+ "action1": action_key,
7932
7936
},
7933
7937
log_prob_key=sample_log_prob_key,
7934
7938
aggregate_probabilities=aggregate_probabilities,
@@ -7939,6 +7943,8 @@ def _create_mock_actor(
7939
7943
]
7940
7944
actor_in_keys = ["params"]
7941
7945
else:
7946
+ if action_key is None:
7947
+ action_key = "action"
7942
7948
distribution_class = TanhNormal
7943
7949
module_out_keys = actor_in_keys = ["loc", "scale"]
7944
7950
module = TensorDictModule(
@@ -8149,8 +8155,8 @@ def _create_seq_mock_data_ppo(
8149
8155
action_dim=4,
8150
8156
atoms=None,
8151
8157
device="cpu",
8152
- sample_log_prob_key="sample_log_prob" ,
8153
- action_key="action" ,
8158
+ sample_log_prob_key=None ,
8159
+ action_key=None ,
8154
8160
composite_action_dist=False,
8155
8161
):
8156
8162
# create a tensordict
@@ -8172,6 +8178,17 @@ def _create_seq_mock_data_ppo(
8172
8178
params_scale = torch.rand_like(action) / 10
8173
8179
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
8174
8180
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
8181
+ if sample_log_prob_key is None:
8182
+ if composite_action_dist:
8183
+ sample_log_prob_key = ("action", "action1_log_prob")
8184
+ else:
8185
+ sample_log_prob_key = "sample_log_prob"
8186
+
8187
+ if action_key is None:
8188
+ if composite_action_dist:
8189
+ action_key = ("action", "action1")
8190
+ else:
8191
+ action_key = "action"
8175
8192
td = TensorDict(
8176
8193
batch_size=(batch, T),
8177
8194
source={
@@ -8183,7 +8200,7 @@ def _create_seq_mock_data_ppo(
8183
8200
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
8184
8201
},
8185
8202
"collector": {"mask": mask},
8186
- action_key: {"action1": action} if composite_action_dist else action,
8203
+ action_key: action,
8187
8204
sample_log_prob_key: (
8188
8205
torch.randn_like(action[..., 1]) / 10
8189
8206
).masked_fill_(~mask, 0.0),
@@ -8263,6 +8280,13 @@ def test_ppo(
8263
8280
loss_critic_type="l2",
8264
8281
functional=functional,
8265
8282
)
8283
+ if composite_action_dist:
8284
+ loss_fn.set_keys(
8285
+ action=("action", "action1"),
8286
+ sample_log_prob=[("action", "action1_log_prob")],
8287
+ )
8288
+ if advantage is not None:
8289
+ advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8266
8290
if advantage is not None:
8267
8291
advantage(td)
8268
8292
else:
@@ -8356,7 +8380,9 @@ def test_ppo_composite_no_aggregate(
8356
8380
loss_critic_type="l2",
8357
8381
functional=functional,
8358
8382
)
8383
+ loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8359
8384
if advantage is not None:
8385
+ advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8360
8386
advantage(td)
8361
8387
else:
8362
8388
if td_est is not None:
@@ -8464,7 +8490,12 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
8464
8490
)
8465
8491
8466
8492
if advantage is not None:
8493
+ if composite_action_dist:
8494
+ advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8467
8495
advantage(td)
8496
+
8497
+ if composite_action_dist:
8498
+ loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8468
8499
loss = loss_fn(td)
8469
8500
8470
8501
loss_critic = loss["loss_critic"]
@@ -8571,7 +8602,14 @@ def test_ppo_shared_seq(
8571
8602
)
8572
8603
8573
8604
if advantage is not None:
8605
+ if composite_action_dist:
8606
+ advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8574
8607
advantage(td)
8608
+
8609
+ if composite_action_dist:
8610
+ loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8611
+ loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8612
+
8575
8613
loss = loss_fn(td).exclude("entropy")
8576
8614
8577
8615
sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
@@ -8659,7 +8697,11 @@ def zero_param(p):
8659
8697
# assert len(list(floss_fn.parameters())) == 0
8660
8698
with params.to_module(loss_fn):
8661
8699
if advantage is not None:
8700
+ if composite_action_dist:
8701
+ advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8662
8702
advantage(td)
8703
+ if composite_action_dist:
8704
+ loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
8663
8705
loss = loss_fn(td)
8664
8706
8665
8707
loss_critic = loss["loss_critic"]
@@ -8760,8 +8802,8 @@ def test_ppo_tensordict_keys_run(
8760
8802
"advantage": "advantage_test",
8761
8803
"value_target": "value_target_test",
8762
8804
"value": "state_value_test",
8763
- "sample_log_prob": "sample_log_prob_test",
8764
- "action": "action_test",
8805
+ "sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test",
8806
+ "action": ("action_test", "action") if composite_action_dist else "action_test",
8765
8807
}
8766
8808
8767
8809
td = self._create_seq_mock_data_ppo(
@@ -8809,6 +8851,8 @@ def test_ppo_tensordict_keys_run(
8809
8851
raise NotImplementedError
8810
8852
8811
8853
loss_fn = loss_class(actor, value, loss_critic_type="l2")
8854
+ if composite_action_dist:
8855
+ tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]]
8812
8856
loss_fn.set_keys(**tensor_keys)
8813
8857
if advantage is not None:
8814
8858
# collect tensordict key names for the advantage module
0 commit comments