Skip to content

Commit a43612a

Browse files
MarkHaoxiangvmoens
andauthored
[Feature] CNN version of MultiAgentMLP (#1479)
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 3d2c161 commit a43612a

File tree

5 files changed

+283
-6
lines changed

5 files changed

+283
-6
lines changed

docs/source/reference/modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ multi-agent contexts.
350350
:template: rl_template_noinherit.rst
351351

352352
MultiAgentMLP
353+
MultiAgentConvNet
353354
QMixer
354355
VDNMixer
355356

test/test_modules.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
CEMPlanner,
1919
DTActor,
2020
LSTMNet,
21+
MultiAgentConvNet,
2122
MultiAgentMLP,
2223
OnlineDTActor,
2324
QMixer,
@@ -916,6 +917,58 @@ def test_mlp(
916917
# same input different output
917918
assert not torch.allclose(out[..., i, :], out[..., j, :])
918919

920+
@pytest.mark.parametrize("n_agents", [1, 3])
921+
@pytest.mark.parametrize("share_params", [True, False])
922+
@pytest.mark.parametrize("centralised", [True, False])
923+
@pytest.mark.parametrize("batch", [(10,), (10, 3), ()])
924+
def test_cnn(
925+
self, n_agents, centralised, share_params, batch, x=50, y=50, channels=3
926+
):
927+
torch.manual_seed(0)
928+
cnn = MultiAgentConvNet(
929+
n_agents=n_agents, centralised=centralised, share_params=share_params
930+
)
931+
td = TensorDict(
932+
{
933+
"agents": TensorDict(
934+
{"observation": torch.randn(*batch, n_agents, channels, x, y)},
935+
[*batch, n_agents],
936+
)
937+
},
938+
batch_size=batch,
939+
)
940+
obs = td[("agents", "observation")]
941+
out = cnn(obs)
942+
assert out.shape[:-1] == (*batch, n_agents)
943+
for i in range(n_agents):
944+
if centralised and share_params:
945+
assert torch.allclose(out[..., i, :], out[..., 0, :])
946+
else:
947+
for j in range(i + 1, n_agents):
948+
assert not torch.allclose(out[..., i, :], out[..., j, :])
949+
950+
obs[..., 0, 0, 0, 0] += 1
951+
out2 = cnn(obs)
952+
for i in range(n_agents):
953+
if centralised:
954+
# a modification to the input of agent 0 will impact all agents
955+
assert not torch.allclose(out[..., i, :], out2[..., i, :])
956+
elif i > 0:
957+
assert torch.allclose(out[..., i, :], out2[..., i, :])
958+
959+
obs = torch.randn(*batch, 1, channels, x, y).expand(
960+
*batch, n_agents, channels, x, y
961+
)
962+
out = cnn(obs)
963+
for i in range(n_agents):
964+
if share_params:
965+
# same input same output
966+
assert torch.allclose(out[..., i, :], out[..., 0, :])
967+
else:
968+
for j in range(i + 1, n_agents):
969+
# same input different output
970+
assert not torch.allclose(out[..., i, :], out[..., j, :])
971+
919972
@pytest.mark.parametrize("n_agents", [1, 3])
920973
@pytest.mark.parametrize(
921974
"batch",

torchrl/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DuelingCnnDQNet,
3030
LSTMNet,
3131
MLP,
32+
MultiAgentConvNet,
3233
MultiAgentMLP,
3334
NoisyLazyLinear,
3435
NoisyLinear,

torchrl/modules/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@
2222
MLP,
2323
OnlineDTActor,
2424
)
25-
from .multiagent import MultiAgentMLP, QMixer, VDNMixer
25+
from .multiagent import MultiAgentConvNet, MultiAgentMLP, QMixer, VDNMixer
2626
from .utils import Squeeze2dLayer, SqueezeLayer

torchrl/modules/models/multiagent.py

Lines changed: 227 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ...data import DEVICE_TYPING
1414

15-
from .models import MLP
15+
from .models import ConvNet, MLP
1616

1717

1818
class MultiAgentMLP(nn.Module):
@@ -215,10 +215,10 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
215215
if self.centralised:
216216
# If the parameters are shared, and it is centralised, all agents will have the same output
217217
# We expand it to maintain the agent dimension, but values will be the same for all agents
218-
output = (
219-
output.view(*output.shape[:-1], self.n_agent_outputs)
220-
.unsqueeze(-2)
221-
.expand(*output.shape[:-1], self.n_agents, self.n_agent_outputs)
218+
output = output.view(*output.shape[:-1], self.n_agent_outputs)
219+
output = output.unsqueeze(-2)
220+
output = output.expand(
221+
*output.shape[:-2], self.n_agents, self.n_agent_outputs
222222
)
223223

224224
if output.shape[-2:] != (self.n_agents, self.n_agent_outputs):
@@ -230,6 +230,228 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
230230
return output
231231

232232

233+
class MultiAgentConvNet(nn.Module):
234+
"""Multi-agent CNN.
235+
236+
In MARL settings, agents may or may not share the same policy for their actions: we say that the parameters can be shared or not. Similarly, a network may take the entire observation space (across agents) or on a per-agent basis to compute its output, which we refer to as "centralized" and "non-centralized", respectively.
237+
238+
It expects inputs with shape ``(*B, n_agents, channels, x, y)``.
239+
240+
Args:
241+
n_agents (int): number of agents.
242+
centralised (bool): If ``True``, each agent will use the inputs of all agents to compute its output, resulting in input of shape ``(*B, n_agents * channels, x, y)``. Otherwise, each agent will only use its data as input.
243+
share_params (bool): If ``True``, the same :class:`~torchrl.modules.ConvNet` will be used to make the forward pass
244+
for all agents (homogeneous policies). Otherwise, each agent will use a different :class:`~torchrl.modules.ConvNet` to process
245+
its input (heterogeneous policies).
246+
device (str or torch.device, optional): device to create the module on.
247+
num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
248+
an integer is provided, every layer will have the same number of cells. If an iterable is provided,
249+
the linear layers ``out_features`` will match the content of ``num_cells``.
250+
kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the convolutional network.
251+
Defaults to ``5``.
252+
strides (int or Sequence[int]): Stride(s) of the convolutional network. If iterable, the length must match the
253+
depth, defined by the num_cells or depth arguments.
254+
Defaults to ``2``.
255+
activation_class (Type[nn.Module]): activation class to be used.
256+
Default to :class:`torch.nn.ELU`.
257+
**kwargs: for :class:`~torchrl.modules.models.ConvNet` can be passed to customize the ConvNet.
258+
259+
260+
Examples:
261+
>>> import torch
262+
>>> from torchrl.modules import MultiAgentConvNet
263+
>>> batch = (3,2)
264+
>>> n_agents = 7
265+
>>> channels, x, y = 3, 100, 100
266+
>>> obs = torch.randn(*batch, n_agents, channels, x, y)
267+
>>> # First lets consider a centralised network with shared parameters.
268+
>>> cnn = MultiAgentConvNet(
269+
... n_agents,
270+
... centralised = True,
271+
... share_params = True
272+
... )
273+
>>> print(cnn)
274+
MultiAgentConvNet(
275+
(agent_networks): ModuleList(
276+
(0): ConvNet(
277+
(0): LazyConv2d(0, 32, kernel_size=(5, 5), stride=(2, 2))
278+
(1): ELU(alpha=1.0)
279+
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
280+
(3): ELU(alpha=1.0)
281+
(4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
282+
(5): ELU(alpha=1.0)
283+
(6): SquashDims()
284+
)
285+
)
286+
)
287+
>>> result = cnn(obs)
288+
>>> # The final dimension of the resulting tensor would be determined based on the layer definition arguments and the shape of input 'obs'.
289+
>>> print(result.shape)
290+
torch.Size([3, 2, 7, 2592])
291+
>>> # Since both observations and parameters are shared, we expect all agents to have identical outputs (eg. for a value function)
292+
>>> print(all(result[0,0,0] == result[0,0,1]))
293+
True
294+
295+
>>> # Alternatively, a local network with parameter sharing (eg. decentralised weight sharing policy)
296+
>>> cnn = MultiAgentConvNet(
297+
... n_agents,
298+
... centralised = False,
299+
... share_params = True
300+
... )
301+
>>> print(cnn)
302+
MultiAgentConvNet(
303+
(agent_networks): ModuleList(
304+
(0): ConvNet(
305+
(0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
306+
(1): ELU(alpha=1.0)
307+
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
308+
(3): ELU(alpha=1.0)
309+
(4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
310+
(5): ELU(alpha=1.0)
311+
(6): SquashDims()
312+
)
313+
)
314+
)
315+
>>> print(result.shape)
316+
torch.Size([3, 2, 7, 2592])
317+
>>> # Parameters are shared but not observations, hence each agent has a different output.
318+
>>> print(all(result[0,0,0] == result[0,0,1]))
319+
False
320+
321+
>>> # Or multiple local networks identical in structure but with differing weights.
322+
>>> cnn = MultiAgentConvNet(
323+
... n_agents,
324+
... centralised = False,
325+
... share_params = False
326+
... )
327+
>>> print(cnn)
328+
MultiAgentConvNet(
329+
(agent_networks): ModuleList(
330+
(0-6): 7 x ConvNet(
331+
(0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
332+
(1): ELU(alpha=1.0)
333+
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
334+
(3): ELU(alpha=1.0)
335+
(4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
336+
(5): ELU(alpha=1.0)
337+
(6): SquashDims()
338+
)
339+
)
340+
)
341+
>>> print(result.shape)
342+
torch.Size([3, 2, 7, 2592])
343+
>>> print(all(result[0,0,0] == result[0,0,1]))
344+
False
345+
346+
>>> # Or where inputs are shared but not parameters.
347+
>>> cnn = MultiAgentConvNet(
348+
... n_agents,
349+
... centralised = True,
350+
... share_params = False
351+
... )
352+
>>> print(cnn)
353+
MultiAgentConvNet(
354+
(agent_networks): ModuleList(
355+
(0-6): 7 x ConvNet(
356+
(0): Conv2d(28, 32, kernel_size=(5, 5), stride=(2, 2))
357+
(1): ELU(alpha=1.0)
358+
(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
359+
(3): ELU(alpha=1.0)
360+
(4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
361+
(5): ELU(alpha=1.0)
362+
(6): SquashDims()
363+
)
364+
)
365+
)
366+
>>> print(result.shape)
367+
torch.Size([3, 2, 7, 2592])
368+
>>> print(all(result[0,0,0] == result[0,0,1]))
369+
False
370+
"""
371+
372+
def __init__(
373+
self,
374+
n_agents: int,
375+
centralised: bool,
376+
share_params: bool,
377+
device: Optional[DEVICE_TYPING] = None,
378+
num_cells: Optional[Sequence[int]] = None,
379+
kernel_sizes: Union[Sequence[Union[int, Sequence[int]]], int] = 5,
380+
strides: Union[Sequence, int] = 2,
381+
paddings: Union[Sequence, int] = 0,
382+
activation_class: Type[nn.Module] = nn.ELU,
383+
**kwargs,
384+
):
385+
super().__init__()
386+
387+
self.n_agents = n_agents
388+
self.centralised = centralised
389+
self.share_params = share_params
390+
391+
self.agent_networks = nn.ModuleList(
392+
[
393+
ConvNet(
394+
num_cells=num_cells,
395+
kernel_sizes=kernel_sizes,
396+
strides=strides,
397+
paddings=paddings,
398+
activation_class=activation_class,
399+
device=device,
400+
**kwargs,
401+
)
402+
for _ in range(self.n_agents if not self.share_params else 1)
403+
]
404+
)
405+
406+
def forward(self, inputs: torch.Tensor):
407+
if len(inputs.shape) < 4:
408+
raise ValueError(
409+
"""Multi-agent network expects (*batch_size, agent_index, x, y, channels)"""
410+
)
411+
if inputs.shape[-4] != self.n_agents:
412+
raise ValueError(
413+
f"""Multi-agent network expects {self.n_agents} but got {inputs.shape[-4]}"""
414+
)
415+
# If the model is centralized, agents have full observability
416+
if self.centralised:
417+
shape = (
418+
*inputs.shape[:-4],
419+
self.n_agents * inputs.shape[-3],
420+
inputs.shape[-2],
421+
inputs.shape[-1],
422+
)
423+
inputs = torch.reshape(inputs, shape)
424+
425+
# If the parameters are not shared, each agent has its own network
426+
if not self.share_params:
427+
if self.centralised:
428+
output = torch.stack(
429+
[net(inputs) for net in self.agent_networks], dim=-2
430+
)
431+
else:
432+
output = torch.stack(
433+
[
434+
net(inp)
435+
for i, (net, inp) in enumerate(
436+
zip(self.agent_networks, inputs.unbind(-4))
437+
)
438+
],
439+
dim=-2,
440+
)
441+
else:
442+
output = self.agent_networks[0](inputs)
443+
if self.centralised:
444+
# If the parameters are shared, and it is centralised all agents will have the same output.
445+
# We expand it to maintain the agent dimension, but values will be the same for all agents
446+
n_agent_outputs = output.shape[-1]
447+
output = output.view(*output.shape[:-1], n_agent_outputs)
448+
output = output.unsqueeze(-2)
449+
output = output.expand(
450+
*output.shape[:-2], self.n_agents, n_agent_outputs
451+
)
452+
return output
453+
454+
233455
class Mixer(nn.Module):
234456
"""A multi-agent value mixer.
235457

0 commit comments

Comments
 (0)