Skip to content

Commit f62785b

Browse files
BY571vmoens
andauthored
[Algorithm] Update DT (#1560)
Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent 001cf33 commit f62785b

File tree

5 files changed

+94
-48
lines changed

5 files changed

+94
-48
lines changed

examples/decision_transformer/dt.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,19 @@
66
This is a self-contained example of an offline Decision Transformer training script.
77
The helper functions are coded in the utils.py associated with this script.
88
"""
9+
import time
910

1011
import hydra
12+
import numpy as np
1113
import torch
1214
import tqdm
15+
from torchrl.envs.libs.gym import set_gym_backend
1316

1417
from torchrl.envs.utils import ExplorationType, set_exploration_type
1518
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
1619

1720
from utils import (
21+
log_metrics,
1822
make_dt_loss,
1923
make_dt_model,
2024
make_dt_optimizer,
@@ -24,29 +28,44 @@
2428
)
2529

2630

31+
@set_gym_backend("gym") # D4RL uses gym so we make sure gymnasium is hidden
2732
@hydra.main(config_path=".", config_name="dt_config")
2833
def main(cfg: "DictConfig"): # noqa: F821
2934
model_device = cfg.optim.device
35+
36+
# Set seeds
37+
torch.manual_seed(cfg.env.seed)
38+
np.random.seed(cfg.env.seed)
39+
40+
# Create logger
3041
logger = make_logger(cfg)
42+
43+
# Create offline replay buffer
3144
offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
3245
cfg.replay_buffer, cfg.env.reward_scaling
3346
)
47+
48+
# Create test environment
3449
test_env = make_env(cfg.env, obs_loc, obs_std)
50+
51+
# Create policy model
3552
actor = make_dt_model(cfg)
3653
policy = actor.to(model_device)
3754

55+
# Create loss
3856
loss_module = make_dt_loss(cfg.loss, actor)
57+
58+
# Create optimizer
3959
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)
60+
61+
# Create inference policy
4062
inference_policy = DecisionTransformerInferenceWrapper(
4163
policy=policy,
4264
inference_context=cfg.env.inference_context,
4365
).to(model_device)
4466

4567
pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)
4668

47-
r0 = None
48-
l0 = None
49-
5069
pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
5170
clip_grad = cfg.optim.clip_grad
5271
eval_steps = cfg.logger.eval_steps
@@ -55,12 +74,14 @@ def main(cfg: "DictConfig"): # noqa: F821
5574

5675
print(" ***Pretraining*** ")
5776
# Pretraining
77+
start_time = time.time()
5878
for i in range(pretrain_gradient_steps):
5979
pbar.update(i)
80+
81+
# Sample data
6082
data = offline_buffer.sample()
61-
# loss
83+
# Compute loss
6284
loss_vals = loss_module(data.to(model_device))
63-
# backprop
6485
transformer_loss = loss_vals["loss"]
6586

6687
transformer_optim.zero_grad()
@@ -70,28 +91,25 @@ def main(cfg: "DictConfig"): # noqa: F821
7091

7192
scheduler.step()
7293

73-
# evaluation
74-
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
94+
# Log metrics
95+
to_log = {"train/loss": loss_vals["loss"]}
96+
97+
# Evaluation
98+
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
7599
if i % pretrain_log_interval == 0:
76100
eval_td = test_env.rollout(
77101
max_steps=eval_steps,
78102
policy=inference_policy,
79103
auto_cast_to_device=True,
80104
)
81-
if r0 is None:
82-
r0 = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
83-
if l0 is None:
84-
l0 = transformer_loss.item()
85-
86-
eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
105+
to_log["eval/reward"] = (
106+
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
107+
)
87108
if logger is not None:
88-
for key, value in loss_vals.items():
89-
logger.log_scalar(key, value.item(), i)
90-
logger.log_scalar("evaluation reward", eval_reward, i)
109+
log_metrics(logger, to_log, i)
91110

92-
pbar.set_description(
93-
f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})"
94-
)
111+
pbar.close()
112+
print(f"Training time: {time.time() - start_time}")
95113

96114

97115
if __name__ == "__main__":

examples/decision_transformer/dt_config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Task and env
1+
# environment and task
22
env:
33
name: HalfCheetah-v3
44
task: ""
@@ -25,7 +25,7 @@ logger:
2525
fintune_log_interval: 1
2626
eval_steps: 1000
2727

28-
# Buffer
28+
# replay buffer
2929
replay_buffer:
3030
dataset: halfcheetah-medium-v2
3131
batch_size: 64
@@ -37,13 +37,12 @@ replay_buffer:
3737
device: cpu
3838
prefetch: 3
3939

40-
# Optimization
40+
# optimization
4141
optim:
4242
device: cuda:0
4343
lr: 1.0e-4
4444
weight_decay: 5.0e-4
4545
batch_size: 64
46-
lr_scheduler: ""
4746
pretrain_gradient_steps: 55000
4847
updates_per_episode: 300
4948
warmup_steps: 10000
@@ -52,7 +51,8 @@ optim:
5251
# loss
5352
loss:
5453
loss_function: "l2"
55-
54+
55+
# transformer model
5656
transformer:
5757
n_embd: 128
5858
n_layer: 3

examples/decision_transformer/odt_config.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Task and env
1+
# environment and task
22
env:
33
name: HalfCheetah-v3
44
task: ""
@@ -10,7 +10,6 @@ env:
1010
num_train_envs: 1
1111
num_eval_envs: 10
1212
reward_scaling: 0.001 # for r2g
13-
noop: 1
1413
seed: 42
1514
target_return_mode: reduce
1615
eval_target_return: 6000
@@ -26,7 +25,7 @@ logger:
2625
fintune_log_interval: 1
2726
eval_steps: 1000
2827

29-
# Buffer
28+
# replay buffer
3029
replay_buffer:
3130
dataset: halfcheetah-medium-v2
3231
batch_size: 256
@@ -38,13 +37,12 @@ replay_buffer:
3837
device: cuda:0
3938
prefetch: 3
4039

41-
# Optimization
40+
# optimizer
4241
optim:
4342
device: cuda:0
4443
lr: 1.0e-4
4544
weight_decay: 5.0e-4
4645
batch_size: 256
47-
lr_scheduler: ""
4846
pretrain_gradient_steps: 10000
4947
updates_per_episode: 300
5048
warmup_steps: 10000
@@ -55,6 +53,7 @@ loss:
5553
alpha_init: 0.1
5654
target_entropy: auto
5755

56+
# transformer model
5857
transformer:
5958
n_embd: 512
6059
n_layer: 4

examples/decision_transformer/online_dt.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@
77
The helper functions are coded in the utils.py associated with this script.
88
"""
99

10+
import time
11+
1012
import hydra
13+
import numpy as np
1114
import torch
1215
import tqdm
13-
1416
from torchrl.envs.libs.gym import set_gym_backend
1517

1618
from torchrl.envs.utils import ExplorationType, set_exploration_type
1719
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
1820

1921
from utils import (
22+
log_metrics,
2023
make_env,
2124
make_logger,
2225
make_odt_loss,
@@ -31,28 +34,41 @@
3134
def main(cfg: "DictConfig"): # noqa: F821
3235
model_device = cfg.optim.device
3336

37+
# Set seeds
38+
torch.manual_seed(cfg.env.seed)
39+
np.random.seed(cfg.env.seed)
40+
41+
# Create logger
3442
logger = make_logger(cfg)
43+
44+
# Create offline replay buffer
3545
offline_buffer, obs_loc, obs_std = make_offline_replay_buffer(
3646
cfg.replay_buffer, cfg.env.reward_scaling
3747
)
48+
49+
# Create test environment
3850
test_env = make_env(cfg.env, obs_loc, obs_std)
3951

52+
# Create policy model
4053
actor = make_odt_model(cfg)
4154
policy = actor.to(model_device)
4255

56+
# Create loss
4357
loss_module = make_odt_loss(cfg.loss, policy)
58+
59+
# Create optimizer
4460
transformer_optim, temperature_optim, scheduler = make_odt_optimizer(
4561
cfg.optim, loss_module
4662
)
63+
64+
# Create inference policy
4765
inference_policy = DecisionTransformerInferenceWrapper(
4866
policy=policy,
4967
inference_context=cfg.env.inference_context,
5068
).to(model_device)
5169

5270
pbar = tqdm.tqdm(total=cfg.optim.pretrain_gradient_steps)
5371

54-
r0 = None
55-
l0 = None
5672
pretrain_gradient_steps = cfg.optim.pretrain_gradient_steps
5773
clip_grad = cfg.optim.clip_grad
5874
eval_steps = cfg.logger.eval_steps
@@ -61,10 +77,12 @@ def main(cfg: "DictConfig"): # noqa: F821
6177

6278
print(" ***Pretraining*** ")
6379
# Pretraining
80+
start_time = time.time()
6481
for i in range(pretrain_gradient_steps):
6582
pbar.update(i)
83+
# Sample data
6684
data = offline_buffer.sample()
67-
# loss
85+
# Compute loss
6886
loss_vals = loss_module(data.to(model_device))
6987
transformer_loss = loss_vals["loss_log_likelihood"] + loss_vals["loss_entropy"]
7088
temperature_loss = loss_vals["loss_alpha"]
@@ -80,7 +98,16 @@ def main(cfg: "DictConfig"): # noqa: F821
8098

8199
scheduler.step()
82100

83-
# evaluation
101+
# Log metrics
102+
to_log = {
103+
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"].item(),
104+
"train/loss_entropy": loss_vals["loss_entropy"].item(),
105+
"train/loss_alpha": loss_vals["loss_alpha"].item(),
106+
"train/alpha": loss_vals["alpha"].item(),
107+
"train/entropy": loss_vals["entropy"].item(),
108+
}
109+
110+
# Evaluation
84111
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
85112
inference_policy.eval()
86113
if i % pretrain_log_interval == 0:
@@ -91,20 +118,15 @@ def main(cfg: "DictConfig"): # noqa: F821
91118
break_when_any_done=False,
92119
)
93120
inference_policy.train()
94-
if r0 is None:
95-
r0 = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
96-
if l0 is None:
97-
l0 = transformer_loss.item()
121+
to_log["eval/reward"] = (
122+
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
123+
)
98124

99-
eval_reward = eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
100125
if logger is not None:
101-
for key, value in loss_vals.items():
102-
logger.log_scalar(key, value.item(), i)
103-
logger.log_scalar("evaluation reward", eval_reward, i)
126+
log_metrics(logger, to_log, i)
104127

105-
pbar.set_description(
106-
f"[Pre-Training] loss: {transformer_loss.item(): 4.4f} (init: {l0: 4.4f}), evaluation reward: {eval_reward: 4.4f} (init={r0: 4.4f})"
107-
)
128+
pbar.close()
129+
print(f"Training time: {time.time() - start_time}")
108130

109131

110132
if __name__ == "__main__":

examples/decision_transformer/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
DoubleToFloat,
2020
EnvCreator,
2121
ExcludeTransform,
22-
NoopResetEnv,
2322
ObservationNorm,
2423
RandomCropTensorDict,
2524
Reward2GoTransform,
@@ -65,8 +64,6 @@ def make_base_env(env_cfg):
6564
env_task = env_cfg.task
6665
env_kwargs.update({"task_name": env_task})
6766
env = env_library(**env_kwargs)
68-
if env_cfg.noop > 1:
69-
env = TransformedEnv(env, NoopResetEnv(env_cfg.noop))
7067
return env
7168

7269

@@ -472,3 +469,13 @@ def make_logger(cfg):
472469
wandb_kwargs={"config": cfg},
473470
)
474471
return logger
472+
473+
474+
# ====================================================================
475+
# General utils
476+
# ---------
477+
478+
479+
def log_metrics(logger, metrics, step):
480+
for metric_name, metric_value in metrics.items():
481+
logger.log_scalar(metric_name, metric_value, step)

0 commit comments

Comments
 (0)