Skip to content

add sink_attn mask in example #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions attn_gym/masks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from attn_gym.masks.causal import causal_mask
from attn_gym.masks.sliding_window import generate_sliding_window
from attn_gym.masks.sink_attn import generate_sink_mask
from attn_gym.masks.prefix_lm import generate_prefix_lm_mask
from attn_gym.masks.document_mask import generate_doc_mask_mod
from attn_gym.masks.dilated_sliding_window import generate_dilated_sliding_window
Expand Down
70 changes: 70 additions & 0 deletions attn_gym/masks/sink_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Generates a sliding window attention mask"""

import torch
from torch.nn.attention.flex_attention import _mask_mod_signature, and_masks, or_masks
from attn_gym.masks import causal_mask

def generate_sink_mask(window_size: int, sink_size: int = 4) -> _mask_mod_signature:
"""Generates a sliding window with sink attention mask.

Args:
window_size: The size of the sliding window.
sink_size: The number of initial tokens that are always visible (sink tokens). Defaults to 4.

Note:
We assume that the window size represents the lookback size and we mask out all future tokens
similar to causal masking, but additionally all tokens can attend to the first `sink_size` tokens.
"""

def sink_mask(b, h, q_idx, kv_idx):
# The sink tokens: the first `sink_size` tokens are always visible
return kv_idx < sink_size

def sliding_window(b, h, q_idx, kv_idx):
# The sliding window constraint: within the window
return q_idx - kv_idx <= window_size

# Combine: (sliding window OR sink) AND causal
combined_mask = and_masks(
or_masks(sliding_window, sink_mask),
causal_mask
)

combined_mask.__name__ = f"sink_window_{window_size}_sink_{sink_size}"
return combined_mask


def main(device: str = "cpu", mask_type: str = "sink", window_size: int = 3, sink_size: int = 4):
"""Visualize the attention scores of sink mask.

Args:
device: Device to use for computation. Defaults to "cpu".
mask_type: Type of mask to use (only "sink" is supported). Defaults to "sink".
window_size: The size of the sliding window. Defaults to 3.
sink_size: The number of initial tokens that are always visible (sink tokens). Defaults to 4.
"""
from attn_gym import visualize_attention_scores

B, H, SEQ_LEN, HEAD_DIM = 1, 1, 12, 8

def make_tensor():
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)

query, key = make_tensor(), make_tensor()

if mask_type != "sink":
raise ValueError("This module only supports 'sink' mask type")

mask_mod = generate_sink_mask(window_size, sink_size)

visualize_attention_scores(
query, key, mask_mod=mask_mod, device=device, name=mask_mod.__name__
)


if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
CLI(main)
2 changes: 2 additions & 0 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
generate_sliding_window,
generate_prefix_lm_mask,
generate_doc_mask_mod,
generate_sink_mask,
)
from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap

Expand All @@ -38,6 +39,7 @@
"softcap_approx": lambda: test_mask(
score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True
),
"sink_attn": lambda: test_mask(mask_mod=generate_sink_mask(window_size=1024)),
}


Expand Down