|
253 | 253 | #
|
254 | 254 | #
|
255 | 255 |
|
256 |
| -print("action_spec:", env.action_spec) |
257 |
| -print("reward_spec:", env.reward_spec) |
258 |
| -print("done_spec:", env.done_spec) |
| 256 | +print("action_spec:", env.full_action_spec) |
| 257 | +print("reward_spec:", env.full_reward_spec) |
| 258 | +print("done_spec:", env.full_done_spec) |
259 | 259 | print("observation_spec:", env.observation_spec)
|
260 | 260 |
|
261 |
| - |
262 | 261 | ######################################################################
|
263 | 262 | # Using the commands just shown we can access the domain of each value.
|
264 | 263 | # Doing this we can see that all specs apart from done have a leading shape ``(num_vmas_envs, n_agents)``.
|
|
270 | 269 | # In fact, specs that have the additional agent dimension
|
271 | 270 | # (i.e., they vary for each agent) will be contained in a inner "agents" key.
|
272 | 271 | #
|
273 |
| -# To access the full structure of the specs we can use |
274 |
| -# |
275 |
| - |
276 |
| -print("full_action_spec:", env.input_spec["full_action_spec"]) |
277 |
| -print("full_reward_spec:", env.output_spec["full_reward_spec"]) |
278 |
| -print("full_done_spec:", env.output_spec["full_done_spec"]) |
279 |
| - |
280 |
| -###################################################################### |
281 | 272 | # As you can see the reward and action spec present the "agent" key,
|
282 | 273 | # meaning that entries in tensordicts belonging to those specs will be nested in an "agents" tensordict,
|
283 | 274 | # grouping all per-agent values.
|
284 | 275 | #
|
285 |
| -# To quickly access the key for each of these values in tensordicts, we can simply ask the environment for the |
286 |
| -# respective key, and |
| 276 | +# To quickly access the keys for each of these values in tensordicts, we can simply ask the environment for the |
| 277 | +# respective keys, and |
287 | 278 | # we will immediately understand which are per-agent and which shared.
|
288 | 279 | # This info will be useful in order to tell all other TorchRL components where to find each value
|
289 | 280 | #
|
290 | 281 |
|
291 |
| -print("action_key:", env.action_key) |
292 |
| -print("reward_key:", env.reward_key) |
293 |
| -print("done_key:", env.done_key) |
294 |
| - |
295 |
| -###################################################################### |
296 |
| -# To tie it all together, we can see that passing these keys to the full specs gives us the leaf domains |
297 |
| -# |
| 282 | +print("action_keys:", env.action_keys) |
| 283 | +print("reward_keys:", env.reward_keys) |
| 284 | +print("done_keys:", env.done_keys) |
298 | 285 |
|
299 |
| -assert env.action_spec == env.input_spec["full_action_spec"][env.action_key] |
300 |
| -assert env.reward_spec == env.output_spec["full_reward_spec"][env.reward_key] |
301 |
| -assert env.done_spec == env.output_spec["full_done_spec"][env.done_key] |
302 | 286 |
|
303 | 287 | ######################################################################
|
304 | 288 | # Transforms
|
|
615 | 599 | action=env.action_key,
|
616 | 600 | sample_log_prob=("agents", "sample_log_prob"),
|
617 | 601 | value=("agents", "state_value"),
|
| 602 | + # These last 2 keys will be expanded to match the reward shape |
| 603 | + done=("agents", "done"), |
| 604 | + terminated=("agents", "terminated"), |
618 | 605 | )
|
619 | 606 |
|
620 | 607 |
|
|
649 | 636 | episode_reward_mean_list = []
|
650 | 637 | for tensordict_data in collector:
|
651 | 638 | tensordict_data.set(
|
652 |
| - ("next", "done"), |
| 639 | + ("next", "agents", "done"), |
653 | 640 | tensordict_data.get(("next", "done"))
|
654 | 641 | .unsqueeze(-1)
|
655 |
| - .expand(tensordict_data.get(("next", env.reward_key)).shape), |
656 |
| - ) # We need to expand the done to match the reward shape (this is expected by the value estimator) |
| 642 | + .expand(tensordict_data.get_item_shape(("next", env.reward_key))), |
| 643 | + ) |
| 644 | + tensordict_data.set( |
| 645 | + ("next", "agents", "terminated"), |
| 646 | + tensordict_data.get(("next", "terminated")) |
| 647 | + .unsqueeze(-1) |
| 648 | + .expand(tensordict_data.get_item_shape(("next", env.reward_key))), |
| 649 | + ) |
| 650 | + # We need to expand the done and terminated to match the reward shape (this is expected by the value estimator) |
657 | 651 |
|
658 | 652 | with torch.no_grad():
|
659 | 653 | GAE(
|
|
688 | 682 | collector.update_policy_weights_()
|
689 | 683 |
|
690 | 684 | # Logging
|
691 |
| - done = tensordict_data.get(("next", "done")) |
| 685 | + done = tensordict_data.get(("next", "agents", "done")) |
692 | 686 | episode_reward_mean = (
|
693 | 687 | tensordict_data.get(("next", "agents", "episode_reward"))[done].mean().item()
|
694 | 688 | )
|
|
0 commit comments