8
8
import itertools
9
9
import operator
10
10
import os
11
-
12
11
import sys
13
12
import warnings
14
13
from copy import deepcopy
23
22
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
24
23
from tensordict._C import unravel_keys
25
24
from tensordict.nn import (
25
+ composite_lp_aggregate,
26
26
CompositeDistribution,
27
27
InteractionType,
28
28
NormalParamExtractor,
29
29
ProbabilisticTensorDictModule,
30
30
ProbabilisticTensorDictModule as ProbMod,
31
31
ProbabilisticTensorDictSequential,
32
32
ProbabilisticTensorDictSequential as ProbSeq,
33
+ set_composite_lp_aggregate,
33
34
TensorDictModule,
34
35
TensorDictModule as Mod,
35
36
TensorDictSequential,
@@ -3540,6 +3541,13 @@ def test_td3bc_reduction(self, reduction):
3540
3541
class TestSAC(LossModuleTestBase):
3541
3542
seed = 0
3542
3543
3544
+ @pytest.fixture(scope="class", autouse=True)
3545
+ def _composite_log_prob(self):
3546
+ setter = set_composite_lp_aggregate(False)
3547
+ setter.set()
3548
+ yield
3549
+ setter.unset()
3550
+
3543
3551
def _create_mock_actor(
3544
3552
self,
3545
3553
batch=2,
@@ -3563,7 +3571,6 @@ def _create_mock_actor(
3563
3571
distribution_map={
3564
3572
"action1": TanhNormal,
3565
3573
},
3566
- aggregate_probabilities=True,
3567
3574
)
3568
3575
module_out_keys = [
3569
3576
("params", "action1", "loc"),
@@ -3583,6 +3590,7 @@ def _create_mock_actor(
3583
3590
out_keys=[action_key],
3584
3591
spec=action_spec,
3585
3592
)
3593
+ assert actor.log_prob_keys
3586
3594
return actor.to(device)
3587
3595
3588
3596
def _create_mock_qvalue(
@@ -3688,7 +3696,6 @@ def forward(self, obs, act):
3688
3696
distribution_map={
3689
3697
"action1": TanhNormal,
3690
3698
},
3691
- aggregate_probabilities=True,
3692
3699
)
3693
3700
module_out_keys = [
3694
3701
("params", "action1", "loc"),
@@ -4342,7 +4349,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
4342
4349
"value": "state_value",
4343
4350
"state_action_value": "state_action_value",
4344
4351
"action": "action",
4345
- "log_prob": "sample_log_prob ",
4352
+ "log_prob": "action_log_prob ",
4346
4353
"reward": "reward",
4347
4354
"done": "done",
4348
4355
"terminated": "terminated",
@@ -4616,6 +4623,13 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
4616
4623
class TestDiscreteSAC(LossModuleTestBase):
4617
4624
seed = 0
4618
4625
4626
+ @pytest.fixture(scope="class", autouse=True)
4627
+ def _composite_log_prob(self):
4628
+ setter = set_composite_lp_aggregate(False)
4629
+ setter.set()
4630
+ yield
4631
+ setter.unset()
4632
+
4619
4633
def _create_mock_actor(
4620
4634
self,
4621
4635
batch=2,
@@ -7902,6 +7916,13 @@ def test_dcql_reduction(self, reduction):
7902
7916
class TestPPO(LossModuleTestBase):
7903
7917
seed = 0
7904
7918
7919
+ @pytest.fixture(scope="class", autouse=True)
7920
+ def _composite_log_prob(self):
7921
+ setter = set_composite_lp_aggregate(False)
7922
+ setter.set()
7923
+ yield
7924
+ setter.unset()
7925
+
7905
7926
def _create_mock_actor(
7906
7927
self,
7907
7928
batch=2,
@@ -7910,9 +7931,8 @@ def _create_mock_actor(
7910
7931
device="cpu",
7911
7932
action_key=None,
7912
7933
observation_key="observation",
7913
- sample_log_prob_key="sample_log_prob" ,
7934
+ sample_log_prob_key=None ,
7914
7935
composite_action_dist=False,
7915
- aggregate_probabilities=None,
7916
7936
):
7917
7937
# Actor
7918
7938
action_spec = Bounded(
@@ -7934,7 +7954,6 @@ def _create_mock_actor(
7934
7954
"action1": action_key,
7935
7955
},
7936
7956
log_prob_key=sample_log_prob_key,
7937
- aggregate_probabilities=aggregate_probabilities,
7938
7957
)
7939
7958
module_out_keys = [
7940
7959
("params", "action1", "loc"),
@@ -8006,7 +8025,6 @@ def _create_mock_actor_value(
8006
8025
"action1": ("action", "action1"),
8007
8026
},
8008
8027
log_prob_key=sample_log_prob_key,
8009
- aggregate_probabilities=True,
8010
8028
)
8011
8029
module_out_keys = [
8012
8030
("params", "action1", "loc"),
@@ -8063,7 +8081,6 @@ def _create_mock_actor_value_shared(
8063
8081
"action1": ("action", "action1"),
8064
8082
},
8065
8083
log_prob_key=sample_log_prob_key,
8066
- aggregate_probabilities=True,
8067
8084
)
8068
8085
module_out_keys = [
8069
8086
("params", "action1", "loc"),
@@ -8181,7 +8198,8 @@ def _create_seq_mock_data_ppo(
8181
8198
if composite_action_dist:
8182
8199
sample_log_prob_key = ("action", "action1_log_prob")
8183
8200
else:
8184
- sample_log_prob_key = "sample_log_prob"
8201
+ # conforming to composite_lp_aggregate(False)
8202
+ sample_log_prob_key = "action_log_prob"
8185
8203
8186
8204
if action_key is None:
8187
8205
if composite_action_dist:
@@ -8287,6 +8305,7 @@ def test_ppo(
8287
8305
if advantage is not None:
8288
8306
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
8289
8307
if advantage is not None:
8308
+ assert not composite_lp_aggregate()
8290
8309
advantage(td)
8291
8310
else:
8292
8311
if td_est is not None:
@@ -8346,7 +8365,6 @@ def test_ppo_composite_no_aggregate(
8346
8365
actor = self._create_mock_actor(
8347
8366
device=device,
8348
8367
composite_action_dist=True,
8349
- aggregate_probabilities=False,
8350
8368
)
8351
8369
value = self._create_mock_value(device=device)
8352
8370
if advantage == "gae":
@@ -8766,6 +8784,7 @@ def zero_param(p):
8766
8784
)
8767
8785
@pytest.mark.parametrize("composite_action_dist", [True, False])
8768
8786
def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
8787
+ assert not composite_lp_aggregate()
8769
8788
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
8770
8789
value = self._create_mock_value()
8771
8790
@@ -8775,8 +8794,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
8775
8794
"advantage": "advantage",
8776
8795
"value_target": "value_target",
8777
8796
"value": "state_value",
8778
- "sample_log_prob": "sample_log_prob",
8779
- "action": "action",
8797
+ "sample_log_prob": "action_log_prob"
8798
+ if not composite_action_dist
8799
+ else ("action", "action1_log_prob"),
8800
+ "action": "action" if not composite_action_dist else ("action", "action1"),
8780
8801
"reward": "reward",
8781
8802
"done": "done",
8782
8803
"terminated": "terminated",
@@ -9162,9 +9183,6 @@ def mixture_constructor(logits, loc, scale):
9162
9183
"Kumaraswamy": ("agent1", "action"),
9163
9184
"mixture": ("agent2", "action"),
9164
9185
},
9165
- aggregate_probabilities=False,
9166
- include_sum=False,
9167
- inplace=True,
9168
9186
)
9169
9187
policy = ProbSeq(
9170
9188
make_params,
@@ -9183,15 +9201,11 @@ def mixture_constructor(logits, loc, scale):
9183
9201
# We want to make sure there is no warning
9184
9202
td = policy(TensorDict(batch_size=[4]))
9185
9203
assert isinstance(
9186
- policy.get_dist(td).log_prob(
9187
- td, aggregate_probabilities=False, inplace=False, include_sum=False
9188
- ),
9204
+ policy.get_dist(td).log_prob(td),
9189
9205
TensorDict,
9190
9206
)
9191
9207
assert isinstance(
9192
- policy.log_prob(
9193
- td, aggregate_probabilities=False, inplace=False, include_sum=False
9194
- ),
9208
+ policy.log_prob(td),
9195
9209
TensorDict,
9196
9210
)
9197
9211
value_operator = Seq(
@@ -9228,6 +9242,13 @@ def mixture_constructor(logits, loc, scale):
9228
9242
class TestA2C(LossModuleTestBase):
9229
9243
seed = 0
9230
9244
9245
+ @pytest.fixture(scope="class", autouse=True)
9246
+ def _composite_log_prob(self):
9247
+ setter = set_composite_lp_aggregate(False)
9248
+ setter.set()
9249
+ yield
9250
+ setter.unset()
9251
+
9231
9252
def _create_mock_actor(
9232
9253
self,
9233
9254
batch=2,
@@ -9236,8 +9257,8 @@ def _create_mock_actor(
9236
9257
device="cpu",
9237
9258
action_key="action",
9238
9259
observation_key="observation",
9239
- sample_log_prob_key="sample_log_prob",
9240
9260
composite_action_dist=False,
9261
+ sample_log_prob_key=None,
9241
9262
):
9242
9263
# Actor
9243
9264
action_spec = Bounded(
@@ -9255,8 +9276,6 @@ def _create_mock_actor(
9255
9276
name_map={
9256
9277
"action1": (action_key, "action1"),
9257
9278
},
9258
- log_prob_key=sample_log_prob_key,
9259
- aggregate_probabilities=True,
9260
9279
)
9261
9280
module_out_keys = [
9262
9281
("params", "action1", "loc"),
@@ -9306,7 +9325,6 @@ def _create_mock_common_layer_setup(
9306
9325
n_hidden=2,
9307
9326
T=10,
9308
9327
composite_action_dist=False,
9309
- sample_log_prob_key="sample_log_prob",
9310
9328
):
9311
9329
common_net = MLP(
9312
9330
num_cells=ncells,
@@ -9332,7 +9350,7 @@ def _create_mock_common_layer_setup(
9332
9350
{
9333
9351
"obs": torch.randn(*batch, n_obs),
9334
9352
"action": {"action1": action} if composite_action_dist else action,
9335
- "sample_log_prob ": torch.randn(*batch),
9353
+ "action_log_prob ": torch.randn(*batch),
9336
9354
"done": torch.zeros(*batch, 1, dtype=torch.bool),
9337
9355
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
9338
9356
"next": {
@@ -9356,8 +9374,6 @@ def _create_mock_common_layer_setup(
9356
9374
name_map={
9357
9375
"action1": ("action", "action1"),
9358
9376
},
9359
- log_prob_key=sample_log_prob_key,
9360
- aggregate_probabilities=True,
9361
9377
)
9362
9378
module_out_keys = [
9363
9379
("params", "action1", "loc"),
@@ -9398,7 +9414,7 @@ def _create_seq_mock_data_a2c(
9398
9414
reward_key="reward",
9399
9415
done_key="done",
9400
9416
terminated_key="terminated",
9401
- sample_log_prob_key="sample_log_prob ",
9417
+ sample_log_prob_key="action_log_prob ",
9402
9418
composite_action_dist=False,
9403
9419
):
9404
9420
# create a tensordict
@@ -9530,6 +9546,11 @@ def set_requires_grad(tensor, requires_grad):
9530
9546
9531
9547
td = td.exclude(loss_fn.tensor_keys.value_target)
9532
9548
if advantage is not None:
9549
+ advantage.set_keys(
9550
+ sample_log_prob=actor.log_prob_keys
9551
+ if composite_action_dist
9552
+ else "action_log_prob"
9553
+ )
9533
9554
advantage(td)
9534
9555
elif td_est is not None:
9535
9556
loss_fn.make_value_estimator(td_est)
@@ -9749,7 +9770,7 @@ def test_a2c_tensordict_keys(self, td_est, composite_action_dist):
9749
9770
"reward": "reward",
9750
9771
"done": "done",
9751
9772
"terminated": "terminated",
9752
- "sample_log_prob": "sample_log_prob ",
9773
+ "sample_log_prob": "action_log_prob ",
9753
9774
}
9754
9775
9755
9776
self.tensordict_keys_test(
0 commit comments