From 366f9c6bff1ecc3675362a90810712f84ac6d54e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 27 Jul 2023 14:52:25 +0100 Subject: [PATCH 1/3] qmix Signed-off-by: Matteo Bettini --- examples/multiagent/qmix_vdn.py | 26 ++++++- examples/multiagent/qmix_vdn.yaml | 2 +- torchrl/modules/models/__init__.py | 4 +- torchrl/modules/models/models.py | 54 ++++++++++++++ torchrl/modules/models/multiagent.py | 101 ++++++++++++++++++++++++++- 5 files changed, 183 insertions(+), 4 deletions(-) diff --git a/examples/multiagent/qmix_vdn.py b/examples/multiagent/qmix_vdn.py index 55c5ef012ba..14b6634e95a 100644 --- a/examples/multiagent/qmix_vdn.py +++ b/examples/multiagent/qmix_vdn.py @@ -18,7 +18,7 @@ from torchrl.envs.libs.vmas import VmasEnv from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential -from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer, VDNMixer +from torchrl.modules.models.multiagent import MultiAgentMLP, QGNNMixer, QMixer, VDNMixer from torchrl.objectives import SoftUpdate, ValueEstimators from torchrl.objectives.multiagent.qmixer import QMixerLoss from utils.logging import init_logging, log_evaluation, log_training @@ -127,6 +127,30 @@ def train(cfg: "DictConfig"): # noqa: F821 in_keys=[("agents", "chosen_action_value")], out_keys=["chosen_action_value"], ) + elif cfg.loss.mixer_type == "qgnn": + mixer = TensorDictModule( + module=QGNNMixer( + use_state=False, + n_agents=env.n_agents, + device=cfg.train.device, + ), + in_keys=[("agents", "chosen_action_value")], + out_keys=["chosen_action_value"], + ) + elif cfg.loss.mixer_type == "qgnn-state": + mixer = TensorDictModule( + module=QGNNMixer( + state_shape=env.unbatched_observation_spec[ + "agents", "observation" + ].shape, + mixing_embed_dim=8, + use_state=True, + n_agents=env.n_agents, + device=cfg.train.device, + ), + in_keys=[("agents", "chosen_action_value"), ("agents", "observation")], + out_keys=["chosen_action_value"], + ) else: raise ValueError("Mixer type not in the example") diff --git a/examples/multiagent/qmix_vdn.yaml b/examples/multiagent/qmix_vdn.yaml index a78b3987ffb..ebbf1e6828a 100644 --- a/examples/multiagent/qmix_vdn.yaml +++ b/examples/multiagent/qmix_vdn.yaml @@ -20,7 +20,7 @@ buffer: memory_size: ??? loss: - mixer_type: "qmix" # or "vdn" + mixer_type: "qgnn-state" # choose from "qmix", "vdn", "qgnn", "qgnn-state" gamma: 0.9 tau: 0.005 # For target net diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 8e5d0c2f9c9..afa02cf965b 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -7,6 +7,7 @@ from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior from .models import ( + AbsLinear, ConvNet, DdpgCnnActor, DdpgCnnQNet, @@ -14,8 +15,9 @@ DdpgMlpQNet, DistributionalDQNnet, DuelingCnnDQNet, + HyperLinear, LSTMNet, MLP, ) -from .multiagent import MultiAgentMLP, QMixer, VDNMixer +from .multiagent import MultiAgentMLP, QGNNMixer, QMixer, VDNMixer from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index a72af43aa13..898ba66d3b2 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -1138,3 +1138,57 @@ def forward( input = self.mlp(input) return self._lstm(input, hidden0_in, hidden1_in) + + +class HyperLinear(nn.Module): + """Missing.""" + + def __init__(self, in_dim, out_dim, pos=True, **kwargs): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.pos = pos + self.w = None + self.b = None + + def num_params(self): + return self.in_dim * self.out_dim + self.out_dim + + def update_params(self, params): + # params: b x (in_dim * out_dim + out_dim) + assert params.shape[1] == self.in_dim * self.out_dim + self.out_dim + batch = params.shape[0] + self.w = params[:, : self.in_dim * self.out_dim].view( + batch, self.in_dim, self.out_dim + ) + self.b = params[:, self.in_dim * self.out_dim :].view(batch, self.out_dim) + if self.pos: + self.w = torch.abs(self.w) + + def forward(self, x): + # x: b x in_dim OR b x n x in_dim + w = self.w + b = self.b + assert x.shape[0] == w.shape[0] + assert x.shape[-1] == w.shape[1] + squeeze_output = False + if x.dim() == 2: + squeeze_output = True + x = x.unsqueeze(1) + if b.dim() == 3: + b = b.squeeze(1) + xw = torch.bmm(x, w) + out = xw + b[:, None] + if squeeze_output: + out = out.squeeze(1) + return out + + +class AbsLinear(nn.Linear): + """Missing.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + return nn.functional.linear(input, torch.abs(self.weight), self.bias) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index de565b336d2..0ffae9d126a 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -12,7 +12,7 @@ from ...data import DEVICE_TYPING -from .models import MLP +from .models import AbsLinear, HyperLinear, MLP class MultiAgentMLP(nn.Module): @@ -598,3 +598,102 @@ def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): # Reshape and return q_tot = y.view(*bs, 1) return q_tot + + +class QGNNMixer(Mixer): + """QGNN Mixer. + + From https://arxiv.org/abs/2205.13005 + + """ + + def __init__( + self, + n_agents: int, + device, + mixing_embed_dim=8, + state_shape=None, + use_state=False, + ): + super().__init__( + needs_state=use_state, + state_shape=state_shape if use_state else torch.Size([]), + n_agents=n_agents, + device=device, + ) + + self.use_state = use_state + self.embed_dim = mixing_embed_dim + self.state_dim = int(np.prod(state_shape)) if self.use_state else None + + self.psi_hyper = MLP( + in_features=1, + out_features=self.embed_dim, + depth=3, + num_cells=self.embed_dim, + activation_class=nn.ReLU, + activate_last_layer=False, + layer_class=HyperLinear if self.use_state else AbsLinear, + layer_kwargs={"pos": True} if self.use_state else {}, + device=device, + ) + + self.phi_hyper = MLP( + in_features=self.embed_dim, + out_features=1, + depth=3, + num_cells=self.embed_dim, + activation_class=nn.ReLU, + activate_last_layer=False, + layer_class=HyperLinear if self.use_state else AbsLinear, + layer_kwargs={"pos": True} if self.use_state else {}, + device=device, + ) + + if self.use_state: + self.psi_param_net = MLP( + in_features=self.state_dim, + out_features=self.num_params(self.psi_hyper), + depth=2, + num_cells=self.state_dim, + activation_class=nn.Mish, + activate_last_layer=False, + device=device, + ) + self.phi_param_net = MLP( + in_features=self.state_dim, + out_features=self.num_params(self.phi_hyper), + depth=2, + num_cells=self.state_dim, + activation_class=nn.Mish, + activate_last_layer=False, + device=device, + ) + + def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor): + if self.use_state: + state = state.view(-1, self.state_dim) + psi_params = self.psi_param_net(state) + phi_params = self.phi_param_net(state) + self.update_params(self.psi_hyper, psi_params) + self.update_params(self.phi_hyper, phi_params) + psi_out = self.psi_hyper(chosen_action_value) + summed = psi_out.sum(dim=-2) + phi_out = self.phi_hyper(summed) + return phi_out + + def num_params(self, net): + num_params = 0 + for layer in net: + if isinstance(layer, HyperLinear): + num_params += layer.num_params() + return num_params + + def update_params(self, net, params): + i = 0 + for layer in net: + if isinstance(layer, HyperLinear): + layer_num_params = layer.num_params() + layer_params = params[:, i : i + layer_num_params] + i += layer_num_params + layer.update_params(layer_params) From 7f1f1ffa54c67a8a17988ba088efd575d848d5a9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 27 Jul 2023 14:52:47 +0100 Subject: [PATCH 2/3] amend Signed-off-by: Matteo Bettini --- examples/multiagent/qmix_vdn.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/multiagent/qmix_vdn.yaml b/examples/multiagent/qmix_vdn.yaml index ebbf1e6828a..7fe197cc1fb 100644 --- a/examples/multiagent/qmix_vdn.yaml +++ b/examples/multiagent/qmix_vdn.yaml @@ -20,7 +20,7 @@ buffer: memory_size: ??? loss: - mixer_type: "qgnn-state" # choose from "qmix", "vdn", "qgnn", "qgnn-state" + mixer_type: "qmix" # choose from "qmix", "vdn", "qgnn", "qgnn-state" gamma: 0.9 tau: 0.005 # For target net From 314c069a8bc8f7fcb09651b49188c7e9fe2cc59a Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 6 Feb 2024 11:10:40 +0000 Subject: [PATCH 3/3] amend --- torchrl/modules/models/multiagent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index fe6f796db02..565ea5827d1 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -12,8 +12,7 @@ from torchrl.data.utils import DEVICE_TYPING -from .models import AbsLinear, HyperLinear, MLP -from torchrl.modules.models import ConvNet, MLP +from torchrl.modules.models import AbsLinear, ConvNet, HyperLinear, MLP class MultiAgentMLP(nn.Module):