diff --git a/attn_gym/masks/__init__.py b/attn_gym/masks/__init__.py index 731acc7..a668a8c 100644 --- a/attn_gym/masks/__init__.py +++ b/attn_gym/masks/__init__.py @@ -1,5 +1,6 @@ from attn_gym.masks.causal import causal_mask from attn_gym.masks.sliding_window import generate_sliding_window +from attn_gym.masks.sink_attn import generate_sink_mask from attn_gym.masks.prefix_lm import generate_prefix_lm_mask from attn_gym.masks.document_mask import generate_doc_mask_mod from attn_gym.masks.dilated_sliding_window import generate_dilated_sliding_window diff --git a/attn_gym/masks/sink_attn.py b/attn_gym/masks/sink_attn.py new file mode 100644 index 0000000..14b3d27 --- /dev/null +++ b/attn_gym/masks/sink_attn.py @@ -0,0 +1,70 @@ +"""Generates a sliding window attention mask""" + +import torch +from torch.nn.attention.flex_attention import _mask_mod_signature, and_masks, or_masks +from attn_gym.masks import causal_mask + +def generate_sink_mask(window_size: int, sink_size: int = 4) -> _mask_mod_signature: + """Generates a sliding window with sink attention mask. + + Args: + window_size: The size of the sliding window. + sink_size: The number of initial tokens that are always visible (sink tokens). Defaults to 4. + + Note: + We assume that the window size represents the lookback size and we mask out all future tokens + similar to causal masking, but additionally all tokens can attend to the first `sink_size` tokens. + """ + + def sink_mask(b, h, q_idx, kv_idx): + # The sink tokens: the first `sink_size` tokens are always visible + return kv_idx < sink_size + + def sliding_window(b, h, q_idx, kv_idx): + # The sliding window constraint: within the window + return q_idx - kv_idx <= window_size + + # Combine: (sliding window OR sink) AND causal + combined_mask = and_masks( + or_masks(sliding_window, sink_mask), + causal_mask + ) + + combined_mask.__name__ = f"sink_window_{window_size}_sink_{sink_size}" + return combined_mask + + +def main(device: str = "cpu", mask_type: str = "sink", window_size: int = 3, sink_size: int = 4): + """Visualize the attention scores of sink mask. + + Args: + device: Device to use for computation. Defaults to "cpu". + mask_type: Type of mask to use (only "sink" is supported). Defaults to "sink". + window_size: The size of the sliding window. Defaults to 3. + sink_size: The number of initial tokens that are always visible (sink tokens). Defaults to 4. + """ + from attn_gym import visualize_attention_scores + + B, H, SEQ_LEN, HEAD_DIM = 1, 1, 12, 8 + + def make_tensor(): + return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device) + + query, key = make_tensor(), make_tensor() + + if mask_type != "sink": + raise ValueError("This module only supports 'sink' mask type") + + mask_mod = generate_sink_mask(window_size, sink_size) + + visualize_attention_scores( + query, key, mask_mod=mask_mod, device=device, name=mask_mod.__name__ + ) + + +if __name__ == "__main__": + try: + from jsonargparse import CLI + except ImportError: + raise ImportError("Be sure to run: pip install -e .'[viz]'") + CLI(main) diff --git a/examples/benchmark.py b/examples/benchmark.py index 50debe2..aa15a1e 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -22,6 +22,7 @@ generate_sliding_window, generate_prefix_lm_mask, generate_doc_mask_mod, + generate_sink_mask, ) from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap @@ -38,6 +39,7 @@ "softcap_approx": lambda: test_mask( score_mod=generate_tanh_softcap(30, approx=True), skip_correctness=True ), + "sink_attn": lambda: test_mask(mask_mod=generate_sink_mask(window_size=1024)), }