|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +from dataclasses import dataclass |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torch import Tensor |
| 12 | +from torch.nn import functional as F |
| 13 | + |
| 14 | + |
| 15 | +def find_multiple(n: int, k: int) -> int: |
| 16 | + if n % k == 0: |
| 17 | + return n |
| 18 | + return n + k - (n % k) |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class ModelArgs: |
| 22 | + block_size: int = 2048 |
| 23 | + vocab_size: int = 32000 |
| 24 | + n_layer: int = 32 |
| 25 | + n_head: int = 32 |
| 26 | + dim: int = 4096 |
| 27 | + intermediate_size: int = None |
| 28 | + n_local_heads: int = -1 |
| 29 | + head_dim: int = 64 |
| 30 | + rope_base: float = 10000 |
| 31 | + norm_eps: float = 1e-5 |
| 32 | + |
| 33 | + def __post_init__(self): |
| 34 | + if self.n_local_heads == -1: |
| 35 | + self.n_local_heads = self.n_head |
| 36 | + if self.intermediate_size is None: |
| 37 | + hidden_dim = 4 * self.dim |
| 38 | + n_hidden = int(2 * hidden_dim / 3) |
| 39 | + self.intermediate_size = find_multiple(n_hidden, 256) |
| 40 | + self.head_dim = self.dim // self.n_head |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def from_name(cls, name: str): |
| 44 | + if name in transformer_configs: |
| 45 | + return cls(**transformer_configs[name]) |
| 46 | + # fuzzy search |
| 47 | + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] |
| 48 | + |
| 49 | + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, |
| 50 | + # take longer name (as it have more symbols matched) |
| 51 | + if len(config) > 1: |
| 52 | + config.sort(key=len, reverse=True) |
| 53 | + assert len(config[0]) != len(config[1]), name # make sure only one 'best' match |
| 54 | + |
| 55 | + return cls(**transformer_configs[config[0]]) |
| 56 | + |
| 57 | + |
| 58 | +transformer_configs = { |
| 59 | + "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), |
| 60 | + "7B": dict(n_layer=32, n_head=32, dim=4096), |
| 61 | + "13B": dict(n_layer=40, n_head=40, dim=5120), |
| 62 | + "30B": dict(n_layer=60, n_head=52, dim=6656), |
| 63 | + "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf |
| 64 | + "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), |
| 65 | + "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), |
| 66 | +} |
| 67 | + |
| 68 | +class KVCache(nn.Module): |
| 69 | + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): |
| 70 | + super().__init__() |
| 71 | + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) |
| 72 | + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) |
| 73 | + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) |
| 74 | + |
| 75 | + def update(self, input_pos, k_val, v_val): |
| 76 | + # input_pos: [S], k_val: [B, H, S, D] |
| 77 | + assert input_pos.shape[0] == k_val.shape[2] |
| 78 | + |
| 79 | + k_out = self.k_cache |
| 80 | + v_out = self.v_cache |
| 81 | + k_out[:, :, input_pos] = k_val |
| 82 | + v_out[:, :, input_pos] = v_val |
| 83 | + |
| 84 | + return k_out, v_out |
| 85 | + |
| 86 | +class Transformer(nn.Module): |
| 87 | + def __init__(self, config: ModelArgs) -> None: |
| 88 | + super().__init__() |
| 89 | + self.config = config |
| 90 | + self.vocab_size = self.config.vocab_size |
| 91 | + |
| 92 | + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) |
| 93 | + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) |
| 94 | + self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
| 95 | + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) |
| 96 | + |
| 97 | + self.freqs_cis: Optional[Tensor] = None |
| 98 | + self.mask_cache: Optional[Tensor] = None |
| 99 | + self.max_batch_size = -1 |
| 100 | + self.max_seq_length = -1 |
| 101 | + |
| 102 | + def setup_caches(self, max_batch_size, max_seq_length): |
| 103 | + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: |
| 104 | + return |
| 105 | + head_dim = self.config.dim // self.config.n_head |
| 106 | + max_seq_length = find_multiple(max_seq_length, 8) |
| 107 | + self.max_seq_length = max_seq_length |
| 108 | + self.max_batch_size = max_batch_size |
| 109 | + for b in self.layers: |
| 110 | + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) |
| 111 | + |
| 112 | + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) |
| 113 | + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) |
| 114 | + |
| 115 | + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: |
| 116 | + assert self.freqs_cis is not None, "Caches must be initialized first" |
| 117 | + mask = self.causal_mask[None, None, input_pos] |
| 118 | + freqs_cis = self.freqs_cis[input_pos] |
| 119 | + x = self.tok_embeddings(idx) |
| 120 | + |
| 121 | + for i, layer in enumerate(self.layers): |
| 122 | + x = layer(x, input_pos, freqs_cis, mask) |
| 123 | + x = self.norm(x) |
| 124 | + logits = self.output(x) |
| 125 | + return logits |
| 126 | + |
| 127 | + @classmethod |
| 128 | + def from_name(cls, name: str): |
| 129 | + return cls(ModelArgs.from_name(name)) |
| 130 | + |
| 131 | + |
| 132 | +class TransformerBlock(nn.Module): |
| 133 | + def __init__(self, config: ModelArgs) -> None: |
| 134 | + super().__init__() |
| 135 | + self.attention = Attention(config) |
| 136 | + self.feed_forward = FeedForward(config) |
| 137 | + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) |
| 138 | + self.attention_norm = RMSNorm(config.dim, config.norm_eps) |
| 139 | + |
| 140 | + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: |
| 141 | + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) |
| 142 | + out = h + self.feed_forward(self.ffn_norm(h)) |
| 143 | + return out |
| 144 | + |
| 145 | + |
| 146 | +class Attention(nn.Module): |
| 147 | + def __init__(self, config: ModelArgs): |
| 148 | + super().__init__() |
| 149 | + assert config.dim % config.n_head == 0 |
| 150 | + |
| 151 | + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim |
| 152 | + # key, query, value projections for all heads, but in a batch |
| 153 | + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) |
| 154 | + self.wo = nn.Linear(config.dim, config.dim, bias=False) |
| 155 | + self.kv_cache = None |
| 156 | + |
| 157 | + self.n_head = config.n_head |
| 158 | + self.head_dim = config.head_dim |
| 159 | + self.n_local_heads = config.n_local_heads |
| 160 | + self.dim = config.dim |
| 161 | + self._register_load_state_dict_pre_hook(self.load_hook) |
| 162 | + |
| 163 | + def load_hook(self, state_dict, prefix, *args): |
| 164 | + if prefix + "wq.weight" in state_dict: |
| 165 | + wq = state_dict.pop(prefix + "wq.weight") |
| 166 | + wk = state_dict.pop(prefix + "wk.weight") |
| 167 | + wv = state_dict.pop(prefix + "wv.weight") |
| 168 | + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) |
| 169 | + |
| 170 | + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: |
| 171 | + bsz, seqlen, _ = x.shape |
| 172 | + |
| 173 | + kv_size = self.n_local_heads * self.head_dim |
| 174 | + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
| 175 | + |
| 176 | + q = q.view(bsz, seqlen, self.n_head, self.head_dim) |
| 177 | + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 178 | + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 179 | + |
| 180 | + q = apply_rotary_emb(q, freqs_cis) |
| 181 | + k = apply_rotary_emb(k, freqs_cis) |
| 182 | + |
| 183 | + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) |
| 184 | + |
| 185 | + if self.kv_cache is not None: |
| 186 | + k, v = self.kv_cache.update(input_pos, k, v) |
| 187 | + |
| 188 | + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) |
| 189 | + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) |
| 190 | + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) |
| 191 | + |
| 192 | + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) |
| 193 | + |
| 194 | + y = self.wo(y) |
| 195 | + return y |
| 196 | + |
| 197 | + |
| 198 | +class FeedForward(nn.Module): |
| 199 | + def __init__(self, config: ModelArgs) -> None: |
| 200 | + super().__init__() |
| 201 | + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) |
| 202 | + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) |
| 203 | + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) |
| 204 | + |
| 205 | + def forward(self, x: Tensor) -> Tensor: |
| 206 | + return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 207 | + |
| 208 | + |
| 209 | +class RMSNorm(nn.Module): |
| 210 | + def __init__(self, dim: int, eps: float = 1e-5): |
| 211 | + super().__init__() |
| 212 | + self.eps = eps |
| 213 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 214 | + |
| 215 | + def _norm(self, x): |
| 216 | + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
| 217 | + |
| 218 | + def forward(self, x: Tensor) -> Tensor: |
| 219 | + output = self._norm(x.float()).type_as(x) |
| 220 | + return output * self.weight |
| 221 | + |
| 222 | + |
| 223 | +def precompute_freqs_cis( |
| 224 | + seq_len: int, n_elem: int, base: int = 10000 |
| 225 | +) -> Tensor: |
| 226 | + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) |
| 227 | + t = torch.arange(seq_len, device=freqs.device) |
| 228 | + freqs = torch.outer(t, freqs) |
| 229 | + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| 230 | + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) |
| 231 | + return cache.to(dtype=torch.bfloat16) |
| 232 | + |
| 233 | + |
| 234 | +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: |
| 235 | + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
| 236 | + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) |
| 237 | + x_out2 = torch.stack( |
| 238 | + [ |
| 239 | + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], |
| 240 | + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], |
| 241 | + ], |
| 242 | + -1, |
| 243 | + ) |
| 244 | + |
| 245 | + x_out2 = x_out2.flatten(3) |
| 246 | + return x_out2.type_as(x) |
0 commit comments