|
| 1 | +import torch |
| 2 | +from torch.nn.attention.flex_attention import _mask_mod_signature |
| 3 | + |
| 4 | + |
| 5 | +def generate_dilated_sliding_window(window_size: int, dilation: int) -> _mask_mod_signature: |
| 6 | + """Generates a dilated sliding window attention mask. |
| 7 | + Args: |
| 8 | + window_size: The size of the sliding window. |
| 9 | + dilation: The dilation factor for the sliding window. |
| 10 | +
|
| 11 | + Note: |
| 12 | + Query at position i can only attend to keys within a window of size `window_size` |
| 13 | + centered around i, where the keys are at positions j such that: |
| 14 | + * abs(i - j) <= window_size |
| 15 | + * abs(i - j) % dilation == 0 |
| 16 | + """ |
| 17 | + |
| 18 | + def dilated_sliding_window(b, h, q_idx, kv_idx): |
| 19 | + diff = torch.abs(q_idx - kv_idx) |
| 20 | + in_window = diff <= window_size |
| 21 | + is_dilated = (diff % dilation) == 0 |
| 22 | + return in_window & is_dilated |
| 23 | + |
| 24 | + dilated_sliding_window.__name__ = f"dilated_sliding_window_{window_size}_dilation_{dilation}" |
| 25 | + return dilated_sliding_window |
| 26 | + |
| 27 | + |
| 28 | +def main(device: str = "cpu"): |
| 29 | + """Visualize the attention scores of dilated sliding window mask mod. |
| 30 | +
|
| 31 | + Args: |
| 32 | + device (str): Device to use for computation. |
| 33 | + """ |
| 34 | + from attn_gym import visualize_attention_scores |
| 35 | + |
| 36 | + B, H, SEQ_LEN, HEAD_DIM = 1, 1, 24, 8 |
| 37 | + |
| 38 | + def make_tensor(): |
| 39 | + return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device) |
| 40 | + |
| 41 | + query, key = make_tensor(), make_tensor() |
| 42 | + |
| 43 | + dilated_sliding_window_mask = generate_dilated_sliding_window(window_size=4, dilation=2) |
| 44 | + visualize_attention_scores( |
| 45 | + query, |
| 46 | + key, |
| 47 | + mask_mod=dilated_sliding_window_mask, |
| 48 | + device=device, |
| 49 | + name="dilated_sliding_window_mask", |
| 50 | + ) |
| 51 | + |
| 52 | + |
| 53 | +if __name__ == "__main__": |
| 54 | + try: |
| 55 | + from jsonargparse import CLI |
| 56 | + except ImportError: |
| 57 | + raise ImportError("Be sure to run: pip install -e .'[viz]'") |
| 58 | + CLI(main) |
0 commit comments