@@ -1206,20 +1206,20 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
1206
1206
return actor .to (device )
1207
1207
1208
1208
def _create_mock_value (
1209
- self , batch = 2 , obs_dim = 3 , action_dim = 4 , device = "cpu" , out_keys = None
1209
+ self , batch = 2 , obs_dim = 3 , action_dim = 4 , state_dim = 8 , device = "cpu" , out_keys = None
1210
1210
):
1211
1211
# Actor
1212
1212
class ValueClass (nn .Module ):
1213
1213
def __init__ (self ):
1214
1214
super ().__init__ ()
1215
- self .linear = nn .Linear (obs_dim + action_dim , 1 )
1215
+ self .linear = nn .Linear (obs_dim + action_dim + state_dim , 1 )
1216
1216
1217
- def forward (self , obs , act ):
1218
- return self .linear (torch .cat ([obs , act ], - 1 ))
1217
+ def forward (self , obs , state , act ):
1218
+ return self .linear (torch .cat ([obs , state , act ], - 1 ))
1219
1219
1220
1220
module = ValueClass ()
1221
1221
value = ValueOperator (
1222
- module = module , in_keys = ["observation" , "action" ], out_keys = out_keys
1222
+ module = module , in_keys = ["observation" , "state" , " action" ], out_keys = out_keys
1223
1223
)
1224
1224
return value .to (device )
1225
1225
@@ -1278,6 +1278,7 @@ def _create_mock_data_ddpg(
1278
1278
batch = 8 ,
1279
1279
obs_dim = 3 ,
1280
1280
action_dim = 4 ,
1281
+ state_dim = 8 ,
1281
1282
atoms = None ,
1282
1283
device = "cpu" ,
1283
1284
reward_key = "reward" ,
@@ -1291,13 +1292,16 @@ def _create_mock_data_ddpg(
1291
1292
else :
1292
1293
action = torch .randn (batch , action_dim , device = device ).clamp (- 1 , 1 )
1293
1294
reward = torch .randn (batch , 1 , device = device )
1295
+ state = torch .randn (batch , state_dim , device = device )
1294
1296
done = torch .zeros (batch , 1 , dtype = torch .bool , device = device )
1295
1297
td = TensorDict (
1296
1298
batch_size = (batch ,),
1297
1299
source = {
1298
1300
"observation" : obs ,
1301
+ "state" : state ,
1299
1302
"next" : {
1300
1303
"observation" : next_obs ,
1304
+ "state" : state ,
1301
1305
done_key : done ,
1302
1306
reward_key : reward ,
1303
1307
},
@@ -1313,30 +1317,37 @@ def _create_seq_mock_data_ddpg(
1313
1317
T = 4 ,
1314
1318
obs_dim = 3 ,
1315
1319
action_dim = 4 ,
1320
+ state_dim = 8 ,
1316
1321
atoms = None ,
1317
1322
device = "cpu" ,
1318
1323
reward_key = "reward" ,
1319
1324
done_key = "done" ,
1320
1325
):
1321
1326
# create a tensordict
1322
1327
total_obs = torch .randn (batch , T + 1 , obs_dim , device = device )
1328
+ total_state = torch .randn (batch , T + 1 , state_dim , device = device )
1323
1329
obs = total_obs [:, :T ]
1324
1330
next_obs = total_obs [:, 1 :]
1331
+ state = total_state [:, :T ]
1332
+ next_state = total_state [:, 1 :]
1325
1333
if atoms :
1326
1334
action = torch .randn (batch , T , atoms , action_dim , device = device ).clamp (
1327
1335
- 1 , 1
1328
1336
)
1329
1337
else :
1330
1338
action = torch .randn (batch , T , action_dim , device = device ).clamp (- 1 , 1 )
1331
1339
reward = torch .randn (batch , T , 1 , device = device )
1340
+
1332
1341
done = torch .zeros (batch , T , 1 , dtype = torch .bool , device = device )
1333
1342
mask = ~ torch .zeros (batch , T , dtype = torch .bool , device = device )
1334
1343
td = TensorDict (
1335
1344
batch_size = (batch , T ),
1336
1345
source = {
1337
1346
"observation" : obs .masked_fill_ (~ mask .unsqueeze (- 1 ), 0.0 ),
1347
+ "state" : state .masked_fill_ (~ mask .unsqueeze (- 1 ), 0.0 ),
1338
1348
"next" : {
1339
1349
"observation" : next_obs .masked_fill_ (~ mask .unsqueeze (- 1 ), 0.0 ),
1350
+ "state" : next_state .masked_fill_ (~ mask .unsqueeze (- 1 ), 0.0 ),
1340
1351
done_key : done ,
1341
1352
reward_key : reward .masked_fill_ (~ mask .unsqueeze (- 1 ), 0.0 ),
1342
1353
},
@@ -1715,6 +1726,8 @@ def test_ddpg_notensordict(self):
1715
1726
"next_done" : td .get (("next" , "done" )),
1716
1727
"next_observation" : td .get (("next" , "observation" )),
1717
1728
"action" : td .get ("action" ),
1729
+ "state" : td .get ("state" ),
1730
+ "next_state" : td .get (("next" , "state" )),
1718
1731
}
1719
1732
td = TensorDict (kwargs , td .batch_size ).unflatten_keys ("_" )
1720
1733
0 commit comments