Skip to content

Commit 10b0f4e

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 30166a0 + f81c9e3 commit 10b0f4e

File tree

31 files changed

+731
-213
lines changed

31 files changed

+731
-213
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ to be able to create this other composition:
829829
GrayScale
830830
InitTracker
831831
KLRewardTransform
832+
LineariseReward
832833
NoopResetEnv
833834
ObservationNorm
834835
ObservationTransform

test/test_cost.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import itertools
99
import operator
1010
import os
11-
1211
import sys
1312
import warnings
1413
from copy import deepcopy
@@ -23,13 +22,15 @@
2322
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
2423
from tensordict._C import unravel_keys
2524
from tensordict.nn import (
25+
composite_lp_aggregate,
2626
CompositeDistribution,
2727
InteractionType,
2828
NormalParamExtractor,
2929
ProbabilisticTensorDictModule,
3030
ProbabilisticTensorDictModule as ProbMod,
3131
ProbabilisticTensorDictSequential,
3232
ProbabilisticTensorDictSequential as ProbSeq,
33+
set_composite_lp_aggregate,
3334
TensorDictModule,
3435
TensorDictModule as Mod,
3536
TensorDictSequential,
@@ -3540,6 +3541,13 @@ def test_td3bc_reduction(self, reduction):
35403541
class TestSAC(LossModuleTestBase):
35413542
seed = 0
35423543

3544+
@pytest.fixture(scope="class", autouse=True)
3545+
def _composite_log_prob(self):
3546+
setter = set_composite_lp_aggregate(False)
3547+
setter.set()
3548+
yield
3549+
setter.unset()
3550+
35433551
def _create_mock_actor(
35443552
self,
35453553
batch=2,
@@ -3563,7 +3571,6 @@ def _create_mock_actor(
35633571
distribution_map={
35643572
"action1": TanhNormal,
35653573
},
3566-
aggregate_probabilities=True,
35673574
)
35683575
module_out_keys = [
35693576
("params", "action1", "loc"),
@@ -3583,6 +3590,7 @@ def _create_mock_actor(
35833590
out_keys=[action_key],
35843591
spec=action_spec,
35853592
)
3593+
assert actor.log_prob_keys
35863594
return actor.to(device)
35873595

35883596
def _create_mock_qvalue(
@@ -3688,7 +3696,6 @@ def forward(self, obs, act):
36883696
distribution_map={
36893697
"action1": TanhNormal,
36903698
},
3691-
aggregate_probabilities=True,
36923699
)
36933700
module_out_keys = [
36943701
("params", "action1", "loc"),
@@ -4342,7 +4349,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
43424349
"value": "state_value",
43434350
"state_action_value": "state_action_value",
43444351
"action": "action",
4345-
"log_prob": "sample_log_prob",
4352+
"log_prob": "action_log_prob",
43464353
"reward": "reward",
43474354
"done": "done",
43484355
"terminated": "terminated",
@@ -4616,6 +4623,13 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
46164623
class TestDiscreteSAC(LossModuleTestBase):
46174624
seed = 0
46184625

4626+
@pytest.fixture(scope="class", autouse=True)
4627+
def _composite_log_prob(self):
4628+
setter = set_composite_lp_aggregate(False)
4629+
setter.set()
4630+
yield
4631+
setter.unset()
4632+
46194633
def _create_mock_actor(
46204634
self,
46214635
batch=2,
@@ -7902,6 +7916,13 @@ def test_dcql_reduction(self, reduction):
79027916
class TestPPO(LossModuleTestBase):
79037917
seed = 0
79047918

7919+
@pytest.fixture(scope="class", autouse=True)
7920+
def _composite_log_prob(self):
7921+
setter = set_composite_lp_aggregate(False)
7922+
setter.set()
7923+
yield
7924+
setter.unset()
7925+
79057926
def _create_mock_actor(
79067927
self,
79077928
batch=2,
@@ -7910,9 +7931,8 @@ def _create_mock_actor(
79107931
device="cpu",
79117932
action_key=None,
79127933
observation_key="observation",
7913-
sample_log_prob_key="sample_log_prob",
7934+
sample_log_prob_key=None,
79147935
composite_action_dist=False,
7915-
aggregate_probabilities=None,
79167936
):
79177937
# Actor
79187938
action_spec = Bounded(
@@ -7934,7 +7954,6 @@ def _create_mock_actor(
79347954
"action1": action_key,
79357955
},
79367956
log_prob_key=sample_log_prob_key,
7937-
aggregate_probabilities=aggregate_probabilities,
79387957
)
79397958
module_out_keys = [
79407959
("params", "action1", "loc"),
@@ -8006,7 +8025,6 @@ def _create_mock_actor_value(
80068025
"action1": ("action", "action1"),
80078026
},
80088027
log_prob_key=sample_log_prob_key,
8009-
aggregate_probabilities=True,
80108028
)
80118029
module_out_keys = [
80128030
("params", "action1", "loc"),
@@ -8063,7 +8081,6 @@ def _create_mock_actor_value_shared(
80638081
"action1": ("action", "action1"),
80648082
},
80658083
log_prob_key=sample_log_prob_key,
8066-
aggregate_probabilities=True,
80678084
)
80688085
module_out_keys = [
80698086
("params", "action1", "loc"),
@@ -8181,7 +8198,8 @@ def _create_seq_mock_data_ppo(
81818198
if composite_action_dist:
81828199
sample_log_prob_key = ("action", "action1_log_prob")
81838200
else:
8184-
sample_log_prob_key = "sample_log_prob"
8201+
# conforming to composite_lp_aggregate(False)
8202+
sample_log_prob_key = "action_log_prob"
81858203

81868204
if action_key is None:
81878205
if composite_action_dist:
@@ -8287,6 +8305,7 @@ def test_ppo(
82878305
if advantage is not None:
82888306
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
82898307
if advantage is not None:
8308+
assert not composite_lp_aggregate()
82908309
advantage(td)
82918310
else:
82928311
if td_est is not None:
@@ -8346,7 +8365,6 @@ def test_ppo_composite_no_aggregate(
83468365
actor = self._create_mock_actor(
83478366
device=device,
83488367
composite_action_dist=True,
8349-
aggregate_probabilities=False,
83508368
)
83518369
value = self._create_mock_value(device=device)
83528370
if advantage == "gae":
@@ -8766,6 +8784,7 @@ def zero_param(p):
87668784
)
87678785
@pytest.mark.parametrize("composite_action_dist", [True, False])
87688786
def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
8787+
assert not composite_lp_aggregate()
87698788
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
87708789
value = self._create_mock_value()
87718790

@@ -8775,8 +8794,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
87758794
"advantage": "advantage",
87768795
"value_target": "value_target",
87778796
"value": "state_value",
8778-
"sample_log_prob": "sample_log_prob",
8779-
"action": "action",
8797+
"sample_log_prob": "action_log_prob"
8798+
if not composite_action_dist
8799+
else ("action", "action1_log_prob"),
8800+
"action": "action" if not composite_action_dist else ("action", "action1"),
87808801
"reward": "reward",
87818802
"done": "done",
87828803
"terminated": "terminated",
@@ -9162,9 +9183,6 @@ def mixture_constructor(logits, loc, scale):
91629183
"Kumaraswamy": ("agent1", "action"),
91639184
"mixture": ("agent2", "action"),
91649185
},
9165-
aggregate_probabilities=False,
9166-
include_sum=False,
9167-
inplace=True,
91689186
)
91699187
policy = ProbSeq(
91709188
make_params,
@@ -9183,15 +9201,11 @@ def mixture_constructor(logits, loc, scale):
91839201
# We want to make sure there is no warning
91849202
td = policy(TensorDict(batch_size=[4]))
91859203
assert isinstance(
9186-
policy.get_dist(td).log_prob(
9187-
td, aggregate_probabilities=False, inplace=False, include_sum=False
9188-
),
9204+
policy.get_dist(td).log_prob(td),
91899205
TensorDict,
91909206
)
91919207
assert isinstance(
9192-
policy.log_prob(
9193-
td, aggregate_probabilities=False, inplace=False, include_sum=False
9194-
),
9208+
policy.log_prob(td),
91959209
TensorDict,
91969210
)
91979211
value_operator = Seq(
@@ -9228,6 +9242,13 @@ def mixture_constructor(logits, loc, scale):
92289242
class TestA2C(LossModuleTestBase):
92299243
seed = 0
92309244

9245+
@pytest.fixture(scope="class", autouse=True)
9246+
def _composite_log_prob(self):
9247+
setter = set_composite_lp_aggregate(False)
9248+
setter.set()
9249+
yield
9250+
setter.unset()
9251+
92319252
def _create_mock_actor(
92329253
self,
92339254
batch=2,
@@ -9236,8 +9257,8 @@ def _create_mock_actor(
92369257
device="cpu",
92379258
action_key="action",
92389259
observation_key="observation",
9239-
sample_log_prob_key="sample_log_prob",
92409260
composite_action_dist=False,
9261+
sample_log_prob_key=None,
92419262
):
92429263
# Actor
92439264
action_spec = Bounded(
@@ -9255,8 +9276,6 @@ def _create_mock_actor(
92559276
name_map={
92569277
"action1": (action_key, "action1"),
92579278
},
9258-
log_prob_key=sample_log_prob_key,
9259-
aggregate_probabilities=True,
92609279
)
92619280
module_out_keys = [
92629281
("params", "action1", "loc"),
@@ -9306,7 +9325,6 @@ def _create_mock_common_layer_setup(
93069325
n_hidden=2,
93079326
T=10,
93089327
composite_action_dist=False,
9309-
sample_log_prob_key="sample_log_prob",
93109328
):
93119329
common_net = MLP(
93129330
num_cells=ncells,
@@ -9332,7 +9350,7 @@ def _create_mock_common_layer_setup(
93329350
{
93339351
"obs": torch.randn(*batch, n_obs),
93349352
"action": {"action1": action} if composite_action_dist else action,
9335-
"sample_log_prob": torch.randn(*batch),
9353+
"action_log_prob": torch.randn(*batch),
93369354
"done": torch.zeros(*batch, 1, dtype=torch.bool),
93379355
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
93389356
"next": {
@@ -9356,8 +9374,6 @@ def _create_mock_common_layer_setup(
93569374
name_map={
93579375
"action1": ("action", "action1"),
93589376
},
9359-
log_prob_key=sample_log_prob_key,
9360-
aggregate_probabilities=True,
93619377
)
93629378
module_out_keys = [
93639379
("params", "action1", "loc"),
@@ -9398,7 +9414,7 @@ def _create_seq_mock_data_a2c(
93989414
reward_key="reward",
93999415
done_key="done",
94009416
terminated_key="terminated",
9401-
sample_log_prob_key="sample_log_prob",
9417+
sample_log_prob_key="action_log_prob",
94029418
composite_action_dist=False,
94039419
):
94049420
# create a tensordict
@@ -9530,6 +9546,11 @@ def set_requires_grad(tensor, requires_grad):
95309546

95319547
td = td.exclude(loss_fn.tensor_keys.value_target)
95329548
if advantage is not None:
9549+
advantage.set_keys(
9550+
sample_log_prob=actor.log_prob_keys
9551+
if composite_action_dist
9552+
else "action_log_prob"
9553+
)
95339554
advantage(td)
95349555
elif td_est is not None:
95359556
loss_fn.make_value_estimator(td_est)
@@ -9749,7 +9770,7 @@ def test_a2c_tensordict_keys(self, td_est, composite_action_dist):
97499770
"reward": "reward",
97509771
"done": "done",
97519772
"terminated": "terminated",
9752-
"sample_log_prob": "sample_log_prob",
9773+
"sample_log_prob": "action_log_prob",
97539774
}
97549775

97559776
self.tensordict_keys_test(

0 commit comments

Comments
 (0)