Skip to content

Commit 5e0d1b8

Browse files
authored
NATTEN example (#16)
1 parent 9128c27 commit 5e0d1b8

File tree

2 files changed

+148
-3
lines changed

2 files changed

+148
-3
lines changed

attn_gym/masks/natten.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Generates a NATTEN mask"""
2+
3+
import torch
4+
from torch import IntTensor, BoolTensor
5+
from torch.nn.attention.flex_attention import _mask_mod_signature
6+
from typing import Tuple
7+
8+
9+
def generate_natten(
10+
canvas_w: int,
11+
canvas_h: int,
12+
kernel_w: int,
13+
kernel_h: int,
14+
) -> _mask_mod_signature:
15+
"""Generates a NATTEN attention mask with a given kernel size.
16+
Args:
17+
canvas_w: The width of the canvas.
18+
canvas_h: The height of the canvas.
19+
kernel_w: The width of the kernel.
20+
kernel_h: The height of the kernel.
21+
"""
22+
23+
def get_x_y(idx: IntTensor) -> Tuple[IntTensor, IntTensor]:
24+
return idx // canvas_w, idx % canvas_w
25+
26+
def natten_mask_mod(
27+
b: IntTensor,
28+
h: IntTensor,
29+
q_idx: IntTensor,
30+
kv_idx: IntTensor,
31+
) -> BoolTensor:
32+
q_x, q_y = get_x_y(q_idx)
33+
kv_x, kv_y = get_x_y(kv_idx)
34+
# kernel nominally attempts to center itself on the query, but kernel center
35+
# is clamped to a fixed distance (kernel half-length) from the canvas edge
36+
kernel_center_x = q_x.clamp(kernel_w // 2, (canvas_w - 1) - kernel_w // 2)
37+
kernel_center_y = q_y.clamp(kernel_h // 2, (canvas_h - 1) - kernel_h // 2)
38+
hori_mask = (kernel_center_x - kv_x).abs() <= kernel_w // 2
39+
vert_mask = (kernel_center_y - kv_y).abs() <= kernel_h // 2
40+
return hori_mask & vert_mask
41+
42+
natten_mask_mod.__name__ = f"natten_c{canvas_w}x{canvas_h}_k{kernel_w}x{kernel_h}"
43+
return natten_mask_mod
44+
45+
46+
def main(device: str = "cpu"):
47+
"""Visualize the attention scores of NATTEN mask mod.
48+
Note: a more complete implementation of NATTEN would include support for kernel dilation.
49+
The NATTEN unfused kernel also has features like the ability to cross-attend to register tokens.
50+
This capability is possible to express in Flex Attention but not attempted here.
51+
See https://github.com/SHI-Labs/NATTEN for more details.
52+
53+
Args:
54+
device (str): Device to use for computation. Defaults
55+
"""
56+
from attn_gym import visualize_attention_scores
57+
58+
B, H, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM = 1, 1, 6, 6, 8
59+
60+
def make_tensor():
61+
return torch.ones(B, H, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM, device=device)
62+
63+
query, key = make_tensor(), make_tensor()
64+
65+
kernel_size = 3
66+
natten_mask = generate_natten(
67+
canvas_w=CANVAS_WIDTH,
68+
canvas_h=CANVAS_HEIGHT,
69+
kernel_w=kernel_size,
70+
kernel_h=kernel_size,
71+
)
72+
visualize_attention_scores(
73+
# TODO: update visualize_attention_scores to support 2D sequences
74+
query.flatten(start_dim=2, end_dim=3),
75+
key.flatten(start_dim=2, end_dim=3),
76+
mask_mod=natten_mask,
77+
device=device,
78+
name=natten_mask.__name__,
79+
)
80+
81+
82+
if __name__ == "__main__":
83+
try:
84+
from jsonargparse import CLI
85+
except ImportError:
86+
raise ImportError("Be sure to run: pip install -e .'[viz]'")
87+
CLI(main)

examples/flex_attn.ipynb

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,15 +560,15 @@
560560
"cell_type": "markdown",
561561
"metadata": {},
562562
"source": [
563-
"### NATTEN MASKING\n",
563+
"### Stand-Alone Self-Attention Masking\n",
564564
"\n",
565565
"In this case, imagine that we have a 2D image of size (H x W) flattened into a\n",
566566
"sequence of tokens. We only want to attend to tokens within 8 `pixels`, but\n",
567567
"from a 2D perspective.\n",
568568
"\n",
569569
"We can implement this mask_mod by first translating the 1D position into 2D coordinates. Then, we can simply check if the distance of both coordinates is within the window.\n",
570570
"\n",
571-
"For more details check the paper's github repository [NATTEN](https://github.com/SHI-Labs/NATTEN) "
571+
"For more details check the paper, [Stand-Alone Self-Attention in Vision Models](https://arxiv.org/abs/1906.05909)"
572572
]
573573
},
574574
{
@@ -591,14 +591,72 @@
591591
" return idx // W, idx % W\n",
592592
"\n",
593593
"\n",
594-
"def natten_mask(b, h, q_idx, kv_idx):\n",
594+
"def sasa_mask(b, h, q_idx, kv_idx):\n",
595595
" q_x, q_y = get_x_y(q_idx)\n",
596596
" kv_x, kv_y = get_x_y(kv_idx)\n",
597597
" horizontal_mask = (q_x - kv_x).abs() <= WINDOW\n",
598598
" vertical_mask = (q_y - kv_y).abs() <= WINDOW\n",
599599
" return horizontal_mask & vertical_mask\n",
600600
"\n",
601601
"\n",
602+
"test_mask(mask_mod=sasa_mask)"
603+
]
604+
},
605+
{
606+
"cell_type": "markdown",
607+
"metadata": {},
608+
"source": [
609+
"### NATTEN Masking\n",
610+
"\n",
611+
"Consider a 2D image of size (H x W) flattened into a sequence of tokens.\n",
612+
"Queries attend to keys in a fixed kernel area (K_H x K_W), centered where possible\n",
613+
"on the query, whilst staying within the canvas and always including the query.\n",
614+
"\n",
615+
"This is similar to SASA, except with extra handling to keep the kernel inside the canvas,\n",
616+
"ensuring that all queries attend to a fixed number of keys. \n",
617+
"Keys compare their position to the kernel center, not the query. The kernel center attempts\n",
618+
"to follow the query position, but is clamped to stay a fixed distance (its half-length) away\n",
619+
"from the canvas edge.\n",
620+
"\n",
621+
"See the [NATTEN repository](https://github.com/SHI-Labs/NATTEN) for more information. \n",
622+
"_Note: a more complete implementation of NATTEN would include support for kernel dilation._ \n",
623+
"_The NATTEN unfused kernel also has features like the ability to cross-attend to register tokens._\n",
624+
"_This capability is possible to express in Flex Attention but not attempted here._"
625+
]
626+
},
627+
{
628+
"cell_type": "code",
629+
"execution_count": null,
630+
"metadata": {},
631+
"outputs": [],
632+
"source": [
633+
"H = 128\n",
634+
"W = 128\n",
635+
"K_H = 7\n",
636+
"K_W = 7\n",
637+
"\n",
638+
"\n",
639+
"def get_x_y(idx):\n",
640+
" return idx // W, idx % W\n",
641+
"\n",
642+
"\n",
643+
"def natten_mask(\n",
644+
" b,\n",
645+
" h,\n",
646+
" q_idx,\n",
647+
" kv_idx,\n",
648+
"):\n",
649+
" q_x, q_y = get_x_y(q_idx)\n",
650+
" kv_x, kv_y = get_x_y(kv_idx)\n",
651+
" # kernel nominally attempts to center itself on the query, but kernel center\n",
652+
" # is clamped to a fixed distance (kernel half-length) from the canvas edge\n",
653+
" kernel_x = q_x.clamp(K_W // 2, (W - 1) - K_W // 2)\n",
654+
" kernel_y = q_y.clamp(K_H // 2, (H - 1) - K_H // 2)\n",
655+
" hori_mask = (kernel_x - kv_x).abs() <= K_W // 2\n",
656+
" vert_mask = (kernel_y - kv_y).abs() <= K_H // 2\n",
657+
" return hori_mask & vert_mask\n",
658+
"\n",
659+
"\n",
602660
"test_mask(mask_mod=natten_mask)"
603661
]
604662
},

0 commit comments

Comments
 (0)