Skip to content

Commit 594462d

Browse files
authored
[Feature] Add Stack transform (#2567)
1 parent 1cffffe commit 594462d

File tree

7 files changed

+764
-5
lines changed

7 files changed

+764
-5
lines changed

.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ conda deactivate && conda activate ./env
2323
python -c "import mlagents_envs"
2424

2525
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow
26+
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py --instafail -v --durations 200 --capture no -k test_transform_env[unity]
2627

2728
coverage combine
2829
coverage xml -i

test/mocking_classes.py

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import Optional
5+
from typing import Dict, List, Optional
66

77
import torch
88
import torch.nn as nn
@@ -24,7 +24,12 @@
2424
from torchrl.data.utils import consolidate_spec
2525
from torchrl.envs.common import EnvBase
2626
from torchrl.envs.model_based.common import ModelBasedEnvBase
27-
from torchrl.envs.utils import _terminated_or_truncated
27+
from torchrl.envs.utils import (
28+
_terminated_or_truncated,
29+
check_marl_grouping,
30+
MarlGroupMapType,
31+
)
32+
2833

2934
spec_dict = {
3035
"bounded": Bounded,
@@ -1059,6 +1064,154 @@ def _step(
10591064
return tensordict
10601065

10611066

1067+
class MultiAgentCountingEnv(EnvBase):
1068+
"""A multi-agent env that is done after a given number of steps.
1069+
1070+
All agents have identical specs.
1071+
1072+
The count is incremented by 1 on each step.
1073+
1074+
"""
1075+
1076+
def __init__(
1077+
self,
1078+
n_agents: int,
1079+
group_map: MarlGroupMapType
1080+
| Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
1081+
max_steps: int = 5,
1082+
start_val: int = 0,
1083+
**kwargs,
1084+
):
1085+
super().__init__(**kwargs)
1086+
self.max_steps = max_steps
1087+
self.start_val = start_val
1088+
self.n_agents = n_agents
1089+
self.agent_names = [f"agent_{idx}" for idx in range(n_agents)]
1090+
1091+
if isinstance(group_map, MarlGroupMapType):
1092+
group_map = group_map.get_group_map(self.agent_names)
1093+
check_marl_grouping(group_map, self.agent_names)
1094+
1095+
self.group_map = group_map
1096+
1097+
observation_specs = {}
1098+
reward_specs = {}
1099+
done_specs = {}
1100+
action_specs = {}
1101+
1102+
for group_name, agents in group_map.items():
1103+
observation_specs[group_name] = {}
1104+
reward_specs[group_name] = {}
1105+
done_specs[group_name] = {}
1106+
action_specs[group_name] = {}
1107+
1108+
for agent_name in agents:
1109+
observation_specs[group_name][agent_name] = Composite(
1110+
observation=Unbounded(
1111+
(
1112+
*self.batch_size,
1113+
3,
1114+
4,
1115+
),
1116+
dtype=torch.float32,
1117+
device=self.device,
1118+
),
1119+
shape=self.batch_size,
1120+
device=self.device,
1121+
)
1122+
reward_specs[group_name][agent_name] = Composite(
1123+
reward=Unbounded(
1124+
(
1125+
*self.batch_size,
1126+
1,
1127+
),
1128+
device=self.device,
1129+
),
1130+
shape=self.batch_size,
1131+
device=self.device,
1132+
)
1133+
done_specs[group_name][agent_name] = Composite(
1134+
done=Categorical(
1135+
2,
1136+
dtype=torch.bool,
1137+
shape=(
1138+
*self.batch_size,
1139+
1,
1140+
),
1141+
device=self.device,
1142+
),
1143+
shape=self.batch_size,
1144+
device=self.device,
1145+
)
1146+
action_specs[group_name][agent_name] = Composite(
1147+
action=Binary(n=1, shape=[*self.batch_size, 1], device=self.device),
1148+
shape=self.batch_size,
1149+
device=self.device,
1150+
)
1151+
1152+
self.observation_spec = Composite(observation_specs)
1153+
self.reward_spec = Composite(reward_specs)
1154+
self.done_spec = Composite(done_specs)
1155+
self.action_spec = Composite(action_specs)
1156+
self.register_buffer(
1157+
"count",
1158+
torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
1159+
)
1160+
1161+
def _set_seed(self, seed: Optional[int]):
1162+
torch.manual_seed(seed)
1163+
1164+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1165+
if tensordict is not None and "_reset" in tensordict.keys():
1166+
_reset = tensordict.get("_reset")
1167+
self.count[_reset] = self.start_val
1168+
else:
1169+
self.count[:] = self.start_val
1170+
1171+
source = {}
1172+
for group_name, agents in self.group_map.items():
1173+
source[group_name] = {}
1174+
for agent_name in agents:
1175+
source[group_name][agent_name] = TensorDict(
1176+
source={
1177+
"observation": torch.rand(
1178+
(*self.batch_size, 3, 4), device=self.device
1179+
),
1180+
"done": self.count > self.max_steps,
1181+
"terminated": self.count > self.max_steps,
1182+
},
1183+
batch_size=self.batch_size,
1184+
device=self.device,
1185+
)
1186+
1187+
tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
1188+
return tensordict
1189+
1190+
def _step(
1191+
self,
1192+
tensordict: TensorDictBase,
1193+
) -> TensorDictBase:
1194+
self.count += 1
1195+
source = {}
1196+
for group_name, agents in self.group_map.items():
1197+
source[group_name] = {}
1198+
for agent_name in agents:
1199+
source[group_name][agent_name] = TensorDict(
1200+
source={
1201+
"observation": torch.rand(
1202+
(*self.batch_size, 3, 4), device=self.device
1203+
),
1204+
"done": self.count > self.max_steps,
1205+
"terminated": self.count > self.max_steps,
1206+
"reward": torch.zeros_like(self.count, dtype=torch.float),
1207+
},
1208+
batch_size=self.batch_size,
1209+
device=self.device,
1210+
)
1211+
tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
1212+
return tensordict
1213+
1214+
10621215
class IncrementingEnv(CountingEnv):
10631216
# Same as CountingEnv but always increments the count by 1 regardless of the action.
10641217
def _step(

0 commit comments

Comments
 (0)