|
2 | 2 | #
|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 |
| -from typing import Optional |
| 5 | +from typing import Dict, List, Optional |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | import torch.nn as nn
|
|
24 | 24 | from torchrl.data.utils import consolidate_spec
|
25 | 25 | from torchrl.envs.common import EnvBase
|
26 | 26 | from torchrl.envs.model_based.common import ModelBasedEnvBase
|
27 |
| -from torchrl.envs.utils import _terminated_or_truncated |
| 27 | +from torchrl.envs.utils import ( |
| 28 | + _terminated_or_truncated, |
| 29 | + check_marl_grouping, |
| 30 | + MarlGroupMapType, |
| 31 | +) |
| 32 | + |
28 | 33 |
|
29 | 34 | spec_dict = {
|
30 | 35 | "bounded": Bounded,
|
@@ -1059,6 +1064,154 @@ def _step(
|
1059 | 1064 | return tensordict
|
1060 | 1065 |
|
1061 | 1066 |
|
| 1067 | +class MultiAgentCountingEnv(EnvBase): |
| 1068 | + """A multi-agent env that is done after a given number of steps. |
| 1069 | +
|
| 1070 | + All agents have identical specs. |
| 1071 | +
|
| 1072 | + The count is incremented by 1 on each step. |
| 1073 | +
|
| 1074 | + """ |
| 1075 | + |
| 1076 | + def __init__( |
| 1077 | + self, |
| 1078 | + n_agents: int, |
| 1079 | + group_map: MarlGroupMapType |
| 1080 | + | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, |
| 1081 | + max_steps: int = 5, |
| 1082 | + start_val: int = 0, |
| 1083 | + **kwargs, |
| 1084 | + ): |
| 1085 | + super().__init__(**kwargs) |
| 1086 | + self.max_steps = max_steps |
| 1087 | + self.start_val = start_val |
| 1088 | + self.n_agents = n_agents |
| 1089 | + self.agent_names = [f"agent_{idx}" for idx in range(n_agents)] |
| 1090 | + |
| 1091 | + if isinstance(group_map, MarlGroupMapType): |
| 1092 | + group_map = group_map.get_group_map(self.agent_names) |
| 1093 | + check_marl_grouping(group_map, self.agent_names) |
| 1094 | + |
| 1095 | + self.group_map = group_map |
| 1096 | + |
| 1097 | + observation_specs = {} |
| 1098 | + reward_specs = {} |
| 1099 | + done_specs = {} |
| 1100 | + action_specs = {} |
| 1101 | + |
| 1102 | + for group_name, agents in group_map.items(): |
| 1103 | + observation_specs[group_name] = {} |
| 1104 | + reward_specs[group_name] = {} |
| 1105 | + done_specs[group_name] = {} |
| 1106 | + action_specs[group_name] = {} |
| 1107 | + |
| 1108 | + for agent_name in agents: |
| 1109 | + observation_specs[group_name][agent_name] = Composite( |
| 1110 | + observation=Unbounded( |
| 1111 | + ( |
| 1112 | + *self.batch_size, |
| 1113 | + 3, |
| 1114 | + 4, |
| 1115 | + ), |
| 1116 | + dtype=torch.float32, |
| 1117 | + device=self.device, |
| 1118 | + ), |
| 1119 | + shape=self.batch_size, |
| 1120 | + device=self.device, |
| 1121 | + ) |
| 1122 | + reward_specs[group_name][agent_name] = Composite( |
| 1123 | + reward=Unbounded( |
| 1124 | + ( |
| 1125 | + *self.batch_size, |
| 1126 | + 1, |
| 1127 | + ), |
| 1128 | + device=self.device, |
| 1129 | + ), |
| 1130 | + shape=self.batch_size, |
| 1131 | + device=self.device, |
| 1132 | + ) |
| 1133 | + done_specs[group_name][agent_name] = Composite( |
| 1134 | + done=Categorical( |
| 1135 | + 2, |
| 1136 | + dtype=torch.bool, |
| 1137 | + shape=( |
| 1138 | + *self.batch_size, |
| 1139 | + 1, |
| 1140 | + ), |
| 1141 | + device=self.device, |
| 1142 | + ), |
| 1143 | + shape=self.batch_size, |
| 1144 | + device=self.device, |
| 1145 | + ) |
| 1146 | + action_specs[group_name][agent_name] = Composite( |
| 1147 | + action=Binary(n=1, shape=[*self.batch_size, 1], device=self.device), |
| 1148 | + shape=self.batch_size, |
| 1149 | + device=self.device, |
| 1150 | + ) |
| 1151 | + |
| 1152 | + self.observation_spec = Composite(observation_specs) |
| 1153 | + self.reward_spec = Composite(reward_specs) |
| 1154 | + self.done_spec = Composite(done_specs) |
| 1155 | + self.action_spec = Composite(action_specs) |
| 1156 | + self.register_buffer( |
| 1157 | + "count", |
| 1158 | + torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), |
| 1159 | + ) |
| 1160 | + |
| 1161 | + def _set_seed(self, seed: Optional[int]): |
| 1162 | + torch.manual_seed(seed) |
| 1163 | + |
| 1164 | + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: |
| 1165 | + if tensordict is not None and "_reset" in tensordict.keys(): |
| 1166 | + _reset = tensordict.get("_reset") |
| 1167 | + self.count[_reset] = self.start_val |
| 1168 | + else: |
| 1169 | + self.count[:] = self.start_val |
| 1170 | + |
| 1171 | + source = {} |
| 1172 | + for group_name, agents in self.group_map.items(): |
| 1173 | + source[group_name] = {} |
| 1174 | + for agent_name in agents: |
| 1175 | + source[group_name][agent_name] = TensorDict( |
| 1176 | + source={ |
| 1177 | + "observation": torch.rand( |
| 1178 | + (*self.batch_size, 3, 4), device=self.device |
| 1179 | + ), |
| 1180 | + "done": self.count > self.max_steps, |
| 1181 | + "terminated": self.count > self.max_steps, |
| 1182 | + }, |
| 1183 | + batch_size=self.batch_size, |
| 1184 | + device=self.device, |
| 1185 | + ) |
| 1186 | + |
| 1187 | + tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device) |
| 1188 | + return tensordict |
| 1189 | + |
| 1190 | + def _step( |
| 1191 | + self, |
| 1192 | + tensordict: TensorDictBase, |
| 1193 | + ) -> TensorDictBase: |
| 1194 | + self.count += 1 |
| 1195 | + source = {} |
| 1196 | + for group_name, agents in self.group_map.items(): |
| 1197 | + source[group_name] = {} |
| 1198 | + for agent_name in agents: |
| 1199 | + source[group_name][agent_name] = TensorDict( |
| 1200 | + source={ |
| 1201 | + "observation": torch.rand( |
| 1202 | + (*self.batch_size, 3, 4), device=self.device |
| 1203 | + ), |
| 1204 | + "done": self.count > self.max_steps, |
| 1205 | + "terminated": self.count > self.max_steps, |
| 1206 | + "reward": torch.zeros_like(self.count, dtype=torch.float), |
| 1207 | + }, |
| 1208 | + batch_size=self.batch_size, |
| 1209 | + device=self.device, |
| 1210 | + ) |
| 1211 | + tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device) |
| 1212 | + return tensordict |
| 1213 | + |
| 1214 | + |
1062 | 1215 | class IncrementingEnv(CountingEnv):
|
1063 | 1216 | # Same as CountingEnv but always increments the count by 1 regardless of the action.
|
1064 | 1217 | def _step(
|
|
0 commit comments