Skip to content

Commit bbf437e

Browse files
Add Dilated Sliding Window mask_mod (#85)
stack-info: PR: #85, branch: drisspg/stack/3 Co-authored-by: sngkim <sangkim.dev@gmail.com>
1 parent 36f8bd5 commit bbf437e

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

attn_gym/masks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from attn_gym.masks.sliding_window import generate_sliding_window
33
from attn_gym.masks.prefix_lm import generate_prefix_lm_mask
44
from attn_gym.masks.document_mask import generate_doc_mask_mod
5+
from attn_gym.masks.dilated_sliding_window import generate_dilated_sliding_window
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
from torch.nn.attention.flex_attention import _mask_mod_signature
3+
4+
5+
def generate_dilated_sliding_window(window_size: int, dilation: int) -> _mask_mod_signature:
6+
"""Generates a dilated sliding window attention mask.
7+
Args:
8+
window_size: The size of the sliding window.
9+
dilation: The dilation factor for the sliding window.
10+
11+
Note:
12+
Query at position i can only attend to keys within a window of size `window_size`
13+
centered around i, where the keys are at positions j such that:
14+
* abs(i - j) <= window_size
15+
* abs(i - j) % dilation == 0
16+
"""
17+
18+
def dilated_sliding_window(b, h, q_idx, kv_idx):
19+
diff = torch.abs(q_idx - kv_idx)
20+
in_window = diff <= window_size
21+
is_dilated = (diff % dilation) == 0
22+
return in_window & is_dilated
23+
24+
dilated_sliding_window.__name__ = f"dilated_sliding_window_{window_size}_dilation_{dilation}"
25+
return dilated_sliding_window
26+
27+
28+
def main(device: str = "cpu"):
29+
"""Visualize the attention scores of dilated sliding window mask mod.
30+
31+
Args:
32+
device (str): Device to use for computation.
33+
"""
34+
from attn_gym import visualize_attention_scores
35+
36+
B, H, SEQ_LEN, HEAD_DIM = 1, 1, 24, 8
37+
38+
def make_tensor():
39+
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)
40+
41+
query, key = make_tensor(), make_tensor()
42+
43+
dilated_sliding_window_mask = generate_dilated_sliding_window(window_size=4, dilation=2)
44+
visualize_attention_scores(
45+
query,
46+
key,
47+
mask_mod=dilated_sliding_window_mask,
48+
device=device,
49+
name="dilated_sliding_window_mask",
50+
)
51+
52+
53+
if __name__ == "__main__":
54+
try:
55+
from jsonargparse import CLI
56+
except ImportError:
57+
raise ImportError("Be sure to run: pip install -e .'[viz]'")
58+
CLI(main)

0 commit comments

Comments
 (0)