Skip to content

Commit 6bebea7

Browse files
[Feature] Vmas to device (#850)
1 parent 20b6fc9 commit 6bebea7

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

test/test_libs.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88
import numpy as np
99
import pytest
1010
import torch
11-
1211
from _utils_internal import (
1312
get_available_devices,
1413
HALFCHEETAH_VERSIONED,
1514
PENDULUM_VERSIONED,
1615
PONG_VERSIONED,
1716
)
1817
from packaging import version
19-
from tensordict.tensordict import assert_allclose_td
18+
from tensordict.tensordict import assert_allclose_td, TensorDict
2019
from torchrl._utils import implement_for
2120
from torchrl.collectors import MultiaSyncDataCollector
2221
from torchrl.collectors.collectors import RandomPolicy
@@ -685,6 +684,67 @@ def make_vmas():
685684
[n_workers, list(env.n_agents)[0], list(env.num_envs)[0], n_rollout_samples]
686685
)
687686

687+
@pytest.mark.parametrize("num_envs", [1, 10])
688+
@pytest.mark.parametrize("n_workers", [1, 3])
689+
def test_vmas_reset(
690+
self,
691+
scenario_name,
692+
num_envs,
693+
n_workers,
694+
n_agents=5,
695+
n_rollout_samples=3,
696+
max_steps=3,
697+
):
698+
def make_vmas():
699+
env = VmasEnv(
700+
scenario_name=scenario_name,
701+
num_envs=num_envs,
702+
n_agents=n_agents,
703+
max_steps=max_steps,
704+
)
705+
env.set_seed(0)
706+
return env
707+
708+
env = ParallelEnv(n_workers, make_vmas)
709+
tensordict = env.rollout(max_steps=n_rollout_samples)
710+
711+
assert tensordict["done"].squeeze(-1)[..., -1].all()
712+
713+
_reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool)
714+
while not _reset.any():
715+
_reset = torch.randint(low=0, high=2, size=env.batch_size, dtype=torch.bool)
716+
717+
tensordict = env.reset(
718+
TensorDict({"_reset": _reset}, batch_size=env.batch_size, device=env.device)
719+
)
720+
assert tensordict["done"][_reset].all().item() is False
721+
# vmas resets all the agent dimension if only one of the agents needs resetting
722+
# thus, here we check that where we did not reset any agent, all agents are still done
723+
assert tensordict["done"].all(dim=1)[~_reset.any(dim=1)].all().item() is True
724+
725+
@pytest.mark.skipif(len(get_available_devices()) < 2, reason="not enough devices")
726+
@pytest.mark.parametrize("first", [0, 1])
727+
def test_to_device(self, scenario_name: str, first: int):
728+
devices = get_available_devices()
729+
730+
def make_vmas():
731+
env = VmasEnv(
732+
scenario_name=scenario_name,
733+
num_envs=7,
734+
n_agents=3,
735+
seed=0,
736+
device=devices[first],
737+
)
738+
return env
739+
740+
env = ParallelEnv(3, make_vmas)
741+
742+
assert env.rollout(max_steps=3).device == devices[first]
743+
744+
env.to(devices[1 - first])
745+
746+
assert env.rollout(max_steps=3).device == devices[1 - first]
747+
688748

689749
if __name__ == "__main__":
690750
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/libs/vmas.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import torch
44
from tensordict.tensordict import TensorDict, TensorDictBase
5-
6-
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
7-
from torchrl.envs.common import _EnvWrapper
5+
from torchrl.data import CompositeSpec, DEVICE_TYPING, UnboundedContinuousTensorSpec
6+
from torchrl.envs.common import _EnvWrapper, EnvBase
87
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
98
from torchrl.envs.utils import _selective_unsqueeze
109

@@ -210,17 +209,23 @@ def _reset(
210209
self, tensordict: Optional[TensorDictBase] = None, **kwargs
211210
) -> TensorDictBase:
212211
if tensordict is not None and "_reset" in tensordict.keys():
213-
envs_to_reset = tensordict.get("_reset").any(dim=0)
212+
_reset = tensordict.get("_reset")
213+
envs_to_reset = _reset.any(dim=0)
214214
for env_index, to_reset in enumerate(envs_to_reset):
215215
if to_reset:
216216
self._env.reset_at(env_index)
217+
done = _selective_unsqueeze(self._env.done(), batch_size=(self.num_envs,))
217218
obs = []
218219
infos = []
220+
dones = []
219221
for agent in self.agents:
220222
obs.append(self.scenario.observation(agent))
221223
infos.append(self.scenario.info(agent))
224+
dones.append(done.clone())
225+
222226
else:
223227
obs, infos = self._env.reset(return_info=True)
228+
dones = None
224229

225230
agent_tds = []
226231
for i in range(self.n_agents):
@@ -237,6 +242,8 @@ def _reset(
237242

238243
if infos is not None:
239244
agent_td.set("info", agent_info)
245+
if dones is not None:
246+
agent_td.set("done", dones[i])
240247
agent_tds.append(agent_td)
241248

242249
tensordict_out = torch.stack(agent_tds, dim=0)
@@ -324,6 +331,10 @@ def __repr__(self) -> str:
324331
f" batch_size={self.batch_size}, device={self.device})"
325332
)
326333

334+
def to(self, device: DEVICE_TYPING) -> EnvBase:
335+
self._env.to(device)
336+
return super().to(device)
337+
327338

328339
class VmasEnv(VmasWrapper):
329340
"""Vmas environment wrapper.

0 commit comments

Comments
 (0)