@@ -1216,6 +1216,57 @@ def make_dreamer(
1216
1216
out_features = 1 , depth = 2 , num_cells = cfg .mlp_num_units , activation_class = nn .ELU
1217
1217
)
1218
1218
1219
+ world_model = _dreamer_make_world_model (
1220
+ obs_encoder , obs_decoder , rssm_prior , rssm_posterior , reward_module
1221
+ ).to (device )
1222
+ with torch .no_grad (), set_exploration_mode ("random" ):
1223
+ tensordict = proof_environment .rollout (4 )
1224
+ tensordict = tensordict .to_tensordict ().to (device )
1225
+ tensordict = tensordict .to (device )
1226
+ world_model (tensordict )
1227
+
1228
+ model_based_env = _dreamer_make_mbenv (
1229
+ reward_module ,
1230
+ rssm_prior ,
1231
+ obs_decoder ,
1232
+ proof_environment ,
1233
+ use_decoder_in_env ,
1234
+ cfg .state_dim ,
1235
+ cfg .rssm_hidden_dim ,
1236
+ )
1237
+ model_based_env = model_based_env .to (device )
1238
+
1239
+ actor_simulator , actor_realworld = _dreamer_make_actors (
1240
+ obs_encoder ,
1241
+ rssm_prior ,
1242
+ rssm_posterior ,
1243
+ cfg .mlp_num_units ,
1244
+ action_key ,
1245
+ proof_environment ,
1246
+ )
1247
+ actor_simulator = actor_simulator .to (device )
1248
+
1249
+ value_model = _dreamer_make_value_model (cfg .mlp_num_units , value_key )
1250
+ value_model = value_model .to (device )
1251
+ with torch .no_grad (), set_exploration_mode ("random" ):
1252
+ tensordict = model_based_env .rollout (4 )
1253
+ tensordict = tensordict .to (device )
1254
+ tensordict = actor_simulator (tensordict )
1255
+ value_model (tensordict )
1256
+
1257
+ actor_realworld = actor_realworld .to (device )
1258
+ if proof_env_is_none :
1259
+ proof_environment .close ()
1260
+ torch .cuda .empty_cache ()
1261
+ del proof_environment
1262
+
1263
+ del tensordict
1264
+ return world_model , model_based_env , actor_simulator , value_model , actor_realworld
1265
+
1266
+
1267
+ def _dreamer_make_world_model (
1268
+ obs_encoder , obs_decoder , rssm_prior , rssm_posterior , reward_module
1269
+ ):
1219
1270
# World Model and reward model
1220
1271
rssm_rollout = RSSMRollout (
1221
1272
TensorDictModule (
@@ -1261,14 +1312,38 @@ def make_dreamer(
1261
1312
transition_model ,
1262
1313
reward_model ,
1263
1314
)
1315
+ return world_model
1264
1316
1265
- # actor for simulator: interacts with states ~ prior
1317
+
1318
+ def _dreamer_make_actors (
1319
+ obs_encoder ,
1320
+ rssm_prior ,
1321
+ rssm_posterior ,
1322
+ mlp_num_units ,
1323
+ action_key ,
1324
+ proof_environment ,
1325
+ ):
1266
1326
actor_module = DreamerActor (
1267
1327
out_features = proof_environment .action_spec .shape [0 ],
1268
1328
depth = 3 ,
1269
- num_cells = cfg . mlp_num_units ,
1329
+ num_cells = mlp_num_units ,
1270
1330
activation_class = nn .ELU ,
1271
1331
)
1332
+ actor_simulator = _dreamer_make_actor_sim (
1333
+ action_key , proof_environment , actor_module
1334
+ )
1335
+ actor_realworld = _dreamer_make_actor_real (
1336
+ obs_encoder ,
1337
+ rssm_prior ,
1338
+ rssm_posterior ,
1339
+ actor_module ,
1340
+ action_key ,
1341
+ proof_environment ,
1342
+ )
1343
+ return actor_simulator , actor_realworld
1344
+
1345
+
1346
+ def _dreamer_make_actor_sim (action_key , proof_environment , actor_module ):
1272
1347
actor_simulator = ProbabilisticTensorDictModule (
1273
1348
TensorDictModule (
1274
1349
actor_module ,
@@ -1293,6 +1368,12 @@ def make_dreamer(
1293
1368
}
1294
1369
),
1295
1370
)
1371
+ return actor_simulator
1372
+
1373
+
1374
+ def _dreamer_make_actor_real (
1375
+ obs_encoder , rssm_prior , rssm_posterior , actor_module , action_key , proof_environment
1376
+ ):
1296
1377
# actor for real world: interacts with states ~ posterior
1297
1378
# Out actor differs from the original paper where first they compute prior and posterior and then act on it
1298
1379
# but we found that this approach worked better.
@@ -1344,17 +1425,33 @@ def make_dreamer(
1344
1425
],
1345
1426
),
1346
1427
)
1428
+ return actor_realworld
1429
+
1430
+
1431
+ def _dreamer_make_value_model (mlp_num_units , value_key ):
1432
+ # actor for simulator: interacts with states ~ prior
1347
1433
value_model = TensorDictModule (
1348
1434
MLP (
1349
1435
out_features = 1 ,
1350
1436
depth = 3 ,
1351
- num_cells = cfg . mlp_num_units ,
1437
+ num_cells = mlp_num_units ,
1352
1438
activation_class = nn .ELU ,
1353
1439
),
1354
1440
in_keys = ["state" , "belief" ],
1355
1441
out_keys = [value_key ],
1356
1442
)
1357
-
1443
+ return value_model
1444
+
1445
+
1446
+ def _dreamer_make_mbenv (
1447
+ reward_module ,
1448
+ rssm_prior ,
1449
+ obs_decoder ,
1450
+ proof_environment ,
1451
+ use_decoder_in_env ,
1452
+ state_dim ,
1453
+ rssm_hidden_dim ,
1454
+ ):
1358
1455
# MB environment
1359
1456
if use_decoder_in_env :
1360
1457
mb_env_obs_decoder = TensorDictModule (
@@ -1387,49 +1484,22 @@ def make_dreamer(
1387
1484
transition_model ,
1388
1485
reward_model ,
1389
1486
),
1390
- prior_shape = torch .Size ([cfg . state_dim ]),
1391
- belief_shape = torch .Size ([cfg . rssm_hidden_dim ]),
1487
+ prior_shape = torch .Size ([state_dim ]),
1488
+ belief_shape = torch .Size ([rssm_hidden_dim ]),
1392
1489
obs_decoder = mb_env_obs_decoder ,
1393
1490
)
1394
1491
1395
1492
model_based_env .set_specs_from_env (proof_environment )
1396
1493
model_based_env = TransformedEnv (model_based_env )
1397
1494
default_dict = {
1398
- "next_state" : NdUnboundedContinuousTensorSpec (cfg . state_dim ),
1399
- "next_belief" : NdUnboundedContinuousTensorSpec (cfg . rssm_hidden_dim ),
1495
+ "next_state" : NdUnboundedContinuousTensorSpec (state_dim ),
1496
+ "next_belief" : NdUnboundedContinuousTensorSpec (rssm_hidden_dim ),
1400
1497
# "action": proof_environment.action_spec,
1401
1498
}
1402
1499
model_based_env .append_transform (
1403
1500
TensorDictPrimer (random = False , default_value = 0 , ** default_dict )
1404
1501
)
1405
-
1406
- world_model = world_model .to (device )
1407
-
1408
- # init nets
1409
- with torch .no_grad (), set_exploration_mode ("random" ):
1410
- tensordict = proof_environment .rollout (4 )
1411
- tensordict = tensordict .to_tensordict ().to (device )
1412
- tensordict = tensordict .to (device )
1413
- world_model (tensordict )
1414
- model_based_env = model_based_env .to (device )
1415
-
1416
- actor_simulator = actor_simulator .to (device )
1417
- value_model = value_model .to (device )
1418
-
1419
- with torch .no_grad (), set_exploration_mode ("random" ):
1420
- tensordict = model_based_env .rollout (4 )
1421
- tensordict = tensordict .to (device )
1422
- tensordict = actor_simulator (tensordict )
1423
- value_model (tensordict )
1424
-
1425
- actor_realworld = actor_realworld .to (device )
1426
- if proof_env_is_none :
1427
- proof_environment .close ()
1428
- torch .cuda .empty_cache ()
1429
- del proof_environment
1430
-
1431
- del tensordict
1432
- return world_model , model_based_env , actor_simulator , value_model , actor_realworld
1502
+ return model_based_env
1433
1503
1434
1504
1435
1505
@dataclass
0 commit comments