Skip to content

Commit 106368f

Browse files
Vincent Moensskandermoallamatteobettini
authored
[Feature] Make advantages compatible with Terminated, Truncated, Done (#1581)
Co-authored-by: Skander Moalla <37197319+skandermoalla@users.noreply.github.com> Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
1 parent 3785609 commit 106368f

File tree

22 files changed

+1203
-337
lines changed

22 files changed

+1203
-337
lines changed

examples/multiagent/iql.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchrl.modules.models.multiagent import MultiAgentMLP
2222
from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators
2323
from utils.logging import init_logging, log_evaluation, log_training
24+
from utils.utils import DoneTransform
2425

2526

2627
def rendering_callback(env, td):
@@ -111,6 +112,7 @@ def train(cfg: "DictConfig"): # noqa: F821
111112
storing_device=cfg.train.device,
112113
frames_per_batch=cfg.collector.frames_per_batch,
113114
total_frames=cfg.collector.total_frames,
115+
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
114116
)
115117

116118
replay_buffer = TensorDictReplayBuffer(
@@ -125,6 +127,8 @@ def train(cfg: "DictConfig"): # noqa: F821
125127
action=env.action_key,
126128
value=("agents", "chosen_action_value"),
127129
reward=env.reward_key,
130+
done=("agents", "done"),
131+
terminated=("agents", "terminated"),
128132
)
129133
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
130134
target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
@@ -144,13 +148,6 @@ def train(cfg: "DictConfig"): # noqa: F821
144148

145149
sampling_time = time.time() - sampling_start
146150

147-
tensordict_data.set(
148-
("next", "done"),
149-
tensordict_data.get(("next", "done"))
150-
.unsqueeze(-1)
151-
.expand(tensordict_data.get(("next", env.reward_key)).shape),
152-
) # We need to expand the done to match the reward shape
153-
154151
current_frames = tensordict_data.numel()
155152
total_frames += current_frames
156153
data_view = tensordict_data.reshape(-1)

examples/multiagent/maddpg_iddpg.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchrl.modules.models.multiagent import MultiAgentMLP
2727
from torchrl.objectives import DDPGLoss, SoftUpdate, ValueEstimators
2828
from utils.logging import init_logging, log_evaluation, log_training
29+
from utils.utils import DoneTransform
2930

3031

3132
def rendering_callback(env, td):
@@ -133,6 +134,7 @@ def train(cfg: "DictConfig"): # noqa: F821
133134
storing_device=cfg.train.device,
134135
frames_per_batch=cfg.collector.frames_per_batch,
135136
total_frames=cfg.collector.total_frames,
137+
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
136138
)
137139

138140
replay_buffer = TensorDictReplayBuffer(
@@ -147,6 +149,8 @@ def train(cfg: "DictConfig"): # noqa: F821
147149
loss_module.set_keys(
148150
state_action_value=("agents", "state_action_value"),
149151
reward=env.reward_key,
152+
done=("agents", "done"),
153+
terminated=("agents", "terminated"),
150154
)
151155
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
152156
target_net_updater = SoftUpdate(loss_module, eps=1 - cfg.loss.tau)
@@ -170,13 +174,6 @@ def train(cfg: "DictConfig"): # noqa: F821
170174

171175
sampling_time = time.time() - sampling_start
172176

173-
tensordict_data.set(
174-
("next", "done"),
175-
tensordict_data.get(("next", "done"))
176-
.unsqueeze(-1)
177-
.expand(tensordict_data.get(("next", env.reward_key)).shape),
178-
) # We need to expand the done to match the reward shape
179-
180177
current_frames = tensordict_data.numel()
181178
total_frames += current_frames
182179
data_view = tensordict_data.reshape(-1)

examples/multiagent/mappo_ippo.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torchrl.modules.models.multiagent import MultiAgentMLP
2323
from torchrl.objectives import ClipPPOLoss, ValueEstimators
2424
from utils.logging import init_logging, log_evaluation, log_training
25+
from utils.utils import DoneTransform
2526

2627

2728
def rendering_callback(env, td):
@@ -126,6 +127,7 @@ def train(cfg: "DictConfig"): # noqa: F821
126127
storing_device=cfg.train.device,
127128
frames_per_batch=cfg.collector.frames_per_batch,
128129
total_frames=cfg.collector.total_frames,
130+
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
129131
)
130132

131133
replay_buffer = TensorDictReplayBuffer(
@@ -142,7 +144,12 @@ def train(cfg: "DictConfig"): # noqa: F821
142144
entropy_coef=cfg.loss.entropy_eps,
143145
normalize_advantage=False,
144146
)
145-
loss_module.set_keys(reward=env.reward_key, action=env.action_key)
147+
loss_module.set_keys(
148+
reward=env.reward_key,
149+
action=env.action_key,
150+
done=("agents", "done"),
151+
terminated=("agents", "terminated"),
152+
)
146153
loss_module.make_value_estimator(
147154
ValueEstimators.GAE, gamma=cfg.loss.gamma, lmbda=cfg.loss.lmbda
148155
)
@@ -165,13 +172,6 @@ def train(cfg: "DictConfig"): # noqa: F821
165172

166173
sampling_time = time.time() - sampling_start
167174

168-
tensordict_data.set(
169-
("next", "done"),
170-
tensordict_data.get(("next", "done"))
171-
.unsqueeze(-1)
172-
.expand(tensordict_data.get(("next", env.reward_key)).shape),
173-
) # We need to expand the done to match the reward shape
174-
175175
with torch.no_grad():
176176
loss_module.value_estimator(
177177
tensordict_data,

examples/multiagent/sac.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchrl.modules.models.multiagent import MultiAgentMLP
2424
from torchrl.objectives import DiscreteSACLoss, SACLoss, SoftUpdate, ValueEstimators
2525
from utils.logging import init_logging, log_evaluation, log_training
26+
from utils.utils import DoneTransform
2627

2728

2829
def rendering_callback(env, td):
@@ -179,6 +180,7 @@ def train(cfg: "DictConfig"): # noqa: F821
179180
storing_device=cfg.train.device,
180181
frames_per_batch=cfg.collector.frames_per_batch,
181182
total_frames=cfg.collector.total_frames,
183+
postproc=DoneTransform(reward_key=env.reward_key, done_keys=env.done_keys),
182184
)
183185

184186
replay_buffer = TensorDictReplayBuffer(
@@ -198,6 +200,8 @@ def train(cfg: "DictConfig"): # noqa: F821
198200
state_action_value=("agents", "state_action_value"),
199201
action=env.action_key,
200202
reward=env.reward_key,
203+
done=("agents", "done"),
204+
terminated=("agents", "terminated"),
201205
)
202206
else:
203207
loss_module = DiscreteSACLoss(
@@ -211,6 +215,8 @@ def train(cfg: "DictConfig"): # noqa: F821
211215
action_value=("agents", "action_value"),
212216
action=env.action_key,
213217
reward=env.reward_key,
218+
done=("agents", "done"),
219+
terminated=("agents", "terminated"),
214220
)
215221

216222
loss_module.make_value_estimator(ValueEstimators.TD0, gamma=cfg.loss.gamma)
@@ -235,13 +241,6 @@ def train(cfg: "DictConfig"): # noqa: F821
235241

236242
sampling_time = time.time() - sampling_start
237243

238-
tensordict_data.set(
239-
("next", "done"),
240-
tensordict_data.get(("next", "done"))
241-
.unsqueeze(-1)
242-
.expand(tensordict_data.get(("next", env.reward_key)).shape),
243-
) # We need to expand the done to match the reward shape
244-
245244
current_frames = tensordict_data.numel()
246245
total_frames += current_frames
247246
data_view = tensordict_data.reshape(-1)

examples/multiagent/utils/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from tensordict import unravel_key
6+
from torchrl.envs import Transform
7+
8+
9+
def swap_last(source, dest):
10+
source = unravel_key(source)
11+
dest = unravel_key(dest)
12+
if isinstance(source, str):
13+
if isinstance(dest, str):
14+
return dest
15+
return dest[-1]
16+
if isinstance(dest, str):
17+
return source[:-1] + (dest,)
18+
return source[:-1] + (dest[-1],)
19+
20+
21+
class DoneTransform(Transform):
22+
"""Expands the 'done' entries (incl. terminated) to match the reward shape.
23+
24+
Can be appended to a replay buffer or a collector.
25+
"""
26+
27+
def __init__(self, reward_key, done_keys):
28+
super().__init__()
29+
self.reward_key = reward_key
30+
self.done_keys = done_keys
31+
32+
def forward(self, tensordict):
33+
for done_key in self.done_keys:
34+
new_name = swap_last(self.reward_key, done_key)
35+
tensordict.set(
36+
("next", new_name),
37+
tensordict.get(("next", done_key))
38+
.unsqueeze(-1)
39+
.expand(tensordict.get(("next", self.reward_key)).shape),
40+
)
41+
return tensordict

0 commit comments

Comments
 (0)