Skip to content

Commit e7a0af1

Browse files
committed
dsv3 attention block
1 parent a91a784 commit e7a0af1

File tree

2 files changed

+323
-0
lines changed

2 files changed

+323
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
10+
from dataclasses import dataclass
11+
from typing import Literal
12+
13+
from torch import nn
14+
15+
from torchtitan.components.tokenizer import Tokenizer
16+
from torchtitan.config_manager import JobConfig
17+
from torchtitan.protocols.train_spec import BaseModelArgs
18+
19+
20+
# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
21+
@dataclass
22+
class DeepseekV3ModelArgs(BaseModelArgs):
23+
"""
24+
Data class for defining model arguments and hyperparameters.
25+
26+
Attributes:
27+
max_batch_size (int): Maximum batch size.
28+
max_seq_len (int): Maximum sequence length.
29+
dtype (Literal["bf16", "fp8"]): Data type for computations.
30+
vocab_size (int): Vocabulary size.
31+
dim (int): Model dimension.
32+
inter_dim (int): Intermediate dimension for MLP layers.
33+
moe_inter_dim (int): Intermediate dimension for MoE layers.
34+
n_layers (int): Number of transformer layers.
35+
n_dense_layers (int): Number of dense layers in the model.
36+
n_heads (int): Number of attention heads.
37+
n_routed_experts (int): Number of routed experts for MoE layers.
38+
n_shared_experts (int): Number of shared experts for MoE layers.
39+
n_activated_experts (int): Number of activated experts in MoE layers.
40+
n_expert_groups (int): Number of expert groups.
41+
n_limited_groups (int): Number of limited groups for MoE routing.
42+
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
43+
route_scale (float): Scaling factor for routing scores.
44+
q_lora_rank (int): LoRA rank for query projections.
45+
kv_lora_rank (int): LoRA rank for key-value projections.
46+
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
47+
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
48+
v_head_dim (int): Dimension for value projections.
49+
original_seq_len (int): Original sequence length.
50+
rope_theta (float): Base for rotary positional encoding.
51+
rope_factor (float): Scaling factor for extended sequence lengths.
52+
beta_fast (int): Fast beta correction factor.
53+
beta_slow (int): Slow beta correction factor.
54+
mscale (float): Scaling factor for extended attention.
55+
"""
56+
57+
max_batch_size: int = 8
58+
max_seq_len: int = 4096 * 4
59+
dtype: Literal["bf16", "fp8"] = "bf16"
60+
vocab_size: int = 102400
61+
dim: int = 2048
62+
inter_dim: int = 10944
63+
moe_inter_dim: int = 1408
64+
n_layers: int = 27
65+
n_dense_layers: int = 1
66+
n_heads: int = 16
67+
norm_eps: float = 1e-5 # eps used for RMSNorm
68+
# MoE
69+
n_routed_experts: int = 64
70+
n_shared_experts: int = 2
71+
n_activated_experts: int = 6
72+
n_expert_groups: int = 1
73+
n_limited_groups: int = 1
74+
score_func: Literal["softmax", "sigmoid"] = "softmax"
75+
route_scale: float = 1.0
76+
# Multi-Head Latent Attention (MLA)
77+
q_lora_rank: int = 0
78+
kv_lora_rank: int = 512
79+
qk_nope_head_dim: int = 128
80+
qk_rope_head_dim: int = 64
81+
v_head_dim: int = 128
82+
use_flex_attn: bool = False
83+
attn_mask_type: str = "causal"
84+
# yarn
85+
original_seq_len: int = 4096
86+
rope_theta: float = 10000.0
87+
rope_factor: float = 40
88+
beta_fast: int = 32
89+
beta_slow: int = 1
90+
mscale: float = 1.0
91+
92+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
93+
"""
94+
TODO: Placeholder for now
95+
"""
96+
pass
97+
98+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
99+
"""
100+
TODO: Placeholder for now
101+
"""
102+
return 0, 0
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
9+
import torch
10+
from torch import nn
11+
12+
from torchtitan.models.attention import build_attention
13+
14+
from .args import DeepseekV3ModelArgs
15+
16+
17+
# Adopted from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
18+
def precompute_freqs_cis(args: DeepseekV3ModelArgs) -> torch.Tensor:
19+
"""
20+
Precomputes frequency-based complex exponential values for rotary positional embeddings.
21+
22+
Args:
23+
args (DeepseekV3ModelArgs): Model arguments containing positional embedding parameters.
24+
25+
Returns:
26+
torch.Tensor: Precomputed complex exponential values for positional embeddings.
27+
"""
28+
dim = args.qk_rope_head_dim
29+
seqlen = args.max_seq_len
30+
beta_fast = args.beta_fast
31+
beta_slow = args.beta_slow
32+
base = args.rope_theta
33+
factor = args.rope_factor
34+
35+
def find_correction_dim(num_rotations, dim, base, max_seq_len):
36+
"""
37+
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
38+
39+
Args:
40+
num_rotations (float): Number of rotations to compute the correction for.
41+
dim (int): Dimensionality of the embedding space.
42+
base (float): Base value for the exponential computation.
43+
max_seq_len (int): Maximum sequence length.
44+
45+
Returns:
46+
float: The correction dimension based on the input parameters.
47+
"""
48+
return (
49+
dim
50+
* math.log(max_seq_len / (num_rotations * 2 * math.pi))
51+
/ (2 * math.log(base))
52+
)
53+
54+
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
55+
"""
56+
Computes the range of correction dimensions for rotary positional embeddings.
57+
58+
Args:
59+
low_rot (float): Lower bound for the number of rotations.
60+
high_rot (float): Upper bound for the number of rotations.
61+
dim (int): Dimensionality of the embedding space.
62+
base (float): Base value for the exponential computation.
63+
max_seq_len (int): Maximum sequence length.
64+
65+
Returns:
66+
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
67+
"""
68+
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
69+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
70+
return max(low, 0), min(high, dim - 1)
71+
72+
def linear_ramp_factor(min, max, dim):
73+
"""
74+
Computes a linear ramp function used to smooth values between a minimum and maximum range.
75+
76+
Args:
77+
min (float): Minimum value for the ramp function.
78+
max (float): Maximum value for the ramp function.
79+
dim (int): Dimensionality of the ramp tensor.
80+
81+
Returns:
82+
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
83+
clamped to the range [0, 1].
84+
"""
85+
if min == max:
86+
max += 0.001
87+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
88+
ramp_func = torch.clamp(linear_func, 0, 1)
89+
return ramp_func
90+
91+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
92+
if seqlen > args.original_seq_len:
93+
low, high = find_correction_range(
94+
beta_fast, beta_slow, dim, base, args.original_seq_len
95+
)
96+
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
97+
freqs = freqs / factor * (1 - smooth) + freqs * smooth
98+
99+
t = torch.arange(seqlen)
100+
freqs = torch.outer(t, freqs)
101+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
102+
return freqs_cis
103+
104+
105+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
106+
"""
107+
Applies rotary positional embeddings to the input tensor.
108+
109+
Args:
110+
x (torch.Tensor): Input tensor with positional embeddings to be applied.
111+
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
112+
113+
Returns:
114+
torch.Tensor: Tensor with rotary embeddings applied.
115+
"""
116+
dtype = x.dtype
117+
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
118+
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
119+
y = torch.view_as_real(x * freqs_cis).flatten(3)
120+
return y.to(dtype)
121+
122+
123+
class Attention(nn.Module):
124+
"""
125+
Multi-head attention (MLA) module.
126+
"""
127+
128+
def __init__(self, model_args: DeepseekV3ModelArgs):
129+
super().__init__()
130+
self.dim = model_args.dim
131+
self.n_heads = model_args.n_heads
132+
self.q_lora_rank = model_args.q_lora_rank
133+
self.kv_lora_rank = model_args.kv_lora_rank
134+
self.qk_nope_head_dim = model_args.qk_nope_head_dim
135+
self.qk_rope_head_dim = model_args.qk_rope_head_dim
136+
self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim
137+
self.v_head_dim = model_args.v_head_dim
138+
139+
if self.q_lora_rank == 0:
140+
self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim)
141+
else:
142+
self.wq_a = nn.Linear(self.dim, self.q_lora_rank)
143+
self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps)
144+
self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
145+
self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
146+
self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps)
147+
self.wkv_b = nn.Linear(
148+
self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
149+
)
150+
self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)
151+
self.softmax_scale = self.qk_head_dim**-0.5
152+
153+
if model_args.max_seq_len > model_args.original_seq_len:
154+
mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0
155+
self.softmax_scale = self.softmax_scale * mscale * mscale
156+
157+
self.sdpa = build_attention(model_args.use_flex_attn, model_args.attn_mask_type)
158+
159+
def forward(
160+
self,
161+
x: torch.Tensor,
162+
freqs_cis: torch.Tensor,
163+
):
164+
"""
165+
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
166+
167+
Args:
168+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
169+
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
170+
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
171+
172+
Returns:
173+
torch.Tensor: Output tensor with the same shape as the input.
174+
"""
175+
bsz, seqlen, _ = x.size()
176+
if self.q_lora_rank == 0:
177+
q = self.wq(x) # q: (bsz, seqlen, n_heads * qk_head_dim)
178+
else:
179+
q = self.wq_b(
180+
self.q_norm(self.wq_a(x))
181+
) # q: (bsz, seqlen, n_heads * qk_head_dim)
182+
183+
q = q.view(
184+
bsz, seqlen, self.n_heads, self.qk_head_dim
185+
) # q: (bsz, seqlen, n_heads, qk_head_dim)
186+
q_nope, q_pe = torch.split(
187+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
188+
)
189+
# q_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
190+
# q_pe: (bsz, seqlen, n_heads, qk_rope_head_dim)
191+
q_pe = apply_rotary_emb(q_pe, freqs_cis)
192+
q = torch.cat([q_nope, q_pe], dim=-1) # q: (bsz, seqlen, n_heads, qk_head_dim)
193+
194+
kv = self.wkv_a(x) # kv: (bsz, seqlen, kv_lora_rank + qk_rope_head_dim)
195+
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
196+
# kv: (bsz, seqlen, kv_lora_rank)
197+
# k_pe: (bsz, seqlen, qk_rope_head_dim)
198+
k_pe = apply_rotary_emb(
199+
k_pe.unsqueeze(2), freqs_cis
200+
) # k_pe: (bsz, seqlen, 1, qk_rope_head_dim)
201+
202+
kv = self.wkv_b(
203+
self.kv_norm(kv)
204+
) # kv: (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim))
205+
kv = kv.view(
206+
bsz, seqlen, self.n_heads, self.qk_nope_head_dim + self.v_head_dim
207+
) # (bsz, seqlen, n_heads, qk_nope_head_dim + v_head_dim)
208+
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
209+
# k_nope: (bsz, seqlen, n_heads, qk_nope_head_dim)
210+
# v: (bsz, seqlen, n_heads, v_head_dim)
211+
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)
212+
# k: (bsz, seqlen, n_heads, qk_head_dim)
213+
214+
# TODO: Need to pass softmax_scale to sdpa() interface.
215+
# For mask, DeepseekV3 uses causal mask, so we can use the default mask in sdpa
216+
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L17
217+
output = self.sdpa(q, k, v)
218+
219+
output = output.transpose(1, 2).contiguous()
220+
output = output.view(bsz, seqlen, -1) # (bs, seqlen, n_heads * v_head_dim)
221+
return self.wo(output) # (bsz, seqlen, dim)

0 commit comments

Comments
 (0)