Skip to content

Commit bf70245

Browse files
authored
Add example for showcasing how to do multi-latent Attention (#113)
stack-info: PR: #113, branch: drisspg/stack/6
1 parent af82ef0 commit bf70245

File tree

2 files changed

+642
-0
lines changed

2 files changed

+642
-0
lines changed

attn_gym/mods/latent_attention.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Implementation of Multi-head Level Attention (MLA) RoPE score modification from DeepSeek-V2.
2+
3+
Reference: https://arxiv.org/pdf/2405.04434 - DeepSeek-V2: A Strong, Economical, and
4+
Efficient Mixture-of-Experts Language Model
5+
"""
6+
7+
import torch
8+
from torch import Tensor
9+
from torch.nn.attention.flex_attention import _score_mod_signature
10+
11+
12+
def generate_mla_rope_score_mod(
13+
query_rope: Tensor,
14+
key_rope: Tensor,
15+
num_heads: int,
16+
scale: float = 1.0,
17+
) -> _score_mod_signature:
18+
"""Returns an MLA RoPE score modification function to be used w/ FlexAttention
19+
20+
Args:
21+
query_pe: Positional embeddings for queries [batch, num_heads, seq_len, head_dim]
22+
key_pe: Positional embeddings for keys [batch, num_heads//128, seq_len, head_dim]
23+
num_heads: The number of query heads
24+
scale: Scaling factor for the positional embedding contribution
25+
26+
Returns:
27+
mla_rope_score_mod: Score modification function for FlexAttention
28+
"""
29+
30+
def mla_rope_score_mod(
31+
score: Tensor, b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor
32+
) -> Tensor:
33+
return score + (
34+
scale * torch.dot(query_rope[b, h, q_idx], key_rope[b, h // num_heads, kv_idx])
35+
)
36+
37+
mla_rope_score_mod.__name__ = f"mla_rope_score_mod_scale_{scale}"
38+
return mla_rope_score_mod
39+
40+
41+
def main(device: str = "cuda"):
42+
"""Visualize the attention scores with MLA RoPE modification.
43+
44+
Args:
45+
device: Device to use for computation
46+
"""
47+
from attn_gym import visualize_attention_scores
48+
49+
# Example dimensions
50+
B, H, SEQ_LEN, LATENT_HEAD_DIM = 1, 128, 8, 512
51+
ROPE_HEAD_DIM = 64
52+
53+
# Create random tensors for visualization
54+
query = torch.rand(B, H, SEQ_LEN, LATENT_HEAD_DIM, device=device)
55+
56+
key = torch.rand(B, 1, SEQ_LEN, LATENT_HEAD_DIM, device=device)
57+
58+
# Create positional embeddings
59+
query_pe = torch.rand(B, H, SEQ_LEN, ROPE_HEAD_DIM, device=device)
60+
key_pe = torch.rand(B, 1, SEQ_LEN, ROPE_HEAD_DIM, device=device)
61+
62+
# Generate the score modification function
63+
mla_rope_score_mod = generate_mla_rope_score_mod(
64+
query_rope=query_pe, key_rope=key_pe, num_heads=H
65+
)
66+
67+
# Visualize attention scores with MLA RoPE modification
68+
visualize_attention_scores(
69+
query, key, score_mod=mla_rope_score_mod, device=device, name="mla_rope_score_mod"
70+
)
71+
72+
73+
if __name__ == "__main__":
74+
try:
75+
from jsonargparse import CLI
76+
except ImportError:
77+
raise ImportError("Be sure to run: pip install -e .'[viz]'")
78+
CLI(main)

0 commit comments

Comments
 (0)