Skip to content

Commit 1de7070

Browse files
authored
Add Sliding Tile Attention impl (#140)
* Add Sliding Tile Attention impl * Fix format
1 parent 24cfc76 commit 1de7070

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

attn_gym/masks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from attn_gym.masks.document_mask import generate_doc_mask_mod
55
from attn_gym.masks.dilated_sliding_window import generate_dilated_sliding_window
66
from attn_gym.masks.natten import generate_natten, generate_tiled_natten, generate_morton_natten
7+
from attn_gym.masks.sta import generate_sta_mask_mod_2d, generate_sta_mask_mod_3d

attn_gym/masks/sta.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Generates a STA mask"""
2+
3+
import torch
4+
from torch import IntTensor, BoolTensor
5+
from torch.nn.attention.flex_attention import _mask_mod_signature
6+
from typing import Tuple
7+
8+
9+
def generate_sta_mask_mod_2d(
10+
canvas_hw: Tuple[int, int],
11+
kernel_hw: Tuple[int, int],
12+
tile_hw: Tuple[int, int],
13+
text_seq_len: int = 0,
14+
) -> _mask_mod_signature:
15+
"""Generates a 2D STA mask with a given kernel size.
16+
17+
Args:
18+
canvas_hw (Tuple[int, int]): The shape of the canvas (height, width).
19+
kernel_hw (Tuple[int, int]): The shape of the kernel (height, width).
20+
tile_hw (Tuple[int, int]): The shape of the tile (height, width).
21+
text_seq_len (int): The length of the text sequence for masking.
22+
"""
23+
canvas_h, canvas_w = canvas_hw
24+
kernel_h, kernel_w = kernel_hw
25+
tile_h, tile_w = tile_hw
26+
tile_numel = tile_h * tile_w
27+
assert canvas_h % tile_h == 0, (
28+
f"Canvas height {canvas_h} is not divisible by tile height {tile_h}"
29+
)
30+
assert canvas_w % tile_w == 0, (
31+
f"Canvas width {canvas_w} is not divisible by tile width {tile_w}"
32+
)
33+
assert kernel_h % tile_h == 0, (
34+
f"Kernel height {kernel_h} is not divisible by tile height {tile_h}"
35+
)
36+
assert kernel_w % tile_w == 0, (
37+
f"Kernel width {kernel_w} is not divisible by tile width {tile_w}"
38+
)
39+
canvas_tile_h, canvas_tile_w = canvas_h // tile_h, canvas_w // tile_w
40+
kernel_tile_h, kernel_tile_w = kernel_h // tile_h, kernel_w // tile_w
41+
vision_seq_len = canvas_h * canvas_w
42+
43+
def get_h_w_idx_tiled(idx: IntTensor) -> Tuple[IntTensor, IntTensor]:
44+
tile_id = idx // tile_numel
45+
tile_h_idx = tile_id // canvas_tile_w
46+
tile_w_idx = tile_id % canvas_tile_w
47+
return tile_h_idx, tile_w_idx
48+
49+
def get_border(kernel_size: IntTensor) -> Tuple[IntTensor, IntTensor]:
50+
left_border = kernel_size // 2
51+
right_border = kernel_size // 2 + (kernel_size % 2 - 1)
52+
return left_border, right_border
53+
54+
def sta_mask_mod_2d(
55+
b: IntTensor,
56+
h: IntTensor,
57+
q_idx: IntTensor,
58+
kv_idx: IntTensor,
59+
) -> BoolTensor:
60+
q_tile_h, q_tile_w = get_h_w_idx_tiled(q_idx)
61+
kv_tile_h, kv_tile_w = get_h_w_idx_tiled(kv_idx)
62+
left_border_h, right_border_h = get_border(kernel_tile_h)
63+
left_border_w, right_border_w = get_border(kernel_tile_w)
64+
kernel_center_h = q_tile_h.clamp(left_border_h, (canvas_tile_h - 1) - right_border_h)
65+
kernel_center_w = q_tile_w.clamp(left_border_w, (canvas_tile_w - 1) - right_border_w)
66+
h_mask = (kv_tile_h >= kernel_center_h - left_border_h) & (
67+
kv_tile_h <= kernel_center_h + right_border_h
68+
)
69+
w_mask = (kv_tile_w >= kernel_center_w - left_border_w) & (
70+
kv_tile_w <= kernel_center_w + right_border_w
71+
)
72+
vision_mask = (q_idx < vision_seq_len) & (kv_idx < vision_seq_len)
73+
vision_to_text_mask = (
74+
(q_idx < vision_seq_len)
75+
& (kv_idx >= vision_seq_len)
76+
& (kv_idx < vision_seq_len + text_seq_len)
77+
)
78+
text_to_all_mask = (q_idx >= vision_seq_len) & (kv_idx < vision_seq_len + text_seq_len)
79+
return (vision_mask & h_mask & w_mask) | vision_to_text_mask | text_to_all_mask
80+
81+
sta_mask_mod_2d.__name__ = (
82+
f"sta_2d_c{canvas_h}x{canvas_w}_k{kernel_h}x{kernel_w}_t{tile_h}x{tile_w}"
83+
)
84+
return sta_mask_mod_2d
85+
86+
87+
def generate_sta_mask_mod_3d(
88+
canvas_twh: Tuple[int, int, int],
89+
kernel_twh: Tuple[int, int, int],
90+
tile_twh: Tuple[int, int, int],
91+
text_seq_len: int = 0,
92+
) -> _mask_mod_signature:
93+
"""Generates a 3D STA mask with a given kernel size.
94+
95+
Args:
96+
canvas_twh (Tuple[int, int, int]): The shape of the canvas (time, height, width).
97+
kernel_twh (Tuple[int, int, int]): The shape of the kernel (time, height, width).
98+
tile_twh (Tuple[int, int, int]): The shape of the tile (time, height, width).
99+
text_seq_len (int): The length of the text sequence for masking.
100+
"""
101+
canvas_t, canvas_h, canvas_w = canvas_twh
102+
kernel_t, kernel_h, kernel_w = kernel_twh
103+
tile_t, tile_h, tile_w = tile_twh
104+
tile_numel = tile_t * tile_h * tile_w
105+
assert canvas_t % tile_t == 0, f"Canvas time {canvas_t} is not divisible by tile time {tile_t}"
106+
assert canvas_h % tile_h == 0, (
107+
f"Canvas height {canvas_h} is not divisible by tile height {tile_h}"
108+
)
109+
assert canvas_w % tile_w == 0, (
110+
f"Canvas width {canvas_w} is not divisible by tile width {tile_w}"
111+
)
112+
assert kernel_t % tile_t == 0, f"Kernel time {kernel_t} is not divisible by tile time {tile_t}"
113+
assert kernel_h % tile_h == 0, (
114+
f"Kernel height {kernel_h} is not divisible by tile height {tile_h}"
115+
)
116+
assert kernel_w % tile_w == 0, (
117+
f"Kernel width {kernel_w} is not divisible by tile width {tile_w}"
118+
)
119+
canvas_tile_t, canvas_tile_h, canvas_tile_w = (
120+
canvas_t // tile_t,
121+
canvas_h // tile_h,
122+
canvas_w // tile_w,
123+
)
124+
kernel_tile_t, kernel_tile_h, kernel_tile_w = (
125+
kernel_t // tile_t,
126+
kernel_h // tile_h,
127+
kernel_w // tile_w,
128+
)
129+
vision_seq_len = canvas_t * canvas_h * canvas_w
130+
131+
def get_t_h_w_idx_tiled(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]:
132+
tile_id = idx // tile_numel
133+
tile_t_idx = tile_id // (canvas_tile_h * canvas_tile_w)
134+
tile_h_idx = (tile_id % (canvas_tile_h * canvas_tile_w)) // canvas_tile_w
135+
tile_w_idx = tile_id % canvas_tile_w
136+
return tile_t_idx, tile_h_idx, tile_w_idx
137+
138+
def get_border(kernel_size: IntTensor) -> Tuple[IntTensor, IntTensor]:
139+
left_border = kernel_size // 2
140+
right_border = kernel_size // 2 + (kernel_size % 2 - 1)
141+
return left_border, right_border
142+
143+
def sta_mask_mod_3d(
144+
b: IntTensor,
145+
h: IntTensor,
146+
q_idx: IntTensor,
147+
kv_idx: IntTensor,
148+
) -> BoolTensor:
149+
q_tile_t, q_tile_h, q_tile_w = get_t_h_w_idx_tiled(q_idx)
150+
kv_tile_t, kv_tile_h, kv_tile_w = get_t_h_w_idx_tiled(kv_idx)
151+
left_border_t, right_border_t = get_border(kernel_tile_t)
152+
left_border_h, right_border_h = get_border(kernel_tile_h)
153+
left_border_w, right_border_w = get_border(kernel_tile_w)
154+
kernel_center_t = q_tile_t.clamp(left_border_t, (canvas_tile_t - 1) - right_border_t)
155+
kernel_center_h = q_tile_h.clamp(left_border_h, (canvas_tile_h - 1) - right_border_h)
156+
kernel_center_w = q_tile_w.clamp(left_border_w, (canvas_tile_w - 1) - right_border_w)
157+
t_mask = (kv_tile_t >= kernel_center_t - left_border_t) & (
158+
kv_tile_t <= kernel_center_t + right_border_t
159+
)
160+
h_mask = (kv_tile_h >= kernel_center_h - left_border_h) & (
161+
kv_tile_h <= kernel_center_h + right_border_h
162+
)
163+
w_mask = (kv_tile_w >= kernel_center_w - left_border_w) & (
164+
kv_tile_w <= kernel_center_w + right_border_w
165+
)
166+
vision_mask = (q_idx < vision_seq_len) & (kv_idx < vision_seq_len)
167+
vision_to_text_mask = (
168+
(q_idx < vision_seq_len)
169+
& (kv_idx >= vision_seq_len)
170+
& (kv_idx < vision_seq_len + text_seq_len)
171+
)
172+
text_to_all_mask = (q_idx >= vision_seq_len) & (kv_idx < vision_seq_len + text_seq_len)
173+
return (vision_mask & t_mask & w_mask & h_mask) | vision_to_text_mask | text_to_all_mask
174+
175+
sta_mask_mod_3d.__name__ = f"sta_3d_c{canvas_t}x{canvas_h}x{canvas_w}_k{kernel_t}x{kernel_h}x{kernel_w}_t{tile_t}x{tile_h}x{tile_w}"
176+
return sta_mask_mod_3d
177+
178+
179+
def main(device: str = "cpu"):
180+
"""Visualize the attention scores of STA mask mod.
181+
Original repo: https://github.com/hao-ai-lab/FastVideo
182+
See blog: https://hao-ai-lab.github.io/blogs/sta/
183+
For reference on how to use a Sliding Tile Attention (STA) module, checkout:
184+
1, https://github.com/hao-ai-lab/FastVideo/blob/6ef8fcb61d5046d22b51a6ef5ef312731cef503d/fastvideo/v1/attention/backends/sliding_tile_attn.py#L105
185+
2, https://github.com/fla-org/fla-zoo/blob/main/flazoo/models/attentions.py#L702
186+
187+
Note that this version alters some of the original code for better readability and include a 2d use case.
188+
Args:
189+
device (str): Device to use for computation. Defaults
190+
"""
191+
from attn_gym import visualize_attention_scores
192+
193+
B, H, CANVAS_TIME, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM = 1, 1, 24, 24, 24, 8
194+
KERNEL_T, KERNEL_H, KERNEL_W = 12, 12, 12
195+
TILE_T, TILE_H, TILE_W = 4, 4, 4
196+
197+
def make_tensor():
198+
return torch.ones(B, H, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM, device=device)
199+
200+
query, key = make_tensor(), make_tensor()
201+
202+
sta_mask_2d = generate_sta_mask_mod_2d(
203+
canvas_hw=(CANVAS_HEIGHT, CANVAS_WIDTH),
204+
kernel_hw=(KERNEL_H, KERNEL_W),
205+
tile_hw=(TILE_H, TILE_W),
206+
)
207+
208+
visualize_attention_scores(
209+
query.flatten(start_dim=2, end_dim=3),
210+
key.flatten(start_dim=2, end_dim=3),
211+
mask_mod=sta_mask_2d,
212+
device=device,
213+
name=sta_mask_2d.__name__,
214+
)
215+
216+
def make_3d_tensor():
217+
return torch.ones(B, H, CANVAS_TIME, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM, device=device)
218+
219+
query_3d, key_3d = make_3d_tensor(), make_3d_tensor()
220+
221+
sta_mask_3d = generate_sta_mask_mod_3d(
222+
canvas_twh=(CANVAS_TIME, CANVAS_HEIGHT, CANVAS_WIDTH),
223+
kernel_twh=(KERNEL_T, KERNEL_H, KERNEL_W),
224+
tile_twh=(TILE_T, TILE_H, TILE_W),
225+
)
226+
227+
visualize_attention_scores(
228+
query_3d.flatten(start_dim=2, end_dim=4),
229+
key_3d.flatten(start_dim=2, end_dim=4),
230+
mask_mod=sta_mask_3d,
231+
device=device,
232+
name=sta_mask_3d.__name__,
233+
)
234+
235+
236+
if __name__ == "__main__":
237+
try:
238+
from jsonargparse import CLI
239+
except ImportError:
240+
raise ImportError("Be sure to run: pip install -e .'[viz]'")
241+
CLI(main)

0 commit comments

Comments
 (0)