|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import math |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch import nn |
| 11 | + |
| 12 | +from torchtitan.models.attention import build_attention |
| 13 | + |
| 14 | +from .args import DeepseekV3ModelArgs |
| 15 | + |
| 16 | + |
| 17 | +# Adopted from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 |
| 18 | +def precompute_freqs_cis(args: DeepseekV3ModelArgs) -> torch.Tensor: |
| 19 | + """ |
| 20 | + Precomputes frequency-based complex exponential values for rotary positional embeddings. |
| 21 | +
|
| 22 | + Args: |
| 23 | + args (DeepseekV3ModelArgs): Model arguments containing positional embedding parameters. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + torch.Tensor: Precomputed complex exponential values for positional embeddings. |
| 27 | + """ |
| 28 | + dim = args.qk_rope_head_dim |
| 29 | + seqlen = args.max_seq_len |
| 30 | + beta_fast = args.beta_fast |
| 31 | + beta_slow = args.beta_slow |
| 32 | + base = args.rope_theta |
| 33 | + factor = args.rope_factor |
| 34 | + |
| 35 | + def find_correction_dim(num_rotations, dim, base, max_seq_len): |
| 36 | + """ |
| 37 | + Computes the correction dimension for a given number of rotations in the rotary positional embedding. |
| 38 | +
|
| 39 | + Args: |
| 40 | + num_rotations (float): Number of rotations to compute the correction for. |
| 41 | + dim (int): Dimensionality of the embedding space. |
| 42 | + base (float): Base value for the exponential computation. |
| 43 | + max_seq_len (int): Maximum sequence length. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + float: The correction dimension based on the input parameters. |
| 47 | + """ |
| 48 | + return ( |
| 49 | + dim |
| 50 | + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) |
| 51 | + / (2 * math.log(base)) |
| 52 | + ) |
| 53 | + |
| 54 | + def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): |
| 55 | + """ |
| 56 | + Computes the range of correction dimensions for rotary positional embeddings. |
| 57 | +
|
| 58 | + Args: |
| 59 | + low_rot (float): Lower bound for the number of rotations. |
| 60 | + high_rot (float): Upper bound for the number of rotations. |
| 61 | + dim (int): Dimensionality of the embedding space. |
| 62 | + base (float): Base value for the exponential computation. |
| 63 | + max_seq_len (int): Maximum sequence length. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. |
| 67 | + """ |
| 68 | + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) |
| 69 | + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) |
| 70 | + return max(low, 0), min(high, dim - 1) |
| 71 | + |
| 72 | + def linear_ramp_factor(min, max, dim): |
| 73 | + """ |
| 74 | + Computes a linear ramp function used to smooth values between a minimum and maximum range. |
| 75 | +
|
| 76 | + Args: |
| 77 | + min (float): Minimum value for the ramp function. |
| 78 | + max (float): Maximum value for the ramp function. |
| 79 | + dim (int): Dimensionality of the ramp tensor. |
| 80 | +
|
| 81 | + Returns: |
| 82 | + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, |
| 83 | + clamped to the range [0, 1]. |
| 84 | + """ |
| 85 | + if min == max: |
| 86 | + max += 0.001 |
| 87 | + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) |
| 88 | + ramp_func = torch.clamp(linear_func, 0, 1) |
| 89 | + return ramp_func |
| 90 | + |
| 91 | + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| 92 | + if seqlen > args.original_seq_len: |
| 93 | + low, high = find_correction_range( |
| 94 | + beta_fast, beta_slow, dim, base, args.original_seq_len |
| 95 | + ) |
| 96 | + smooth = 1 - linear_ramp_factor(low, high, dim // 2) |
| 97 | + freqs = freqs / factor * (1 - smooth) + freqs * smooth |
| 98 | + |
| 99 | + t = torch.arange(seqlen) |
| 100 | + freqs = torch.outer(t, freqs) |
| 101 | + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| 102 | + return freqs_cis |
| 103 | + |
| 104 | + |
| 105 | +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| 106 | + """ |
| 107 | + Applies rotary positional embeddings to the input tensor. |
| 108 | +
|
| 109 | + Args: |
| 110 | + x (torch.Tensor): Input tensor with positional embeddings to be applied. |
| 111 | + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. |
| 112 | +
|
| 113 | + Returns: |
| 114 | + torch.Tensor: Tensor with rotary embeddings applied. |
| 115 | + """ |
| 116 | + dtype = x.dtype |
| 117 | + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) |
| 118 | + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) |
| 119 | + y = torch.view_as_real(x * freqs_cis).flatten(3) |
| 120 | + return y.to(dtype) |
| 121 | + |
| 122 | + |
| 123 | +class Attention(nn.Module): |
| 124 | + """ |
| 125 | + Multi-head attention (MLA) module. |
| 126 | + """ |
| 127 | + |
| 128 | + def __init__(self, model_args: DeepseekV3ModelArgs): |
| 129 | + super().__init__() |
| 130 | + self.dim = model_args.dim |
| 131 | + self.n_heads = model_args.n_heads |
| 132 | + self.q_lora_rank = model_args.q_lora_rank |
| 133 | + self.kv_lora_rank = model_args.kv_lora_rank |
| 134 | + self.qk_nope_head_dim = model_args.qk_nope_head_dim |
| 135 | + self.qk_rope_head_dim = model_args.qk_rope_head_dim |
| 136 | + self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim |
| 137 | + self.v_head_dim = model_args.v_head_dim |
| 138 | + |
| 139 | + if self.q_lora_rank == 0: |
| 140 | + self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim) |
| 141 | + else: |
| 142 | + self.wq_a = nn.Linear(self.dim, self.q_lora_rank) |
| 143 | + self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps) |
| 144 | + self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim) |
| 145 | + self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) |
| 146 | + self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps) |
| 147 | + self.wkv_b = nn.Linear( |
| 148 | + self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim) |
| 149 | + ) |
| 150 | + self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim) |
| 151 | + self.softmax_scale = self.qk_head_dim**-0.5 |
| 152 | + |
| 153 | + if model_args.max_seq_len > model_args.original_seq_len: |
| 154 | + mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 |
| 155 | + self.softmax_scale = self.softmax_scale * mscale * mscale |
| 156 | + |
| 157 | + self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type) |
| 158 | + |
| 159 | + def forward( |
| 160 | + self, |
| 161 | + x: torch.Tensor, |
| 162 | + freqs_cis: torch.Tensor, |
| 163 | + ): |
| 164 | + """ |
| 165 | + Forward pass for the Multi-Head Latent Attention (MLA) Layer. |
| 166 | +
|
| 167 | + Args: |
| 168 | + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). |
| 169 | + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. |
| 170 | + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. |
| 171 | +
|
| 172 | + Returns: |
| 173 | + torch.Tensor: Output tensor with the same shape as the input. |
| 174 | + """ |
| 175 | + bsz, seqlen, _ = x.size() |
| 176 | + if self.q_lora_rank == 0: |
| 177 | + q = self.wq(x) # q: (bsz, seqlen, n_heads * qk_head_dim) |
| 178 | + else: |
| 179 | + q = self.wq_b( |
| 180 | + self.q_norm(self.wq_a(x)) |
| 181 | + ) # q: (bsz, seqlen, n_heads * qk_head_dim) |
| 182 | + |
| 183 | + q = q.view( |
| 184 | + bsz, seqlen, self.n_heads, self.qk_head_dim |
| 185 | + ) # q: (bsz, seqlen, n_heads, qk_head_dim) |
| 186 | + q_nope, q_pe = torch.split( |
| 187 | + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 |
| 188 | + ) |
| 189 | + # q_nope: (bsz, seqlen, n_heads, qk_nope_head_dim) |
| 190 | + # q_pe: (bsz, seqlen, n_heads, qk_rope_head_dim) |
| 191 | + q_pe = apply_rotary_emb(q_pe, freqs_cis) |
| 192 | + q = torch.cat([q_nope, q_pe], dim=-1) # q: (bsz, seqlen, n_heads, qk_head_dim) |
| 193 | + |
| 194 | + kv = self.wkv_a(x) # kv: (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) |
| 195 | + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
| 196 | + # kv: (bsz, seqlen, kv_lora_rank) |
| 197 | + # k_pe: (bsz, seqlen, qk_rope_head_dim) |
| 198 | + k_pe = apply_rotary_emb( |
| 199 | + k_pe.unsqueeze(2), freqs_cis |
| 200 | + ) # k_pe: (bsz, seqlen, 1, qk_rope_head_dim) |
| 201 | + |
| 202 | + kv = self.wkv_b( |
| 203 | + self.kv_norm(kv) |
| 204 | + ) # kv: (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) |
| 205 | + kv = kv.view( |
| 206 | + bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim |
| 207 | + ) # (bsz, seqlen, n_heads, qk_nope_head_dim + v_head_dim) |
| 208 | + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) |
| 209 | + # k_nope: (bsz, seqlen, n_heads, qk_nope_head_dim) |
| 210 | + # v: (bsz, seqlen, n_heads, v_head_dim) |
| 211 | + k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1) |
| 212 | + # k: (bsz, seqlen, n_heads, qk_head_dim) |
| 213 | + |
| 214 | + # TODO: Need to pass softmax_scale to sdpa() interface. |
| 215 | + # For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa |
| 216 | + # https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17 |
| 217 | + output = self.sdpa(q, k, v) |
| 218 | + |
| 219 | + output = output.transpose(1, 2).contiguous() |
| 220 | + output = output.view(bsz, seqlen, -1) # (bs, seqlen, n_heads * v_head_dim) |
| 221 | + return self.wo(output) # (bsz, seqlen, dim) |
0 commit comments