Skip to content

Commit 22fd5ba

Browse files
[Docs] Fix multi-agent tutorial (#1599)
Signed-off-by: Matteo Bettini <matbet@meta.com>
1 parent 9ccae47 commit 22fd5ba

File tree

2 files changed

+22
-29
lines changed

2 files changed

+22
-29
lines changed

examples/multiagent/sac.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def train(cfg: "DictConfig"): # noqa: F821
258258
loss_vals["loss_actor"]
259259
+ loss_vals["loss_alpha"]
260260
+ loss_vals["loss_qvalue"]
261-
+ loss_vals["loss_alpha"]
262261
)
263262

264263
loss_value.backward()

tutorials/sphinx-tutorials/multiagent_ppo.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,11 @@
253253
#
254254
#
255255

256-
print("action_spec:", env.action_spec)
257-
print("reward_spec:", env.reward_spec)
258-
print("done_spec:", env.done_spec)
256+
print("action_spec:", env.full_action_spec)
257+
print("reward_spec:", env.full_reward_spec)
258+
print("done_spec:", env.full_done_spec)
259259
print("observation_spec:", env.observation_spec)
260260

261-
262261
######################################################################
263262
# Using the commands just shown we can access the domain of each value.
264263
# Doing this we can see that all specs apart from done have a leading shape ``(num_vmas_envs, n_agents)``.
@@ -270,35 +269,20 @@
270269
# In fact, specs that have the additional agent dimension
271270
# (i.e., they vary for each agent) will be contained in a inner "agents" key.
272271
#
273-
# To access the full structure of the specs we can use
274-
#
275-
276-
print("full_action_spec:", env.input_spec["full_action_spec"])
277-
print("full_reward_spec:", env.output_spec["full_reward_spec"])
278-
print("full_done_spec:", env.output_spec["full_done_spec"])
279-
280-
######################################################################
281272
# As you can see the reward and action spec present the "agent" key,
282273
# meaning that entries in tensordicts belonging to those specs will be nested in an "agents" tensordict,
283274
# grouping all per-agent values.
284275
#
285-
# To quickly access the key for each of these values in tensordicts, we can simply ask the environment for the
286-
# respective key, and
276+
# To quickly access the keys for each of these values in tensordicts, we can simply ask the environment for the
277+
# respective keys, and
287278
# we will immediately understand which are per-agent and which shared.
288279
# This info will be useful in order to tell all other TorchRL components where to find each value
289280
#
290281

291-
print("action_key:", env.action_key)
292-
print("reward_key:", env.reward_key)
293-
print("done_key:", env.done_key)
294-
295-
######################################################################
296-
# To tie it all together, we can see that passing these keys to the full specs gives us the leaf domains
297-
#
282+
print("action_keys:", env.action_keys)
283+
print("reward_keys:", env.reward_keys)
284+
print("done_keys:", env.done_keys)
298285

299-
assert env.action_spec == env.input_spec["full_action_spec"][env.action_key]
300-
assert env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key]
301-
assert env.done_spec == env.output_spec["full_done_spec"][env.done_key]
302286

303287
######################################################################
304288
# Transforms
@@ -615,6 +599,9 @@
615599
action=env.action_key,
616600
sample_log_prob=("agents", "sample_log_prob"),
617601
value=("agents", "state_value"),
602+
# These last 2 keys will be expanded to match the reward shape
603+
done=("agents", "done"),
604+
terminated=("agents", "terminated"),
618605
)
619606

620607

@@ -649,11 +636,18 @@
649636
episode_reward_mean_list = []
650637
for tensordict_data in collector:
651638
tensordict_data.set(
652-
("next", "done"),
639+
("next", "agents", "done"),
653640
tensordict_data.get(("next", "done"))
654641
.unsqueeze(-1)
655-
.expand(tensordict_data.get(("next", env.reward_key)).shape),
656-
) # We need to expand the done to match the reward shape (this is expected by the value estimator)
642+
.expand(tensordict_data.get_item_shape(("next", env.reward_key))),
643+
)
644+
tensordict_data.set(
645+
("next", "agents", "terminated"),
646+
tensordict_data.get(("next", "terminated"))
647+
.unsqueeze(-1)
648+
.expand(tensordict_data.get_item_shape(("next", env.reward_key))),
649+
)
650+
# We need to expand the done and terminated to match the reward shape (this is expected by the value estimator)
657651

658652
with torch.no_grad():
659653
GAE(
@@ -688,7 +682,7 @@
688682
collector.update_policy_weights_()
689683

690684
# Logging
691-
done = tensordict_data.get(("next", "done"))
685+
done = tensordict_data.get(("next", "agents", "done"))
692686
episode_reward_mean = (
693687
tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item()
694688
)

0 commit comments

Comments
 (0)