|
| 1 | +"""Generates a STA mask""" |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import IntTensor, BoolTensor |
| 5 | +from torch.nn.attention.flex_attention import _mask_mod_signature |
| 6 | +from typing import Tuple |
| 7 | + |
| 8 | + |
| 9 | +def generate_sta_mask_mod_2d( |
| 10 | + canvas_hw: Tuple[int, int], |
| 11 | + kernel_hw: Tuple[int, int], |
| 12 | + tile_hw: Tuple[int, int], |
| 13 | + text_seq_len: int = 0, |
| 14 | +) -> _mask_mod_signature: |
| 15 | + """Generates a 2D STA mask with a given kernel size. |
| 16 | +
|
| 17 | + Args: |
| 18 | + canvas_hw (Tuple[int, int]): The shape of the canvas (height, width). |
| 19 | + kernel_hw (Tuple[int, int]): The shape of the kernel (height, width). |
| 20 | + tile_hw (Tuple[int, int]): The shape of the tile (height, width). |
| 21 | + text_seq_len (int): The length of the text sequence for masking. |
| 22 | + """ |
| 23 | + canvas_h, canvas_w = canvas_hw |
| 24 | + kernel_h, kernel_w = kernel_hw |
| 25 | + tile_h, tile_w = tile_hw |
| 26 | + tile_numel = tile_h * tile_w |
| 27 | + assert canvas_h % tile_h == 0, ( |
| 28 | + f"Canvas height {canvas_h} is not divisible by tile height {tile_h}" |
| 29 | + ) |
| 30 | + assert canvas_w % tile_w == 0, ( |
| 31 | + f"Canvas width {canvas_w} is not divisible by tile width {tile_w}" |
| 32 | + ) |
| 33 | + assert kernel_h % tile_h == 0, ( |
| 34 | + f"Kernel height {kernel_h} is not divisible by tile height {tile_h}" |
| 35 | + ) |
| 36 | + assert kernel_w % tile_w == 0, ( |
| 37 | + f"Kernel width {kernel_w} is not divisible by tile width {tile_w}" |
| 38 | + ) |
| 39 | + canvas_tile_h, canvas_tile_w = canvas_h // tile_h, canvas_w // tile_w |
| 40 | + kernel_tile_h, kernel_tile_w = kernel_h // tile_h, kernel_w // tile_w |
| 41 | + vision_seq_len = canvas_h * canvas_w |
| 42 | + |
| 43 | + def get_h_w_idx_tiled(idx: IntTensor) -> Tuple[IntTensor, IntTensor]: |
| 44 | + tile_id = idx // tile_numel |
| 45 | + tile_h_idx = tile_id // canvas_tile_w |
| 46 | + tile_w_idx = tile_id % canvas_tile_w |
| 47 | + return tile_h_idx, tile_w_idx |
| 48 | + |
| 49 | + def get_border(kernel_size: IntTensor) -> Tuple[IntTensor, IntTensor]: |
| 50 | + left_border = kernel_size // 2 |
| 51 | + right_border = kernel_size // 2 + (kernel_size % 2 - 1) |
| 52 | + return left_border, right_border |
| 53 | + |
| 54 | + def sta_mask_mod_2d( |
| 55 | + b: IntTensor, |
| 56 | + h: IntTensor, |
| 57 | + q_idx: IntTensor, |
| 58 | + kv_idx: IntTensor, |
| 59 | + ) -> BoolTensor: |
| 60 | + q_tile_h, q_tile_w = get_h_w_idx_tiled(q_idx) |
| 61 | + kv_tile_h, kv_tile_w = get_h_w_idx_tiled(kv_idx) |
| 62 | + left_border_h, right_border_h = get_border(kernel_tile_h) |
| 63 | + left_border_w, right_border_w = get_border(kernel_tile_w) |
| 64 | + kernel_center_h = q_tile_h.clamp(left_border_h, (canvas_tile_h - 1) - right_border_h) |
| 65 | + kernel_center_w = q_tile_w.clamp(left_border_w, (canvas_tile_w - 1) - right_border_w) |
| 66 | + h_mask = (kv_tile_h >= kernel_center_h - left_border_h) & ( |
| 67 | + kv_tile_h <= kernel_center_h + right_border_h |
| 68 | + ) |
| 69 | + w_mask = (kv_tile_w >= kernel_center_w - left_border_w) & ( |
| 70 | + kv_tile_w <= kernel_center_w + right_border_w |
| 71 | + ) |
| 72 | + vision_mask = (q_idx < vision_seq_len) & (kv_idx < vision_seq_len) |
| 73 | + vision_to_text_mask = ( |
| 74 | + (q_idx < vision_seq_len) |
| 75 | + & (kv_idx >= vision_seq_len) |
| 76 | + & (kv_idx < vision_seq_len + text_seq_len) |
| 77 | + ) |
| 78 | + text_to_all_mask = (q_idx >= vision_seq_len) & (kv_idx < vision_seq_len + text_seq_len) |
| 79 | + return (vision_mask & h_mask & w_mask) | vision_to_text_mask | text_to_all_mask |
| 80 | + |
| 81 | + sta_mask_mod_2d.__name__ = ( |
| 82 | + f"sta_2d_c{canvas_h}x{canvas_w}_k{kernel_h}x{kernel_w}_t{tile_h}x{tile_w}" |
| 83 | + ) |
| 84 | + return sta_mask_mod_2d |
| 85 | + |
| 86 | + |
| 87 | +def generate_sta_mask_mod_3d( |
| 88 | + canvas_twh: Tuple[int, int, int], |
| 89 | + kernel_twh: Tuple[int, int, int], |
| 90 | + tile_twh: Tuple[int, int, int], |
| 91 | + text_seq_len: int = 0, |
| 92 | +) -> _mask_mod_signature: |
| 93 | + """Generates a 3D STA mask with a given kernel size. |
| 94 | +
|
| 95 | + Args: |
| 96 | + canvas_twh (Tuple[int, int, int]): The shape of the canvas (time, height, width). |
| 97 | + kernel_twh (Tuple[int, int, int]): The shape of the kernel (time, height, width). |
| 98 | + tile_twh (Tuple[int, int, int]): The shape of the tile (time, height, width). |
| 99 | + text_seq_len (int): The length of the text sequence for masking. |
| 100 | + """ |
| 101 | + canvas_t, canvas_h, canvas_w = canvas_twh |
| 102 | + kernel_t, kernel_h, kernel_w = kernel_twh |
| 103 | + tile_t, tile_h, tile_w = tile_twh |
| 104 | + tile_numel = tile_t * tile_h * tile_w |
| 105 | + assert canvas_t % tile_t == 0, f"Canvas time {canvas_t} is not divisible by tile time {tile_t}" |
| 106 | + assert canvas_h % tile_h == 0, ( |
| 107 | + f"Canvas height {canvas_h} is not divisible by tile height {tile_h}" |
| 108 | + ) |
| 109 | + assert canvas_w % tile_w == 0, ( |
| 110 | + f"Canvas width {canvas_w} is not divisible by tile width {tile_w}" |
| 111 | + ) |
| 112 | + assert kernel_t % tile_t == 0, f"Kernel time {kernel_t} is not divisible by tile time {tile_t}" |
| 113 | + assert kernel_h % tile_h == 0, ( |
| 114 | + f"Kernel height {kernel_h} is not divisible by tile height {tile_h}" |
| 115 | + ) |
| 116 | + assert kernel_w % tile_w == 0, ( |
| 117 | + f"Kernel width {kernel_w} is not divisible by tile width {tile_w}" |
| 118 | + ) |
| 119 | + canvas_tile_t, canvas_tile_h, canvas_tile_w = ( |
| 120 | + canvas_t // tile_t, |
| 121 | + canvas_h // tile_h, |
| 122 | + canvas_w // tile_w, |
| 123 | + ) |
| 124 | + kernel_tile_t, kernel_tile_h, kernel_tile_w = ( |
| 125 | + kernel_t // tile_t, |
| 126 | + kernel_h // tile_h, |
| 127 | + kernel_w // tile_w, |
| 128 | + ) |
| 129 | + vision_seq_len = canvas_t * canvas_h * canvas_w |
| 130 | + |
| 131 | + def get_t_h_w_idx_tiled(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: |
| 132 | + tile_id = idx // tile_numel |
| 133 | + tile_t_idx = tile_id // (canvas_tile_h * canvas_tile_w) |
| 134 | + tile_h_idx = (tile_id % (canvas_tile_h * canvas_tile_w)) // canvas_tile_w |
| 135 | + tile_w_idx = tile_id % canvas_tile_w |
| 136 | + return tile_t_idx, tile_h_idx, tile_w_idx |
| 137 | + |
| 138 | + def get_border(kernel_size: IntTensor) -> Tuple[IntTensor, IntTensor]: |
| 139 | + left_border = kernel_size // 2 |
| 140 | + right_border = kernel_size // 2 + (kernel_size % 2 - 1) |
| 141 | + return left_border, right_border |
| 142 | + |
| 143 | + def sta_mask_mod_3d( |
| 144 | + b: IntTensor, |
| 145 | + h: IntTensor, |
| 146 | + q_idx: IntTensor, |
| 147 | + kv_idx: IntTensor, |
| 148 | + ) -> BoolTensor: |
| 149 | + q_tile_t, q_tile_h, q_tile_w = get_t_h_w_idx_tiled(q_idx) |
| 150 | + kv_tile_t, kv_tile_h, kv_tile_w = get_t_h_w_idx_tiled(kv_idx) |
| 151 | + left_border_t, right_border_t = get_border(kernel_tile_t) |
| 152 | + left_border_h, right_border_h = get_border(kernel_tile_h) |
| 153 | + left_border_w, right_border_w = get_border(kernel_tile_w) |
| 154 | + kernel_center_t = q_tile_t.clamp(left_border_t, (canvas_tile_t - 1) - right_border_t) |
| 155 | + kernel_center_h = q_tile_h.clamp(left_border_h, (canvas_tile_h - 1) - right_border_h) |
| 156 | + kernel_center_w = q_tile_w.clamp(left_border_w, (canvas_tile_w - 1) - right_border_w) |
| 157 | + t_mask = (kv_tile_t >= kernel_center_t - left_border_t) & ( |
| 158 | + kv_tile_t <= kernel_center_t + right_border_t |
| 159 | + ) |
| 160 | + h_mask = (kv_tile_h >= kernel_center_h - left_border_h) & ( |
| 161 | + kv_tile_h <= kernel_center_h + right_border_h |
| 162 | + ) |
| 163 | + w_mask = (kv_tile_w >= kernel_center_w - left_border_w) & ( |
| 164 | + kv_tile_w <= kernel_center_w + right_border_w |
| 165 | + ) |
| 166 | + vision_mask = (q_idx < vision_seq_len) & (kv_idx < vision_seq_len) |
| 167 | + vision_to_text_mask = ( |
| 168 | + (q_idx < vision_seq_len) |
| 169 | + & (kv_idx >= vision_seq_len) |
| 170 | + & (kv_idx < vision_seq_len + text_seq_len) |
| 171 | + ) |
| 172 | + text_to_all_mask = (q_idx >= vision_seq_len) & (kv_idx < vision_seq_len + text_seq_len) |
| 173 | + return (vision_mask & t_mask & w_mask & h_mask) | vision_to_text_mask | text_to_all_mask |
| 174 | + |
| 175 | + sta_mask_mod_3d.__name__ = f"sta_3d_c{canvas_t}x{canvas_h}x{canvas_w}_k{kernel_t}x{kernel_h}x{kernel_w}_t{tile_t}x{tile_h}x{tile_w}" |
| 176 | + return sta_mask_mod_3d |
| 177 | + |
| 178 | + |
| 179 | +def main(device: str = "cpu"): |
| 180 | + """Visualize the attention scores of STA mask mod. |
| 181 | + Original repo: https://github.com/hao-ai-lab/FastVideo |
| 182 | + See blog: https://hao-ai-lab.github.io/blogs/sta/ |
| 183 | + For reference on how to use a Sliding Tile Attention (STA) module, checkout: |
| 184 | + 1, https://github.com/hao-ai-lab/FastVideo/blob/6ef8fcb61d5046d22b51a6ef5ef312731cef503d/fastvideo/v1/attention/backends/sliding_tile_attn.py#L105 |
| 185 | + 2, https://github.com/fla-org/fla-zoo/blob/main/flazoo/models/attentions.py#L702 |
| 186 | +
|
| 187 | + Note that this version alters some of the original code for better readability and include a 2d use case. |
| 188 | + Args: |
| 189 | + device (str): Device to use for computation. Defaults |
| 190 | + """ |
| 191 | + from attn_gym import visualize_attention_scores |
| 192 | + |
| 193 | + B, H, CANVAS_TIME, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM = 1, 1, 24, 24, 24, 8 |
| 194 | + KERNEL_T, KERNEL_H, KERNEL_W = 12, 12, 12 |
| 195 | + TILE_T, TILE_H, TILE_W = 4, 4, 4 |
| 196 | + |
| 197 | + def make_tensor(): |
| 198 | + return torch.ones(B, H, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM, device=device) |
| 199 | + |
| 200 | + query, key = make_tensor(), make_tensor() |
| 201 | + |
| 202 | + sta_mask_2d = generate_sta_mask_mod_2d( |
| 203 | + canvas_hw=(CANVAS_HEIGHT, CANVAS_WIDTH), |
| 204 | + kernel_hw=(KERNEL_H, KERNEL_W), |
| 205 | + tile_hw=(TILE_H, TILE_W), |
| 206 | + ) |
| 207 | + |
| 208 | + visualize_attention_scores( |
| 209 | + query.flatten(start_dim=2, end_dim=3), |
| 210 | + key.flatten(start_dim=2, end_dim=3), |
| 211 | + mask_mod=sta_mask_2d, |
| 212 | + device=device, |
| 213 | + name=sta_mask_2d.__name__, |
| 214 | + ) |
| 215 | + |
| 216 | + def make_3d_tensor(): |
| 217 | + return torch.ones(B, H, CANVAS_TIME, CANVAS_HEIGHT, CANVAS_WIDTH, HEAD_DIM, device=device) |
| 218 | + |
| 219 | + query_3d, key_3d = make_3d_tensor(), make_3d_tensor() |
| 220 | + |
| 221 | + sta_mask_3d = generate_sta_mask_mod_3d( |
| 222 | + canvas_twh=(CANVAS_TIME, CANVAS_HEIGHT, CANVAS_WIDTH), |
| 223 | + kernel_twh=(KERNEL_T, KERNEL_H, KERNEL_W), |
| 224 | + tile_twh=(TILE_T, TILE_H, TILE_W), |
| 225 | + ) |
| 226 | + |
| 227 | + visualize_attention_scores( |
| 228 | + query_3d.flatten(start_dim=2, end_dim=4), |
| 229 | + key_3d.flatten(start_dim=2, end_dim=4), |
| 230 | + mask_mod=sta_mask_3d, |
| 231 | + device=device, |
| 232 | + name=sta_mask_3d.__name__, |
| 233 | + ) |
| 234 | + |
| 235 | + |
| 236 | +if __name__ == "__main__": |
| 237 | + try: |
| 238 | + from jsonargparse import CLI |
| 239 | + except ImportError: |
| 240 | + raise ImportError("Be sure to run: pip install -e .'[viz]'") |
| 241 | + CLI(main) |
0 commit comments