18
18
import torch
19
19
20
20
from packaging import version, version as pack_version
21
-
22
21
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
23
22
from tensordict._C import unravel_keys
24
23
from tensordict.nn import (
37
36
TensorDictSequential as Seq,
38
37
WrapModule,
39
38
)
39
+ from tensordict.nn.distributions.composite import _add_suffix
40
40
from tensordict.nn.utils import Buffer
41
41
from tensordict.utils import unravel_key
42
42
from torch import autograd, nn
@@ -199,6 +199,13 @@ def get_devices():
199
199
200
200
201
201
class LossModuleTestBase:
202
+ @pytest.fixture(scope="class", autouse=True)
203
+ def _composite_log_prob(self):
204
+ setter = set_composite_lp_aggregate(False)
205
+ setter.set()
206
+ yield
207
+ setter.unset()
208
+
202
209
def __init_subclass__(cls, **kwargs):
203
210
super().__init_subclass__(**kwargs)
204
211
assert hasattr(
@@ -3541,13 +3548,6 @@ def test_td3bc_reduction(self, reduction):
3541
3548
class TestSAC(LossModuleTestBase):
3542
3549
seed = 0
3543
3550
3544
- @pytest.fixture(scope="class", autouse=True)
3545
- def _composite_log_prob(self):
3546
- setter = set_composite_lp_aggregate(False)
3547
- setter.set()
3548
- yield
3549
- setter.unset()
3550
-
3551
3551
def _create_mock_actor(
3552
3552
self,
3553
3553
batch=2,
@@ -4623,13 +4623,6 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
4623
4623
class TestDiscreteSAC(LossModuleTestBase):
4624
4624
seed = 0
4625
4625
4626
- @pytest.fixture(scope="class", autouse=True)
4627
- def _composite_log_prob(self):
4628
- setter = set_composite_lp_aggregate(False)
4629
- setter.set()
4630
- yield
4631
- setter.unset()
4632
-
4633
4626
def _create_mock_actor(
4634
4627
self,
4635
4628
batch=2,
@@ -6786,7 +6779,7 @@ def test_redq_tensordict_keys(self, td_est):
6786
6779
"priority": "td_error",
6787
6780
"action": "action",
6788
6781
"value": "state_value",
6789
- "sample_log_prob": "sample_log_prob ",
6782
+ "sample_log_prob": "action_log_prob ",
6790
6783
"state_action_value": "state_action_value",
6791
6784
"reward": "reward",
6792
6785
"done": "done",
@@ -6849,12 +6842,22 @@ def test_redq_notensordict(
6849
6842
actor_network=actor,
6850
6843
qvalue_network=qvalue,
6851
6844
)
6852
- loss.set_keys(
6853
- action=action_key,
6854
- reward=reward_key,
6855
- done=done_key,
6856
- terminated=terminated_key,
6857
- )
6845
+ if deprec:
6846
+ loss.set_keys(
6847
+ action=action_key,
6848
+ reward=reward_key,
6849
+ done=done_key,
6850
+ terminated=terminated_key,
6851
+ log_prob=_add_suffix(action_key, "_log_prob"),
6852
+ )
6853
+ else:
6854
+ loss.set_keys(
6855
+ action=action_key,
6856
+ reward=reward_key,
6857
+ done=done_key,
6858
+ terminated=terminated_key,
6859
+ sample_log_prob=_add_suffix(action_key, "_log_prob"),
6860
+ )
6858
6861
6859
6862
kwargs = {
6860
6863
action_key: td.get(action_key),
@@ -7916,13 +7919,6 @@ def test_dcql_reduction(self, reduction):
7916
7919
class TestPPO(LossModuleTestBase):
7917
7920
seed = 0
7918
7921
7919
- @pytest.fixture(scope="class", autouse=True)
7920
- def _composite_log_prob(self):
7921
- setter = set_composite_lp_aggregate(False)
7922
- setter.set()
7923
- yield
7924
- setter.unset()
7925
-
7926
7922
def _create_mock_actor(
7927
7923
self,
7928
7924
batch=2,
@@ -8003,7 +7999,7 @@ def _create_mock_actor_value(
8003
7999
action_dim=4,
8004
8000
device="cpu",
8005
8001
composite_action_dist=False,
8006
- sample_log_prob_key="sample_log_prob ",
8002
+ sample_log_prob_key="action_log_prob ",
8007
8003
):
8008
8004
# Actor
8009
8005
action_spec = Bounded(
@@ -8058,7 +8054,7 @@ def _create_mock_actor_value_shared(
8058
8054
action_dim=4,
8059
8055
device="cpu",
8060
8056
composite_action_dist=False,
8061
- sample_log_prob_key="sample_log_prob ",
8057
+ sample_log_prob_key="action_log_prob ",
8062
8058
):
8063
8059
# Actor
8064
8060
action_spec = Bounded(
@@ -8123,7 +8119,7 @@ def _create_mock_data_ppo(
8123
8119
reward_key="reward",
8124
8120
done_key="done",
8125
8121
terminated_key="terminated",
8126
- sample_log_prob_key="sample_log_prob ",
8122
+ sample_log_prob_key="action_log_prob ",
8127
8123
composite_action_dist=False,
8128
8124
):
8129
8125
# create a tensordict
@@ -8834,7 +8830,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
8834
8830
"advantage": "advantage_test",
8835
8831
"value_target": "value_target_test",
8836
8832
"value": "state_value_test",
8837
- "sample_log_prob": "sample_log_prob_test ",
8833
+ "sample_log_prob": "action_log_prob_test ",
8838
8834
"action": "action_test",
8839
8835
}
8840
8836
@@ -9242,13 +9238,6 @@ def mixture_constructor(logits, loc, scale):
9242
9238
class TestA2C(LossModuleTestBase):
9243
9239
seed = 0
9244
9240
9245
- @pytest.fixture(scope="class", autouse=True)
9246
- def _composite_log_prob(self):
9247
- setter = set_composite_lp_aggregate(False)
9248
- setter.set()
9249
- yield
9250
- setter.unset()
9251
-
9252
9241
def _create_mock_actor(
9253
9242
self,
9254
9243
batch=2,
@@ -9814,7 +9803,7 @@ def test_a2c_tensordict_keys_run(
9814
9803
value_key = "state_value_test"
9815
9804
action_key = "action_test"
9816
9805
reward_key = "reward_test"
9817
- sample_log_prob_key = "sample_log_prob_test "
9806
+ sample_log_prob_key = "action_log_prob_test "
9818
9807
done_key = ("done", "test")
9819
9808
terminated_key = ("terminated", "test")
9820
9809
@@ -10258,7 +10247,7 @@ def test_reinforce_tensordict_keys(self, td_est):
10258
10247
"advantage": "advantage",
10259
10248
"value_target": "value_target",
10260
10249
"value": "state_value",
10261
- "sample_log_prob": "sample_log_prob ",
10250
+ "sample_log_prob": "action_log_prob ",
10262
10251
"reward": "reward",
10263
10252
"done": "done",
10264
10253
"terminated": "terminated",
@@ -10316,7 +10305,7 @@ def _create_mock_common_layer_setup(
10316
10305
{
10317
10306
"obs": torch.randn(*batch, n_obs),
10318
10307
"action": torch.randn(*batch, n_act),
10319
- "sample_log_prob ": torch.randn(*batch),
10308
+ "action_log_prob ": torch.randn(*batch),
10320
10309
"done": torch.zeros(*batch, 1, dtype=torch.bool),
10321
10310
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
10322
10311
"next": {
@@ -11788,7 +11777,7 @@ def _create_mock_common_layer_setup(
11788
11777
{
11789
11778
"obs": torch.randn(*batch, n_obs),
11790
11779
"action": torch.randn(*batch, n_act),
11791
- "sample_log_prob ": torch.randn(*batch),
11780
+ "action_log_prob ": torch.randn(*batch),
11792
11781
"done": torch.zeros(*batch, 1, dtype=torch.bool),
11793
11782
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
11794
11783
"next": {
@@ -12604,7 +12593,7 @@ def _create_mock_common_layer_setup(
12604
12593
{
12605
12594
"obs": torch.randn(*batch, n_obs),
12606
12595
"action": torch.randn(*batch, n_act),
12607
- "sample_log_prob ": torch.randn(*batch),
12596
+ "action_log_prob ": torch.randn(*batch),
12608
12597
"done": torch.zeros(*batch, 1, dtype=torch.bool),
12609
12598
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
12610
12599
"next": {
@@ -15228,6 +15217,7 @@ def test_successive_traj_gae(
15228
15217
["half", torch.half, "cpu"],
15229
15218
],
15230
15219
)
15220
+ @set_composite_lp_aggregate(False)
15231
15221
def test_shared_params(dest, expected_dtype, expected_device):
15232
15222
if torch.cuda.device_count() == 0 and dest == "cuda":
15233
15223
pytest.skip("no cuda device available")
@@ -15332,6 +15322,13 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
15332
15322
15333
15323
15334
15324
class TestAdv:
15325
+ @pytest.fixture(scope="class", autouse=True)
15326
+ def _composite_log_prob(self):
15327
+ setter = set_composite_lp_aggregate(False)
15328
+ setter.set()
15329
+ yield
15330
+ setter.unset()
15331
+
15335
15332
@pytest.mark.parametrize(
15336
15333
"adv,kwargs",
15337
15334
[
@@ -15369,7 +15366,7 @@ def test_dispatch(
15369
15366
)
15370
15367
kwargs = {
15371
15368
"obs": torch.randn(1, 10, 3),
15372
- "sample_log_prob ": torch.log(torch.rand(1, 10, 1)),
15369
+ "action_log_prob ": torch.log(torch.rand(1, 10, 1)),
15373
15370
"next_reward": torch.randn(1, 10, 1, requires_grad=True),
15374
15371
"next_done": torch.zeros(1, 10, 1, dtype=torch.bool),
15375
15372
"next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
@@ -15431,7 +15428,7 @@ def test_diff_reward(
15431
15428
td = TensorDict(
15432
15429
{
15433
15430
"obs": torch.randn(1, 10, 3),
15434
- "sample_log_prob ": torch.log(torch.rand(1, 10, 1)),
15431
+ "action_log_prob ": torch.log(torch.rand(1, 10, 1)),
15435
15432
"next": {
15436
15433
"obs": torch.randn(1, 10, 3),
15437
15434
"reward": torch.randn(1, 10, 1, requires_grad=True),
@@ -15504,7 +15501,7 @@ def test_non_differentiable(self, adv, shifted, kwargs):
15504
15501
td = TensorDict(
15505
15502
{
15506
15503
"obs": torch.randn(1, 10, 3),
15507
- "sample_log_prob ": torch.log(torch.rand(1, 10, 1)),
15504
+ "action_log_prob ": torch.log(torch.rand(1, 10, 1)),
15508
15505
"next": {
15509
15506
"obs": torch.randn(1, 10, 3),
15510
15507
"reward": torch.randn(1, 10, 1, requires_grad=True),
@@ -15575,7 +15572,7 @@ def test_time_dim(self, adv, kwargs, shifted=True):
15575
15572
td = TensorDict(
15576
15573
{
15577
15574
"obs": torch.randn(1, 10, 3),
15578
- "sample_log_prob ": torch.log(torch.rand(1, 10, 1)),
15575
+ "action_log_prob ": torch.log(torch.rand(1, 10, 1)),
15579
15576
"next": {
15580
15577
"obs": torch.randn(1, 10, 3),
15581
15578
"reward": torch.randn(1, 10, 1, requires_grad=True),
@@ -15676,7 +15673,7 @@ def test_skip_existing(
15676
15673
td = TensorDict(
15677
15674
{
15678
15675
"obs": torch.randn(1, 10, 3),
15679
- "sample_log_prob ": torch.log(torch.rand(1, 10, 1)),
15676
+ "action_log_prob ": torch.log(torch.rand(1, 10, 1)),
15680
15677
"state_value": torch.ones(1, 10, 1),
15681
15678
"next": {
15682
15679
"obs": torch.randn(1, 10, 3),
@@ -15814,6 +15811,13 @@ def test_set_deprecated_keys(self, adv, kwargs):
15814
15811
15815
15812
15816
15813
class TestBase:
15814
+ @pytest.fixture(scope="class", autouse=True)
15815
+ def _composite_log_prob(self):
15816
+ setter = set_composite_lp_aggregate(False)
15817
+ setter.set()
15818
+ yield
15819
+ setter.unset()
15820
+
15817
15821
def test_decorators(self):
15818
15822
class MyLoss(LossModule):
15819
15823
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -16033,6 +16037,13 @@ class _AcceptedKeys:
16033
16037
16034
16038
16035
16039
class TestUtils:
16040
+ @pytest.fixture(scope="class", autouse=True)
16041
+ def _composite_log_prob(self):
16042
+ setter = set_composite_lp_aggregate(False)
16043
+ setter.set()
16044
+ yield
16045
+ setter.unset()
16046
+
16036
16047
@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
16037
16048
@pytest.mark.parametrize("T", [1, 10])
16038
16049
@pytest.mark.parametrize("device", get_default_devices())
@@ -16203,6 +16214,7 @@ def fun(a, b, time_dim=-2):
16203
16214
(SoftUpdate, {"eps": 0.99}),
16204
16215
],
16205
16216
)
16217
+ @set_composite_lp_aggregate(False)
16206
16218
def test_updater_warning(updater, kwarg):
16207
16219
with warnings.catch_warnings():
16208
16220
dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot")
@@ -16215,6 +16227,13 @@ def test_updater_warning(updater, kwarg):
16215
16227
16216
16228
16217
16229
class TestSingleCall:
16230
+ @pytest.fixture(scope="class", autouse=True)
16231
+ def _composite_log_prob(self):
16232
+ setter = set_composite_lp_aggregate(False)
16233
+ setter.set()
16234
+ yield
16235
+ setter.unset()
16236
+
16218
16237
def _mock_value_net(self, has_target, value_key):
16219
16238
model = nn.Linear(3, 1)
16220
16239
module = TensorDictModule(model, in_keys=["obs"], out_keys=[value_key])
@@ -16267,6 +16286,7 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True)
16267
16286
assert (value != value_).all()
16268
16287
16269
16288
16289
+ @set_composite_lp_aggregate(False)
16270
16290
def test_instantiate_with_different_keys():
16271
16291
loss_1 = DQNLoss(
16272
16292
value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True
@@ -16281,6 +16301,13 @@ def test_instantiate_with_different_keys():
16281
16301
16282
16302
16283
16303
class TestBuffer:
16304
+ @pytest.fixture(scope="class", autouse=True)
16305
+ def _composite_log_prob(self):
16306
+ setter = set_composite_lp_aggregate(False)
16307
+ setter.set()
16308
+ yield
16309
+ setter.unset()
16310
+
16284
16311
# @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half))
16285
16312
# def test_param_cast(self, dtype):
16286
16313
# param = nn.Parameter(torch.zeros(3))
@@ -16390,6 +16417,7 @@ def __init__(self):
16390
16417
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
16391
16418
)
16392
16419
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
16420
+ @set_composite_lp_aggregate(False)
16393
16421
def test_exploration_compile():
16394
16422
try:
16395
16423
torch._dynamo.reset_code_caches()
@@ -16456,6 +16484,7 @@ def func(t):
16456
16484
assert it == exploration_type()
16457
16485
16458
16486
16487
+ @set_composite_lp_aggregate(False)
16459
16488
def test_loss_exploration():
16460
16489
class DummyLoss(LossModule):
16461
16490
def forward(self, td, mode):
0 commit comments