Skip to content

Commit 434fe58

Browse files
[Tests] DDPG extra critic input tests (#1568)
Signed-off-by: Matteo Bettini <matbet@meta.com>
1 parent 8503378 commit 434fe58

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

test/test_cost.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,20 +1206,20 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
12061206
return actor.to(device)
12071207

12081208
def _create_mock_value(
1209-
self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None
1209+
self, batch=2, obs_dim=3, action_dim=4, state_dim=8, device="cpu", out_keys=None
12101210
):
12111211
# Actor
12121212
class ValueClass(nn.Module):
12131213
def __init__(self):
12141214
super().__init__()
1215-
self.linear = nn.Linear(obs_dim + action_dim, 1)
1215+
self.linear = nn.Linear(obs_dim + action_dim + state_dim, 1)
12161216

1217-
def forward(self, obs, act):
1218-
return self.linear(torch.cat([obs, act], -1))
1217+
def forward(self, obs, state, act):
1218+
return self.linear(torch.cat([obs, state, act], -1))
12191219

12201220
module = ValueClass()
12211221
value = ValueOperator(
1222-
module=module, in_keys=["observation", "action"], out_keys=out_keys
1222+
module=module, in_keys=["observation", "state", "action"], out_keys=out_keys
12231223
)
12241224
return value.to(device)
12251225

@@ -1278,6 +1278,7 @@ def _create_mock_data_ddpg(
12781278
batch=8,
12791279
obs_dim=3,
12801280
action_dim=4,
1281+
state_dim=8,
12811282
atoms=None,
12821283
device="cpu",
12831284
reward_key="reward",
@@ -1291,13 +1292,16 @@ def _create_mock_data_ddpg(
12911292
else:
12921293
action = torch.randn(batch, action_dim, device=device).clamp(-1, 1)
12931294
reward = torch.randn(batch, 1, device=device)
1295+
state = torch.randn(batch, state_dim, device=device)
12941296
done = torch.zeros(batch, 1, dtype=torch.bool, device=device)
12951297
td = TensorDict(
12961298
batch_size=(batch,),
12971299
source={
12981300
"observation": obs,
1301+
"state": state,
12991302
"next": {
13001303
"observation": next_obs,
1304+
"state": state,
13011305
done_key: done,
13021306
reward_key: reward,
13031307
},
@@ -1313,30 +1317,37 @@ def _create_seq_mock_data_ddpg(
13131317
T=4,
13141318
obs_dim=3,
13151319
action_dim=4,
1320+
state_dim=8,
13161321
atoms=None,
13171322
device="cpu",
13181323
reward_key="reward",
13191324
done_key="done",
13201325
):
13211326
# create a tensordict
13221327
total_obs = torch.randn(batch, T + 1, obs_dim, device=device)
1328+
total_state = torch.randn(batch, T + 1, state_dim, device=device)
13231329
obs = total_obs[:, :T]
13241330
next_obs = total_obs[:, 1:]
1331+
state = total_state[:, :T]
1332+
next_state = total_state[:, 1:]
13251333
if atoms:
13261334
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
13271335
-1, 1
13281336
)
13291337
else:
13301338
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
13311339
reward = torch.randn(batch, T, 1, device=device)
1340+
13321341
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
13331342
mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device)
13341343
td = TensorDict(
13351344
batch_size=(batch, T),
13361345
source={
13371346
"observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
1347+
"state": state.masked_fill_(~mask.unsqueeze(-1), 0.0),
13381348
"next": {
13391349
"observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0),
1350+
"state": next_state.masked_fill_(~mask.unsqueeze(-1), 0.0),
13401351
done_key: done,
13411352
reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
13421353
},
@@ -1715,6 +1726,8 @@ def test_ddpg_notensordict(self):
17151726
"next_done": td.get(("next", "done")),
17161727
"next_observation": td.get(("next", "observation")),
17171728
"action": td.get("action"),
1729+
"state": td.get("state"),
1730+
"next_state": td.get(("next", "state")),
17181731
}
17191732
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")
17201733

0 commit comments

Comments
 (0)