|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +from collections import OrderedDict |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +import torch.nn.functional as F |
| 8 | +import torch_xla.distributed.spmd as xs |
| 9 | +from torch.nn.parameter import Parameter |
| 10 | + |
| 11 | +from vllm.logger import init_logger |
| 12 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 13 | + QKVParallelLinear, |
| 14 | + RowParallelLinear) |
| 15 | + |
| 16 | +logger = init_logger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +class XlaQKVParallelLinear(nn.Module): |
| 20 | + |
| 21 | + def __init__(self, |
| 22 | + qkv_linear: nn.Module, |
| 23 | + mesh: Optional["xs.Mesh"] = None): |
| 24 | + super().__init__() |
| 25 | + assert isinstance(qkv_linear, QKVParallelLinear) |
| 26 | + self.skip_bias_add = qkv_linear.skip_bias_add |
| 27 | + self.return_bias = qkv_linear.return_bias |
| 28 | + assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD." |
| 29 | + |
| 30 | + self.q_weight: Parameter |
| 31 | + self.k_weight: Parameter |
| 32 | + self.v_weight: Parameter |
| 33 | + self.q_bias: Optional[Parameter] |
| 34 | + self.k_bias: Optional[Parameter] |
| 35 | + self.v_bias: Optional[Parameter] |
| 36 | + self._load_weights_from_qkv_linear(qkv_linear) |
| 37 | + if mesh is not None: |
| 38 | + self._shard_weight(mesh) |
| 39 | + |
| 40 | + def _shard_weight(self, mesh: "xs.Mesh"): |
| 41 | + self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False) |
| 42 | + self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False) |
| 43 | + self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False) |
| 44 | + xs.mark_sharding(self.q_weight, mesh, ('x', None)) |
| 45 | + xs.mark_sharding(self.k_weight, mesh, ('x', None)) |
| 46 | + xs.mark_sharding(self.v_weight, mesh, ('x', None)) |
| 47 | + if self.q_bias is not None: |
| 48 | + assert self.k_bias is not None and self.v_bias is not None, \ |
| 49 | + "QKVParallelLinear should have q, k, and v biases together." |
| 50 | + self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False) |
| 51 | + xs.mark_sharding(self.q_bias, mesh, ('x', )) |
| 52 | + self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False) |
| 53 | + xs.mark_sharding(self.k_bias, mesh, ('x', )) |
| 54 | + self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False) |
| 55 | + xs.mark_sharding(self.v_bias, mesh, ('x', )) |
| 56 | + |
| 57 | + def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): |
| 58 | + q_proj_size, k_proj_size, _ = qkv_linear.output_sizes |
| 59 | + # The weight of qkv linear is a concatenation of q, k, and v weights |
| 60 | + # along the output dimension. |
| 61 | + qkv_weight = qkv_linear.weight.data.cpu() |
| 62 | + q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) |
| 63 | + k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size], |
| 64 | + requires_grad=False) |
| 65 | + v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:], |
| 66 | + requires_grad=False) |
| 67 | + self.register_parameter("q_weight", q_weight) |
| 68 | + self.register_parameter("k_weight", k_weight) |
| 69 | + self.register_parameter("v_weight", v_weight) |
| 70 | + |
| 71 | + if qkv_linear.bias is not None: |
| 72 | + q_bias = Parameter(qkv_linear.bias[:q_proj_size], |
| 73 | + requires_grad=False) |
| 74 | + k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size + |
| 75 | + k_proj_size], |
| 76 | + requires_grad=False) |
| 77 | + v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:], |
| 78 | + requires_grad=False) |
| 79 | + self.register_parameter("q_bias", q_bias) |
| 80 | + self.register_parameter("k_bias", k_bias) |
| 81 | + self.register_parameter("v_bias", v_bias) |
| 82 | + else: |
| 83 | + self.register_parameter("q_bias", None) |
| 84 | + self.register_parameter("k_bias", None) |
| 85 | + self.register_parameter("v_bias", None) |
| 86 | + |
| 87 | + def forward(self, input): |
| 88 | + # Same forward functionality as QKVParallelLinear, but doing qkv porj |
| 89 | + # separately. |
| 90 | + q_bias = self.q_bias if not self.skip_bias_add else None |
| 91 | + k_bias = self.k_bias if not self.skip_bias_add else None |
| 92 | + v_bias = self.v_bias if not self.skip_bias_add else None |
| 93 | + q_proj = F.linear(input, self.q_weight, q_bias) |
| 94 | + k_proj = F.linear(input, self.k_weight, k_bias) |
| 95 | + v_proj = F.linear(input, self.v_weight, v_bias) |
| 96 | + # The q/k/v projections will be split outside of the QKVParallelLinear. |
| 97 | + # Because we are replacing XlaQKVParallelLinear with the |
| 98 | + # QKVParallelLinear, we need to concatenate q, k, and v projections to |
| 99 | + # match the output shape of the QKVParallelLinear implementation even if |
| 100 | + # it seems to be redundant. |
| 101 | + # The concat and the following split will be noop, and should be |
| 102 | + # optimized away by the compiler. |
| 103 | + qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) |
| 104 | + output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \ |
| 105 | + self.skip_bias_add else None |
| 106 | + if not self.return_bias: |
| 107 | + return qkv_proj |
| 108 | + return qkv_proj, output_bias |
| 109 | + |
| 110 | + |
| 111 | +def partition_column_parallel_linear(layer: torch.nn.Module, |
| 112 | + mesh: xs.Mesh) -> torch.nn.Module: |
| 113 | + assert isinstance(layer, ColumnParallelLinear) |
| 114 | + xs.mark_sharding(layer.weight, mesh, ('x', None)) |
| 115 | + logger.debug("Applied column-parallel sharding to %s", layer) |
| 116 | + return layer |
| 117 | + |
| 118 | + |
| 119 | +def partition_row_parallel_linear(layer: torch.nn.Module, |
| 120 | + mesh: xs.Mesh) -> torch.nn.Module: |
| 121 | + assert isinstance(layer, RowParallelLinear) |
| 122 | + xs.mark_sharding(layer.weight, mesh, (None, 'x')) |
| 123 | + logger.debug("Applied row-parallel sharding to %s", layer) |
| 124 | + return layer |
| 125 | + |
| 126 | + |
| 127 | +def partition_qkv_parallel_linear(layer: torch.nn.Module, |
| 128 | + mesh: xs.Mesh) -> torch.nn.Module: |
| 129 | + assert isinstance(layer, QKVParallelLinear) |
| 130 | + xla_layer = XlaQKVParallelLinear(layer, mesh) |
| 131 | + logger.debug("Applied qkv parallel sharding to %s", layer) |
| 132 | + return xla_layer |
| 133 | + |
| 134 | + |
| 135 | +MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([ |
| 136 | + ("QKVParallelLinear", partition_qkv_parallel_linear), |
| 137 | + ("ColumnParallelLinear", partition_column_parallel_linear), |
| 138 | + ("RowParallelLinear", partition_row_parallel_linear), |
| 139 | +]) |
| 140 | + |
| 141 | + |
| 142 | +def get_fqn(module): |
| 143 | + # Get the fully qualified name of the module |
| 144 | + return module.__class__.__qualname__ |
| 145 | + |
| 146 | + |
| 147 | +def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: |
| 148 | + """ |
| 149 | + Recursively check a PyTorch model and apply appropriate sharding based on |
| 150 | + the MODULE_TYPE_TO_WRAPPING_FUNC mapping. |
| 151 | + |
| 152 | + Args: |
| 153 | + model: torch.nn.Module to process |
| 154 | + mesh: An XLA SPMD mesh object used for sharding |
| 155 | + """ |
| 156 | + |
| 157 | + def _process_module(module, name=None, parent=None): |
| 158 | + for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items(): |
| 159 | + if get_fqn(module) == module_type: |
| 160 | + wrapped_module = wrapping_func(module, mesh) |
| 161 | + |
| 162 | + assert parent is not None and name is not None, ( |
| 163 | + "Top Level module is not expected to be wrapped.") |
| 164 | + if wrapped_module is not module: |
| 165 | + # Wrapped module and module are different py object. |
| 166 | + # The original module should be replaced by the |
| 167 | + # wrapped_module. |
| 168 | + logger.debug("replace %s with %s", module, wrapped_module) |
| 169 | + setattr(parent, name, wrapped_module) |
| 170 | + |
| 171 | + module = wrapped_module |
| 172 | + break |
| 173 | + |
| 174 | + for child_name, child_module in list(module.named_children()): |
| 175 | + _process_module(child_module, child_name, module) |
| 176 | + |
| 177 | + _process_module(model) |
0 commit comments