12
12
13
13
from ...data import DEVICE_TYPING
14
14
15
- from .models import MLP
15
+ from .models import ConvNet , MLP
16
16
17
17
18
18
class MultiAgentMLP (nn .Module ):
@@ -215,10 +215,10 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
215
215
if self .centralised :
216
216
# If the parameters are shared, and it is centralised, all agents will have the same output
217
217
# We expand it to maintain the agent dimension, but values will be the same for all agents
218
- output = (
219
- output .view ( * output . shape [: - 1 ], self . n_agent_outputs )
220
- . unsqueeze ( - 2 )
221
- . expand ( * output .shape [:- 1 ], self .n_agents , self .n_agent_outputs )
218
+ output = output . view ( * output . shape [: - 1 ], self . n_agent_outputs )
219
+ output = output .unsqueeze ( - 2 )
220
+ output = output . expand (
221
+ * output .shape [:- 2 ], self .n_agents , self .n_agent_outputs
222
222
)
223
223
224
224
if output .shape [- 2 :] != (self .n_agents , self .n_agent_outputs ):
@@ -230,6 +230,228 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
230
230
return output
231
231
232
232
233
+ class MultiAgentConvNet (nn .Module ):
234
+ """Multi-agent CNN.
235
+
236
+ In MARL settings, agents may or may not share the same policy for their actions: we say that the parameters can be shared or not. Similarly, a network may take the entire observation space (across agents) or on a per-agent basis to compute its output, which we refer to as "centralized" and "non-centralized", respectively.
237
+
238
+ It expects inputs with shape ``(*B, n_agents, channels, x, y)``.
239
+
240
+ Args:
241
+ n_agents (int): number of agents.
242
+ centralised (bool): If ``True``, each agent will use the inputs of all agents to compute its output, resulting in input of shape ``(*B, n_agents * channels, x, y)``. Otherwise, each agent will only use its data as input.
243
+ share_params (bool): If ``True``, the same :class:`~torchrl.modules.ConvNet` will be used to make the forward pass
244
+ for all agents (homogeneous policies). Otherwise, each agent will use a different :class:`~torchrl.modules.ConvNet` to process
245
+ its input (heterogeneous policies).
246
+ device (str or torch.device, optional): device to create the module on.
247
+ num_cells (int or Sequence[int], optional): number of cells of every layer in between the input and output. If
248
+ an integer is provided, every layer will have the same number of cells. If an iterable is provided,
249
+ the linear layers ``out_features`` will match the content of ``num_cells``.
250
+ kernel_sizes (int, Sequence[Union[int, Sequence[int]]]): Kernel size(s) of the convolutional network.
251
+ Defaults to ``5``.
252
+ strides (int or Sequence[int]): Stride(s) of the convolutional network. If iterable, the length must match the
253
+ depth, defined by the num_cells or depth arguments.
254
+ Defaults to ``2``.
255
+ activation_class (Type[nn.Module]): activation class to be used.
256
+ Default to :class:`torch.nn.ELU`.
257
+ **kwargs: for :class:`~torchrl.modules.models.ConvNet` can be passed to customize the ConvNet.
258
+
259
+
260
+ Examples:
261
+ >>> import torch
262
+ >>> from torchrl.modules import MultiAgentConvNet
263
+ >>> batch = (3,2)
264
+ >>> n_agents = 7
265
+ >>> channels, x, y = 3, 100, 100
266
+ >>> obs = torch.randn(*batch, n_agents, channels, x, y)
267
+ >>> # First lets consider a centralised network with shared parameters.
268
+ >>> cnn = MultiAgentConvNet(
269
+ ... n_agents,
270
+ ... centralised = True,
271
+ ... share_params = True
272
+ ... )
273
+ >>> print(cnn)
274
+ MultiAgentConvNet(
275
+ (agent_networks): ModuleList(
276
+ (0): ConvNet(
277
+ (0): LazyConv2d(0, 32, kernel_size=(5, 5), stride=(2, 2))
278
+ (1): ELU(alpha=1.0)
279
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
280
+ (3): ELU(alpha=1.0)
281
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
282
+ (5): ELU(alpha=1.0)
283
+ (6): SquashDims()
284
+ )
285
+ )
286
+ )
287
+ >>> result = cnn(obs)
288
+ >>> # The final dimension of the resulting tensor would be determined based on the layer definition arguments and the shape of input 'obs'.
289
+ >>> print(result.shape)
290
+ torch.Size([3, 2, 7, 2592])
291
+ >>> # Since both observations and parameters are shared, we expect all agents to have identical outputs (eg. for a value function)
292
+ >>> print(all(result[0,0,0] == result[0,0,1]))
293
+ True
294
+
295
+ >>> # Alternatively, a local network with parameter sharing (eg. decentralised weight sharing policy)
296
+ >>> cnn = MultiAgentConvNet(
297
+ ... n_agents,
298
+ ... centralised = False,
299
+ ... share_params = True
300
+ ... )
301
+ >>> print(cnn)
302
+ MultiAgentConvNet(
303
+ (agent_networks): ModuleList(
304
+ (0): ConvNet(
305
+ (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
306
+ (1): ELU(alpha=1.0)
307
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
308
+ (3): ELU(alpha=1.0)
309
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
310
+ (5): ELU(alpha=1.0)
311
+ (6): SquashDims()
312
+ )
313
+ )
314
+ )
315
+ >>> print(result.shape)
316
+ torch.Size([3, 2, 7, 2592])
317
+ >>> # Parameters are shared but not observations, hence each agent has a different output.
318
+ >>> print(all(result[0,0,0] == result[0,0,1]))
319
+ False
320
+
321
+ >>> # Or multiple local networks identical in structure but with differing weights.
322
+ >>> cnn = MultiAgentConvNet(
323
+ ... n_agents,
324
+ ... centralised = False,
325
+ ... share_params = False
326
+ ... )
327
+ >>> print(cnn)
328
+ MultiAgentConvNet(
329
+ (agent_networks): ModuleList(
330
+ (0-6): 7 x ConvNet(
331
+ (0): Conv2d(4, 32, kernel_size=(5, 5), stride=(2, 2))
332
+ (1): ELU(alpha=1.0)
333
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
334
+ (3): ELU(alpha=1.0)
335
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
336
+ (5): ELU(alpha=1.0)
337
+ (6): SquashDims()
338
+ )
339
+ )
340
+ )
341
+ >>> print(result.shape)
342
+ torch.Size([3, 2, 7, 2592])
343
+ >>> print(all(result[0,0,0] == result[0,0,1]))
344
+ False
345
+
346
+ >>> # Or where inputs are shared but not parameters.
347
+ >>> cnn = MultiAgentConvNet(
348
+ ... n_agents,
349
+ ... centralised = True,
350
+ ... share_params = False
351
+ ... )
352
+ >>> print(cnn)
353
+ MultiAgentConvNet(
354
+ (agent_networks): ModuleList(
355
+ (0-6): 7 x ConvNet(
356
+ (0): Conv2d(28, 32, kernel_size=(5, 5), stride=(2, 2))
357
+ (1): ELU(alpha=1.0)
358
+ (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
359
+ (3): ELU(alpha=1.0)
360
+ (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(2, 2))
361
+ (5): ELU(alpha=1.0)
362
+ (6): SquashDims()
363
+ )
364
+ )
365
+ )
366
+ >>> print(result.shape)
367
+ torch.Size([3, 2, 7, 2592])
368
+ >>> print(all(result[0,0,0] == result[0,0,1]))
369
+ False
370
+ """
371
+
372
+ def __init__ (
373
+ self ,
374
+ n_agents : int ,
375
+ centralised : bool ,
376
+ share_params : bool ,
377
+ device : Optional [DEVICE_TYPING ] = None ,
378
+ num_cells : Optional [Sequence [int ]] = None ,
379
+ kernel_sizes : Union [Sequence [Union [int , Sequence [int ]]], int ] = 5 ,
380
+ strides : Union [Sequence , int ] = 2 ,
381
+ paddings : Union [Sequence , int ] = 0 ,
382
+ activation_class : Type [nn .Module ] = nn .ELU ,
383
+ ** kwargs ,
384
+ ):
385
+ super ().__init__ ()
386
+
387
+ self .n_agents = n_agents
388
+ self .centralised = centralised
389
+ self .share_params = share_params
390
+
391
+ self .agent_networks = nn .ModuleList (
392
+ [
393
+ ConvNet (
394
+ num_cells = num_cells ,
395
+ kernel_sizes = kernel_sizes ,
396
+ strides = strides ,
397
+ paddings = paddings ,
398
+ activation_class = activation_class ,
399
+ device = device ,
400
+ ** kwargs ,
401
+ )
402
+ for _ in range (self .n_agents if not self .share_params else 1 )
403
+ ]
404
+ )
405
+
406
+ def forward (self , inputs : torch .Tensor ):
407
+ if len (inputs .shape ) < 4 :
408
+ raise ValueError (
409
+ """Multi-agent network expects (*batch_size, agent_index, x, y, channels)"""
410
+ )
411
+ if inputs .shape [- 4 ] != self .n_agents :
412
+ raise ValueError (
413
+ f"""Multi-agent network expects { self .n_agents } but got { inputs .shape [- 4 ]} """
414
+ )
415
+ # If the model is centralized, agents have full observability
416
+ if self .centralised :
417
+ shape = (
418
+ * inputs .shape [:- 4 ],
419
+ self .n_agents * inputs .shape [- 3 ],
420
+ inputs .shape [- 2 ],
421
+ inputs .shape [- 1 ],
422
+ )
423
+ inputs = torch .reshape (inputs , shape )
424
+
425
+ # If the parameters are not shared, each agent has its own network
426
+ if not self .share_params :
427
+ if self .centralised :
428
+ output = torch .stack (
429
+ [net (inputs ) for net in self .agent_networks ], dim = - 2
430
+ )
431
+ else :
432
+ output = torch .stack (
433
+ [
434
+ net (inp )
435
+ for i , (net , inp ) in enumerate (
436
+ zip (self .agent_networks , inputs .unbind (- 4 ))
437
+ )
438
+ ],
439
+ dim = - 2 ,
440
+ )
441
+ else :
442
+ output = self .agent_networks [0 ](inputs )
443
+ if self .centralised :
444
+ # If the parameters are shared, and it is centralised all agents will have the same output.
445
+ # We expand it to maintain the agent dimension, but values will be the same for all agents
446
+ n_agent_outputs = output .shape [- 1 ]
447
+ output = output .view (* output .shape [:- 1 ], n_agent_outputs )
448
+ output = output .unsqueeze (- 2 )
449
+ output = output .expand (
450
+ * output .shape [:- 2 ], self .n_agents , n_agent_outputs
451
+ )
452
+ return output
453
+
454
+
233
455
class Mixer (nn .Module ):
234
456
"""A multi-agent value mixer.
235
457
0 commit comments