Skip to content

Commit 213ae5b

Browse files
authored
[Refactoring] Refactor dreamer helper in smaller pieces (#662)
1 parent 6196a95 commit 213ae5b

File tree

1 file changed

+106
-36
lines changed

1 file changed

+106
-36
lines changed

torchrl/trainers/helpers/models.py

Lines changed: 106 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,57 @@ def make_dreamer(
12161216
out_features=1, depth=2, num_cells=cfg.mlp_num_units, activation_class=nn.ELU
12171217
)
12181218

1219+
world_model = _dreamer_make_world_model(
1220+
obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
1221+
).to(device)
1222+
with torch.no_grad(), set_exploration_mode("random"):
1223+
tensordict = proof_environment.rollout(4)
1224+
tensordict = tensordict.to_tensordict().to(device)
1225+
tensordict = tensordict.to(device)
1226+
world_model(tensordict)
1227+
1228+
model_based_env = _dreamer_make_mbenv(
1229+
reward_module,
1230+
rssm_prior,
1231+
obs_decoder,
1232+
proof_environment,
1233+
use_decoder_in_env,
1234+
cfg.state_dim,
1235+
cfg.rssm_hidden_dim,
1236+
)
1237+
model_based_env = model_based_env.to(device)
1238+
1239+
actor_simulator, actor_realworld = _dreamer_make_actors(
1240+
obs_encoder,
1241+
rssm_prior,
1242+
rssm_posterior,
1243+
cfg.mlp_num_units,
1244+
action_key,
1245+
proof_environment,
1246+
)
1247+
actor_simulator = actor_simulator.to(device)
1248+
1249+
value_model = _dreamer_make_value_model(cfg.mlp_num_units, value_key)
1250+
value_model = value_model.to(device)
1251+
with torch.no_grad(), set_exploration_mode("random"):
1252+
tensordict = model_based_env.rollout(4)
1253+
tensordict = tensordict.to(device)
1254+
tensordict = actor_simulator(tensordict)
1255+
value_model(tensordict)
1256+
1257+
actor_realworld = actor_realworld.to(device)
1258+
if proof_env_is_none:
1259+
proof_environment.close()
1260+
torch.cuda.empty_cache()
1261+
del proof_environment
1262+
1263+
del tensordict
1264+
return world_model, model_based_env, actor_simulator, value_model, actor_realworld
1265+
1266+
1267+
def _dreamer_make_world_model(
1268+
obs_encoder, obs_decoder, rssm_prior, rssm_posterior, reward_module
1269+
):
12191270
# World Model and reward model
12201271
rssm_rollout = RSSMRollout(
12211272
TensorDictModule(
@@ -1261,14 +1312,38 @@ def make_dreamer(
12611312
transition_model,
12621313
reward_model,
12631314
)
1315+
return world_model
12641316

1265-
# actor for simulator: interacts with states ~ prior
1317+
1318+
def _dreamer_make_actors(
1319+
obs_encoder,
1320+
rssm_prior,
1321+
rssm_posterior,
1322+
mlp_num_units,
1323+
action_key,
1324+
proof_environment,
1325+
):
12661326
actor_module = DreamerActor(
12671327
out_features=proof_environment.action_spec.shape[0],
12681328
depth=3,
1269-
num_cells=cfg.mlp_num_units,
1329+
num_cells=mlp_num_units,
12701330
activation_class=nn.ELU,
12711331
)
1332+
actor_simulator = _dreamer_make_actor_sim(
1333+
action_key, proof_environment, actor_module
1334+
)
1335+
actor_realworld = _dreamer_make_actor_real(
1336+
obs_encoder,
1337+
rssm_prior,
1338+
rssm_posterior,
1339+
actor_module,
1340+
action_key,
1341+
proof_environment,
1342+
)
1343+
return actor_simulator, actor_realworld
1344+
1345+
1346+
def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
12721347
actor_simulator = ProbabilisticTensorDictModule(
12731348
TensorDictModule(
12741349
actor_module,
@@ -1293,6 +1368,12 @@ def make_dreamer(
12931368
}
12941369
),
12951370
)
1371+
return actor_simulator
1372+
1373+
1374+
def _dreamer_make_actor_real(
1375+
obs_encoder, rssm_prior, rssm_posterior, actor_module, action_key, proof_environment
1376+
):
12961377
# actor for real world: interacts with states ~ posterior
12971378
# Out actor differs from the original paper where first they compute prior and posterior and then act on it
12981379
# but we found that this approach worked better.
@@ -1344,17 +1425,33 @@ def make_dreamer(
13441425
],
13451426
),
13461427
)
1428+
return actor_realworld
1429+
1430+
1431+
def _dreamer_make_value_model(mlp_num_units, value_key):
1432+
# actor for simulator: interacts with states ~ prior
13471433
value_model = TensorDictModule(
13481434
MLP(
13491435
out_features=1,
13501436
depth=3,
1351-
num_cells=cfg.mlp_num_units,
1437+
num_cells=mlp_num_units,
13521438
activation_class=nn.ELU,
13531439
),
13541440
in_keys=["state", "belief"],
13551441
out_keys=[value_key],
13561442
)
1357-
1443+
return value_model
1444+
1445+
1446+
def _dreamer_make_mbenv(
1447+
reward_module,
1448+
rssm_prior,
1449+
obs_decoder,
1450+
proof_environment,
1451+
use_decoder_in_env,
1452+
state_dim,
1453+
rssm_hidden_dim,
1454+
):
13581455
# MB environment
13591456
if use_decoder_in_env:
13601457
mb_env_obs_decoder = TensorDictModule(
@@ -1387,49 +1484,22 @@ def make_dreamer(
13871484
transition_model,
13881485
reward_model,
13891486
),
1390-
prior_shape=torch.Size([cfg.state_dim]),
1391-
belief_shape=torch.Size([cfg.rssm_hidden_dim]),
1487+
prior_shape=torch.Size([state_dim]),
1488+
belief_shape=torch.Size([rssm_hidden_dim]),
13921489
obs_decoder=mb_env_obs_decoder,
13931490
)
13941491

13951492
model_based_env.set_specs_from_env(proof_environment)
13961493
model_based_env = TransformedEnv(model_based_env)
13971494
default_dict = {
1398-
"next_state": NdUnboundedContinuousTensorSpec(cfg.state_dim),
1399-
"next_belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
1495+
"next_state": NdUnboundedContinuousTensorSpec(state_dim),
1496+
"next_belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim),
14001497
# "action": proof_environment.action_spec,
14011498
}
14021499
model_based_env.append_transform(
14031500
TensorDictPrimer(random=False, default_value=0, **default_dict)
14041501
)
1405-
1406-
world_model = world_model.to(device)
1407-
1408-
# init nets
1409-
with torch.no_grad(), set_exploration_mode("random"):
1410-
tensordict = proof_environment.rollout(4)
1411-
tensordict = tensordict.to_tensordict().to(device)
1412-
tensordict = tensordict.to(device)
1413-
world_model(tensordict)
1414-
model_based_env = model_based_env.to(device)
1415-
1416-
actor_simulator = actor_simulator.to(device)
1417-
value_model = value_model.to(device)
1418-
1419-
with torch.no_grad(), set_exploration_mode("random"):
1420-
tensordict = model_based_env.rollout(4)
1421-
tensordict = tensordict.to(device)
1422-
tensordict = actor_simulator(tensordict)
1423-
value_model(tensordict)
1424-
1425-
actor_realworld = actor_realworld.to(device)
1426-
if proof_env_is_none:
1427-
proof_environment.close()
1428-
torch.cuda.empty_cache()
1429-
del proof_environment
1430-
1431-
del tensordict
1432-
return world_model, model_based_env, actor_simulator, value_model, actor_realworld
1502+
return model_based_env
14331503

14341504

14351505
@dataclass

0 commit comments

Comments
 (0)