|
8 | 8 | import numpy as np
|
9 | 9 | import pytest
|
10 | 10 | import torch
|
11 |
| - |
12 | 11 | from _utils_internal import (
|
13 | 12 | get_available_devices,
|
14 | 13 | HALFCHEETAH_VERSIONED,
|
15 | 14 | PENDULUM_VERSIONED,
|
16 | 15 | PONG_VERSIONED,
|
17 | 16 | )
|
18 | 17 | from packaging import version
|
19 |
| -from tensordict.tensordict import assert_allclose_td |
| 18 | +from tensordict.tensordict import assert_allclose_td, TensorDict |
20 | 19 | from torchrl._utils import implement_for
|
21 | 20 | from torchrl.collectors import MultiaSyncDataCollector
|
22 | 21 | from torchrl.collectors.collectors import RandomPolicy
|
@@ -685,6 +684,67 @@ def make_vmas():
|
685 | 684 | [n_workers, list(env.n_agents)[0], list(env.num_envs)[0], n_rollout_samples]
|
686 | 685 | )
|
687 | 686 |
|
| 687 | + @pytest.mark.parametrize("num_envs", [1, 10]) |
| 688 | + @pytest.mark.parametrize("n_workers", [1, 3]) |
| 689 | + def test_vmas_reset( |
| 690 | + self, |
| 691 | + scenario_name, |
| 692 | + num_envs, |
| 693 | + n_workers, |
| 694 | + n_agents=5, |
| 695 | + n_rollout_samples=3, |
| 696 | + max_steps=3, |
| 697 | + ): |
| 698 | + def make_vmas(): |
| 699 | + env = VmasEnv( |
| 700 | + scenario_name=scenario_name, |
| 701 | + num_envs=num_envs, |
| 702 | + n_agents=n_agents, |
| 703 | + max_steps=max_steps, |
| 704 | + ) |
| 705 | + env.set_seed(0) |
| 706 | + return env |
| 707 | + |
| 708 | + env = ParallelEnv(n_workers, make_vmas) |
| 709 | + tensordict = env.rollout(max_steps=n_rollout_samples) |
| 710 | + |
| 711 | + assert tensordict["done"].squeeze(-1)[..., -1].all() |
| 712 | + |
| 713 | + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) |
| 714 | + while not _reset.any(): |
| 715 | + _reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool) |
| 716 | + |
| 717 | + tensordict = env.reset( |
| 718 | + TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device) |
| 719 | + ) |
| 720 | + assert tensordict["done"][_reset].all().item() is False |
| 721 | + # vmas resets all the agent dimension if only one of the agents needs resetting |
| 722 | + # thus, here we check that where we did not reset any agent, all agents are still done |
| 723 | + assert tensordict["done"].all(dim=1)[~_reset.any(dim=1)].all().item() is True |
| 724 | + |
| 725 | + @pytest.mark.skipif(len(get_available_devices()) < 2, reason="not enough devices") |
| 726 | + @pytest.mark.parametrize("first", [0, 1]) |
| 727 | + def test_to_device(self, scenario_name: str, first: int): |
| 728 | + devices = get_available_devices() |
| 729 | + |
| 730 | + def make_vmas(): |
| 731 | + env = VmasEnv( |
| 732 | + scenario_name=scenario_name, |
| 733 | + num_envs=7, |
| 734 | + n_agents=3, |
| 735 | + seed=0, |
| 736 | + device=devices[first], |
| 737 | + ) |
| 738 | + return env |
| 739 | + |
| 740 | + env = ParallelEnv(3, make_vmas) |
| 741 | + |
| 742 | + assert env.rollout(max_steps=3).device == devices[first] |
| 743 | + |
| 744 | + env.to(devices[1 - first]) |
| 745 | + |
| 746 | + assert env.rollout(max_steps=3).device == devices[1 - first] |
| 747 | + |
688 | 748 |
|
689 | 749 | if __name__ == "__main__":
|
690 | 750 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
|
0 commit comments