|
| 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import Optional |
| 16 | + |
| 17 | +import torch |
| 18 | + |
| 19 | +from ..utils import logging |
| 20 | +from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU |
| 21 | +from .attention import FeedForward |
| 22 | + |
| 23 | + |
| 24 | +logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 25 | + |
| 26 | + |
| 27 | +class _MemoryOptimizedFeedForward(torch.nn.Module): |
| 28 | + r""" |
| 29 | + See [`~models.attention.FeedForward`] parameter documentation. This class is a copy of the FeedForward class. The |
| 30 | + only difference is that this module is optimized for memory. |
| 31 | +
|
| 32 | + This method achieves memory savings by applying the ideas of tensor-parallelism sequentially. Input projection |
| 33 | + layers are split column-wise and output projection layers are split row-wise. This allows for the computation of |
| 34 | + the feedforward pass to occur without ever materializing the full intermediate tensor. Typically, the intermediate |
| 35 | + tensor takes 4x-8x more memory than the input tensor. This method reduces that with a small performance tradeoff. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + dim: int, |
| 41 | + dim_out: Optional[int] = None, |
| 42 | + mult: int = 4, |
| 43 | + dropout: float = 0.0, |
| 44 | + activation_fn: str = "geglu", |
| 45 | + final_dropout: bool = False, |
| 46 | + inner_dim: Optional[int] = None, |
| 47 | + bias: bool = True, |
| 48 | + num_splits: int = 4, |
| 49 | + ) -> None: |
| 50 | + super().__init__() |
| 51 | + |
| 52 | + if inner_dim is None: |
| 53 | + inner_dim = int(dim * mult) |
| 54 | + |
| 55 | + dim_out = dim_out if dim_out is not None else dim |
| 56 | + |
| 57 | + dim_split = inner_dim // num_splits |
| 58 | + if inner_dim % dim_split != 0: |
| 59 | + raise ValueError(f"inner_dim must be divisible by {mult=}, or {num_splits=} if provided.") |
| 60 | + |
| 61 | + self._dim = dim |
| 62 | + self._dim_out = dim_out |
| 63 | + self._mult = mult |
| 64 | + self._dropout = dropout |
| 65 | + self._activation_fn = activation_fn |
| 66 | + self._final_dropout = final_dropout |
| 67 | + self._inner_dim = inner_dim |
| 68 | + self._bias = bias |
| 69 | + self._num_splits = num_splits |
| 70 | + |
| 71 | + def get_activation_fn(dim_: int, inner_dim_: int): |
| 72 | + if activation_fn == "gelu": |
| 73 | + act_fn = GELU(dim_, inner_dim_, bias=bias) |
| 74 | + if activation_fn == "gelu-approximate": |
| 75 | + act_fn = GELU(dim_, inner_dim_, approximate="tanh", bias=bias) |
| 76 | + elif activation_fn == "geglu": |
| 77 | + act_fn = GEGLU(dim_, inner_dim_, bias=bias) |
| 78 | + elif activation_fn == "geglu-approximate": |
| 79 | + act_fn = ApproximateGELU(dim_, inner_dim_, bias=bias) |
| 80 | + elif activation_fn == "swiglu": |
| 81 | + act_fn = SwiGLU(dim_, inner_dim_, bias=bias) |
| 82 | + elif activation_fn == "linear-silu": |
| 83 | + act_fn = LinearActivation(dim_, inner_dim_, bias=bias, activation="silu") |
| 84 | + return act_fn |
| 85 | + |
| 86 | + # Split column-wise |
| 87 | + self.proj_in = torch.nn.ModuleList([get_activation_fn(dim, dim_split) for _ in range(inner_dim // dim_split)]) |
| 88 | + |
| 89 | + self.dropout = torch.nn.Dropout(dropout) |
| 90 | + |
| 91 | + # Split row-wise |
| 92 | + self.proj_out = torch.nn.ModuleList( |
| 93 | + [torch.nn.Linear(dim_split, dim_out, bias=False) for _ in range(inner_dim // dim_split)] |
| 94 | + ) |
| 95 | + |
| 96 | + self.bias = None |
| 97 | + if bias: |
| 98 | + self.bias = torch.nn.Parameter(torch.zeros(dim_out)) |
| 99 | + |
| 100 | + self.final_dropout = None |
| 101 | + if final_dropout: |
| 102 | + self.final_dropout = torch.nn.Dropout(dropout) |
| 103 | + |
| 104 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 105 | + # Output tensor for "all_reduce" operation |
| 106 | + output = hidden_states.new_zeros(hidden_states.shape) |
| 107 | + |
| 108 | + # Apply feedforward pass sequentially since this is intended for memory optimization on a single GPU |
| 109 | + for proj_in, proj_out in zip(self.proj_in, self.proj_out): |
| 110 | + out = proj_in(hidden_states) |
| 111 | + out = self.dropout(out) |
| 112 | + out = proj_out(out) |
| 113 | + # Perform "all_reduce" |
| 114 | + output += out |
| 115 | + |
| 116 | + if self.bias is not None: |
| 117 | + output += self.bias |
| 118 | + if self.final_dropout is not None: |
| 119 | + output = self.final_dropout(output) |
| 120 | + |
| 121 | + return output |
| 122 | + |
| 123 | + |
| 124 | +def apply_memory_optimized_feedforward(module: torch.nn.Module, num_splits: Optional[int] = None) -> torch.nn.Module: |
| 125 | + module_dict = dict(module.named_modules()) |
| 126 | + |
| 127 | + for name, submodule in module_dict.items(): |
| 128 | + if not isinstance(submodule, FeedForward): |
| 129 | + continue |
| 130 | + |
| 131 | + logger.debug(f"Applying memory optimized feedforward to layer '{name}'") |
| 132 | + state_dict = submodule.state_dict() |
| 133 | + num_splits = submodule._mult if num_splits is None else num_splits |
| 134 | + |
| 135 | + # remap net.0.proj.weight |
| 136 | + net_0_proj = state_dict.pop("net.0.proj.weight") |
| 137 | + net_0_proj = net_0_proj.chunk(num_splits, dim=0) |
| 138 | + for i in range(num_splits): |
| 139 | + state_dict[f"proj_in.{i}.proj.weight"] = net_0_proj[i] |
| 140 | + |
| 141 | + # remap net.0.proj.bias |
| 142 | + if "net.0.proj.bias" in state_dict: |
| 143 | + net_0_proj_bias = state_dict.pop("net.0.proj.bias") |
| 144 | + net_0_proj_bias = net_0_proj_bias.chunk(num_splits, dim=0) |
| 145 | + for i in range(num_splits): |
| 146 | + state_dict[f"proj_in.{i}.proj.bias"] = net_0_proj_bias[i] |
| 147 | + |
| 148 | + # remap net.2.weight |
| 149 | + net_2_weight = state_dict.pop("net.2.weight") |
| 150 | + net_2_weight = net_2_weight.chunk(num_splits, dim=1) |
| 151 | + for i in range(num_splits): |
| 152 | + state_dict[f"proj_out.{i}.weight"] = net_2_weight[i] |
| 153 | + |
| 154 | + # remap net.2.bias |
| 155 | + if "net.2.bias" in state_dict: |
| 156 | + net_2_bias = state_dict.pop("net.2.bias") |
| 157 | + state_dict["bias"] = net_2_bias |
| 158 | + |
| 159 | + with torch.device("meta"): |
| 160 | + new_ff = _MemoryOptimizedFeedForward( |
| 161 | + dim=submodule._dim, |
| 162 | + dim_out=submodule._dim_out, |
| 163 | + mult=submodule._mult, |
| 164 | + dropout=submodule._dropout, |
| 165 | + activation_fn=submodule._activation_fn, |
| 166 | + final_dropout=submodule._final_dropout, |
| 167 | + inner_dim=submodule._inner_dim, |
| 168 | + bias=submodule._bias, |
| 169 | + num_splits=num_splits, |
| 170 | + ) |
| 171 | + |
| 172 | + new_ff.load_state_dict(state_dict, strict=True, assign=True) |
| 173 | + |
| 174 | + parent_module_name, _, submodule_name = name.rpartition(".") |
| 175 | + parent_module = module_dict[parent_module_name] |
| 176 | + setattr(parent_module, submodule_name, new_ff) |
| 177 | + |
| 178 | + return module |
0 commit comments