Skip to content

Commit 1d88f7a

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 09ff2c8 + d5d49da commit 1d88f7a

File tree

5 files changed

+145
-99
lines changed

5 files changed

+145
-99
lines changed

.github/unittest/linux_sota/scripts/test_sota.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,6 @@
188188
ppo.collector.frames_per_batch=16 \
189189
logger.mode=offline \
190190
logger.backend=
191-
""",
192-
"dreamer": """python sota-implementations/dreamer/dreamer.py \
193-
collector.total_frames=600 \
194-
collector.init_random_frames=10 \
195-
collector.frames_per_batch=200 \
196-
env.n_parallel_envs=1 \
197-
optimization.optim_steps_per_batch=1 \
198-
logger.video=False \
199-
logger.backend=csv \
200-
replay_buffer.buffer_size=120 \
201-
replay_buffer.batch_size=24 \
202-
replay_buffer.batch_length=12 \
203-
networks.rssm_hidden_dim=17
204191
""",
205192
"ddpg-single": """python sota-implementations/ddpg/ddpg.py \
206193
collector.total_frames=48 \
@@ -289,6 +276,19 @@
289276
logger.backend=
290277
""",
291278
"bandits": """python sota-implementations/bandits/dqn.py --n_steps=100
279+
""",
280+
"dreamer": """python sota-implementations/dreamer/dreamer.py \
281+
collector.total_frames=600 \
282+
collector.init_random_frames=10 \
283+
collector.frames_per_batch=200 \
284+
env.n_parallel_envs=1 \
285+
optimization.optim_steps_per_batch=1 \
286+
logger.video=False \
287+
logger.backend=csv \
288+
replay_buffer.buffer_size=120 \
289+
replay_buffer.batch_size=24 \
290+
replay_buffer.batch_length=12 \
291+
networks.rssm_hidden_dim=17
292292
""",
293293
}
294294

test/test_cost.py

Lines changed: 79 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919

2020
from packaging import version, version as pack_version
21-
2221
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
2322
from tensordict._C import unravel_keys
2423
from tensordict.nn import (
@@ -37,6 +36,7 @@
3736
TensorDictSequential as Seq,
3837
WrapModule,
3938
)
39+
from tensordict.nn.distributions.composite import _add_suffix
4040
from tensordict.nn.utils import Buffer
4141
from tensordict.utils import unravel_key
4242
from torch import autograd, nn
@@ -199,6 +199,13 @@ def get_devices():
199199

200200

201201
class LossModuleTestBase:
202+
@pytest.fixture(scope="class", autouse=True)
203+
def _composite_log_prob(self):
204+
setter = set_composite_lp_aggregate(False)
205+
setter.set()
206+
yield
207+
setter.unset()
208+
202209
def __init_subclass__(cls, **kwargs):
203210
super().__init_subclass__(**kwargs)
204211
assert hasattr(
@@ -3541,13 +3548,6 @@ def test_td3bc_reduction(self, reduction):
35413548
class TestSAC(LossModuleTestBase):
35423549
seed = 0
35433550

3544-
@pytest.fixture(scope="class", autouse=True)
3545-
def _composite_log_prob(self):
3546-
setter = set_composite_lp_aggregate(False)
3547-
setter.set()
3548-
yield
3549-
setter.unset()
3550-
35513551
def _create_mock_actor(
35523552
self,
35533553
batch=2,
@@ -4623,13 +4623,6 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
46234623
class TestDiscreteSAC(LossModuleTestBase):
46244624
seed = 0
46254625

4626-
@pytest.fixture(scope="class", autouse=True)
4627-
def _composite_log_prob(self):
4628-
setter = set_composite_lp_aggregate(False)
4629-
setter.set()
4630-
yield
4631-
setter.unset()
4632-
46334626
def _create_mock_actor(
46344627
self,
46354628
batch=2,
@@ -6786,7 +6779,7 @@ def test_redq_tensordict_keys(self, td_est):
67866779
"priority": "td_error",
67876780
"action": "action",
67886781
"value": "state_value",
6789-
"sample_log_prob": "sample_log_prob",
6782+
"sample_log_prob": "action_log_prob",
67906783
"state_action_value": "state_action_value",
67916784
"reward": "reward",
67926785
"done": "done",
@@ -6849,12 +6842,22 @@ def test_redq_notensordict(
68496842
actor_network=actor,
68506843
qvalue_network=qvalue,
68516844
)
6852-
loss.set_keys(
6853-
action=action_key,
6854-
reward=reward_key,
6855-
done=done_key,
6856-
terminated=terminated_key,
6857-
)
6845+
if deprec:
6846+
loss.set_keys(
6847+
action=action_key,
6848+
reward=reward_key,
6849+
done=done_key,
6850+
terminated=terminated_key,
6851+
log_prob=_add_suffix(action_key, "_log_prob"),
6852+
)
6853+
else:
6854+
loss.set_keys(
6855+
action=action_key,
6856+
reward=reward_key,
6857+
done=done_key,
6858+
terminated=terminated_key,
6859+
sample_log_prob=_add_suffix(action_key, "_log_prob"),
6860+
)
68586861

68596862
kwargs = {
68606863
action_key: td.get(action_key),
@@ -7916,13 +7919,6 @@ def test_dcql_reduction(self, reduction):
79167919
class TestPPO(LossModuleTestBase):
79177920
seed = 0
79187921

7919-
@pytest.fixture(scope="class", autouse=True)
7920-
def _composite_log_prob(self):
7921-
setter = set_composite_lp_aggregate(False)
7922-
setter.set()
7923-
yield
7924-
setter.unset()
7925-
79267922
def _create_mock_actor(
79277923
self,
79287924
batch=2,
@@ -8003,7 +7999,7 @@ def _create_mock_actor_value(
80037999
action_dim=4,
80048000
device="cpu",
80058001
composite_action_dist=False,
8006-
sample_log_prob_key="sample_log_prob",
8002+
sample_log_prob_key="action_log_prob",
80078003
):
80088004
# Actor
80098005
action_spec = Bounded(
@@ -8058,7 +8054,7 @@ def _create_mock_actor_value_shared(
80588054
action_dim=4,
80598055
device="cpu",
80608056
composite_action_dist=False,
8061-
sample_log_prob_key="sample_log_prob",
8057+
sample_log_prob_key="action_log_prob",
80628058
):
80638059
# Actor
80648060
action_spec = Bounded(
@@ -8123,7 +8119,7 @@ def _create_mock_data_ppo(
81238119
reward_key="reward",
81248120
done_key="done",
81258121
terminated_key="terminated",
8126-
sample_log_prob_key="sample_log_prob",
8122+
sample_log_prob_key="action_log_prob",
81278123
composite_action_dist=False,
81288124
):
81298125
# create a tensordict
@@ -8834,7 +8830,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
88348830
"advantage": "advantage_test",
88358831
"value_target": "value_target_test",
88368832
"value": "state_value_test",
8837-
"sample_log_prob": "sample_log_prob_test",
8833+
"sample_log_prob": "action_log_prob_test",
88388834
"action": "action_test",
88398835
}
88408836

@@ -9242,13 +9238,6 @@ def mixture_constructor(logits, loc, scale):
92429238
class TestA2C(LossModuleTestBase):
92439239
seed = 0
92449240

9245-
@pytest.fixture(scope="class", autouse=True)
9246-
def _composite_log_prob(self):
9247-
setter = set_composite_lp_aggregate(False)
9248-
setter.set()
9249-
yield
9250-
setter.unset()
9251-
92529241
def _create_mock_actor(
92539242
self,
92549243
batch=2,
@@ -9814,7 +9803,7 @@ def test_a2c_tensordict_keys_run(
98149803
value_key = "state_value_test"
98159804
action_key = "action_test"
98169805
reward_key = "reward_test"
9817-
sample_log_prob_key = "sample_log_prob_test"
9806+
sample_log_prob_key = "action_log_prob_test"
98189807
done_key = ("done", "test")
98199808
terminated_key = ("terminated", "test")
98209809

@@ -10258,7 +10247,7 @@ def test_reinforce_tensordict_keys(self, td_est):
1025810247
"advantage": "advantage",
1025910248
"value_target": "value_target",
1026010249
"value": "state_value",
10261-
"sample_log_prob": "sample_log_prob",
10250+
"sample_log_prob": "action_log_prob",
1026210251
"reward": "reward",
1026310252
"done": "done",
1026410253
"terminated": "terminated",
@@ -10316,7 +10305,7 @@ def _create_mock_common_layer_setup(
1031610305
{
1031710306
"obs": torch.randn(*batch, n_obs),
1031810307
"action": torch.randn(*batch, n_act),
10319-
"sample_log_prob": torch.randn(*batch),
10308+
"action_log_prob": torch.randn(*batch),
1032010309
"done": torch.zeros(*batch, 1, dtype=torch.bool),
1032110310
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
1032210311
"next": {
@@ -11788,7 +11777,7 @@ def _create_mock_common_layer_setup(
1178811777
{
1178911778
"obs": torch.randn(*batch, n_obs),
1179011779
"action": torch.randn(*batch, n_act),
11791-
"sample_log_prob": torch.randn(*batch),
11780+
"action_log_prob": torch.randn(*batch),
1179211781
"done": torch.zeros(*batch, 1, dtype=torch.bool),
1179311782
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
1179411783
"next": {
@@ -12604,7 +12593,7 @@ def _create_mock_common_layer_setup(
1260412593
{
1260512594
"obs": torch.randn(*batch, n_obs),
1260612595
"action": torch.randn(*batch, n_act),
12607-
"sample_log_prob": torch.randn(*batch),
12596+
"action_log_prob": torch.randn(*batch),
1260812597
"done": torch.zeros(*batch, 1, dtype=torch.bool),
1260912598
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
1261012599
"next": {
@@ -15228,6 +15217,7 @@ def test_successive_traj_gae(
1522815217
["half", torch.half, "cpu"],
1522915218
],
1523015219
)
15220+
@set_composite_lp_aggregate(False)
1523115221
def test_shared_params(dest, expected_dtype, expected_device):
1523215222
if torch.cuda.device_count() == 0 and dest == "cuda":
1523315223
pytest.skip("no cuda device available")
@@ -15332,6 +15322,13 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1533215322

1533315323

1533415324
class TestAdv:
15325+
@pytest.fixture(scope="class", autouse=True)
15326+
def _composite_log_prob(self):
15327+
setter = set_composite_lp_aggregate(False)
15328+
setter.set()
15329+
yield
15330+
setter.unset()
15331+
1533515332
@pytest.mark.parametrize(
1533615333
"adv,kwargs",
1533715334
[
@@ -15369,7 +15366,7 @@ def test_dispatch(
1536915366
)
1537015367
kwargs = {
1537115368
"obs": torch.randn(1, 10, 3),
15372-
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
15369+
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
1537315370
"next_reward": torch.randn(1, 10, 1, requires_grad=True),
1537415371
"next_done": torch.zeros(1, 10, 1, dtype=torch.bool),
1537515372
"next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
@@ -15431,7 +15428,7 @@ def test_diff_reward(
1543115428
td = TensorDict(
1543215429
{
1543315430
"obs": torch.randn(1, 10, 3),
15434-
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
15431+
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
1543515432
"next": {
1543615433
"obs": torch.randn(1, 10, 3),
1543715434
"reward": torch.randn(1, 10, 1, requires_grad=True),
@@ -15504,7 +15501,7 @@ def test_non_differentiable(self, adv, shifted, kwargs):
1550415501
td = TensorDict(
1550515502
{
1550615503
"obs": torch.randn(1, 10, 3),
15507-
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
15504+
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
1550815505
"next": {
1550915506
"obs": torch.randn(1, 10, 3),
1551015507
"reward": torch.randn(1, 10, 1, requires_grad=True),
@@ -15575,7 +15572,7 @@ def test_time_dim(self, adv, kwargs, shifted=True):
1557515572
td = TensorDict(
1557615573
{
1557715574
"obs": torch.randn(1, 10, 3),
15578-
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
15575+
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
1557915576
"next": {
1558015577
"obs": torch.randn(1, 10, 3),
1558115578
"reward": torch.randn(1, 10, 1, requires_grad=True),
@@ -15676,7 +15673,7 @@ def test_skip_existing(
1567615673
td = TensorDict(
1567715674
{
1567815675
"obs": torch.randn(1, 10, 3),
15679-
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
15676+
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
1568015677
"state_value": torch.ones(1, 10, 1),
1568115678
"next": {
1568215679
"obs": torch.randn(1, 10, 3),
@@ -15814,6 +15811,13 @@ def test_set_deprecated_keys(self, adv, kwargs):
1581415811

1581515812

1581615813
class TestBase:
15814+
@pytest.fixture(scope="class", autouse=True)
15815+
def _composite_log_prob(self):
15816+
setter = set_composite_lp_aggregate(False)
15817+
setter.set()
15818+
yield
15819+
setter.unset()
15820+
1581715821
def test_decorators(self):
1581815822
class MyLoss(LossModule):
1581915823
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -16033,6 +16037,13 @@ class _AcceptedKeys:
1603316037

1603416038

1603516039
class TestUtils:
16040+
@pytest.fixture(scope="class", autouse=True)
16041+
def _composite_log_prob(self):
16042+
setter = set_composite_lp_aggregate(False)
16043+
setter.set()
16044+
yield
16045+
setter.unset()
16046+
1603616047
@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
1603716048
@pytest.mark.parametrize("T", [1, 10])
1603816049
@pytest.mark.parametrize("device", get_default_devices())
@@ -16203,6 +16214,7 @@ def fun(a, b, time_dim=-2):
1620316214
(SoftUpdate, {"eps": 0.99}),
1620416215
],
1620516216
)
16217+
@set_composite_lp_aggregate(False)
1620616218
def test_updater_warning(updater, kwarg):
1620716219
with warnings.catch_warnings():
1620816220
dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot")
@@ -16215,6 +16227,13 @@ def test_updater_warning(updater, kwarg):
1621516227

1621616228

1621716229
class TestSingleCall:
16230+
@pytest.fixture(scope="class", autouse=True)
16231+
def _composite_log_prob(self):
16232+
setter = set_composite_lp_aggregate(False)
16233+
setter.set()
16234+
yield
16235+
setter.unset()
16236+
1621816237
def _mock_value_net(self, has_target, value_key):
1621916238
model = nn.Linear(3, 1)
1622016239
module = TensorDictModule(model, in_keys=["obs"], out_keys=[value_key])
@@ -16267,6 +16286,7 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True)
1626716286
assert (value != value_).all()
1626816287

1626916288

16289+
@set_composite_lp_aggregate(False)
1627016290
def test_instantiate_with_different_keys():
1627116291
loss_1 = DQNLoss(
1627216292
value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True
@@ -16281,6 +16301,13 @@ def test_instantiate_with_different_keys():
1628116301

1628216302

1628316303
class TestBuffer:
16304+
@pytest.fixture(scope="class", autouse=True)
16305+
def _composite_log_prob(self):
16306+
setter = set_composite_lp_aggregate(False)
16307+
setter.set()
16308+
yield
16309+
setter.unset()
16310+
1628416311
# @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half))
1628516312
# def test_param_cast(self, dtype):
1628616313
# param = nn.Parameter(torch.zeros(3))
@@ -16390,6 +16417,7 @@ def __init__(self):
1639016417
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
1639116418
)
1639216419
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
16420+
@set_composite_lp_aggregate(False)
1639316421
def test_exploration_compile():
1639416422
try:
1639516423
torch._dynamo.reset_code_caches()
@@ -16456,6 +16484,7 @@ def func(t):
1645616484
assert it == exploration_type()
1645716485

1645816486

16487+
@set_composite_lp_aggregate(False)
1645916488
def test_loss_exploration():
1646016489
class DummyLoss(LossModule):
1646116490
def forward(self, td, mode):

0 commit comments

Comments
 (0)