|
78 | 78 | 这里我们以一个最简单的例子来说明更新的过程。
|
79 | 79 |
|
80 | 80 |
|
81 |
| -我们以 **序列长度 $ N = 4 $ 、特征维度 $ d = 2 $为例**,将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 均分为 **2 块**,展示 FlashAttention 的分块计算和流式更新过程。假设: |
82 |
| -- $\mathbf{Q} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_r = 2 $ )。 |
83 |
| -- $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_c = 2 $ )。 |
| 81 | +我们以 **序列长度 $N = 4$ 、特征维度 $d = 2$ 为例**,将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 均分为 **2 块**,展示 FlashAttention 的分块计算和流式更新过程。假设: |
| 82 | +- $\mathbf{Q} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $B_r = 2$ )。 |
| 83 | +- $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $B_c = 2$ )。 |
84 | 84 |
|
85 | 85 | 初始状态下:
|
86 | 86 | - 输出矩阵 $\mathbf{O} = \begin{bmatrix} 0 & 0 \\ 0 & 0 \\ 0 & 0 \\ 0 & 0 \end{bmatrix}$。
|
87 | 87 | - 全局统计量:$\ell = [0, 0, 0, 0]^T$, $m = [-\infty, -\infty, -\infty, -\infty]^T$。
|
88 | 88 |
|
89 | 89 | ---
|
90 | 90 |
|
91 |
| -**步骤 1:外层循环 $ j=1 $,处理块 $\mathbf{K}_1$ , $\mathbf{V}_1$ :** |
| 91 | +**步骤 1:外层循环 $j=1$,处理块 $\mathbf{K}_1$, $\mathbf{V}_1$ :** |
92 | 92 |
|
93 | 93 | 1. **加载 $\mathbf{K}_1$, $\mathbf{V}_1$ 到 SRAM**:
|
94 | 94 |
|
|
122 | 122 | - 类似地,加载 $\mathbf{Q}_2 = \begin{bmatrix} q_{31} & q_{32} \\ q_{41} & q_{42} \end{bmatrix}$,计算 $\mathbf{S}_{21} = \mathbf{Q}_2 \mathbf{K}_1^T$,更新后两行 $\mathbf{O}_2$。
|
123 | 123 |
|
124 | 124 |
|
125 |
| -**步骤 2:外层循环 $ j=2 $,处理块 $\mathbf{K}_2, \mathbf{V}_2$ ** |
| 125 | +**步骤 2:外层循环 $j=2$,处理块 $\mathbf{K}_2$, $\mathbf{V}_2$ :** |
126 | 126 |
|
127 |
| -1. **加载 $\mathbf{K}_2, \mathbf{V}_2$ 到 SRAM**: |
| 127 | +1. **加载 $\mathbf{K}_2$, $\mathbf{V}_2$ 到 SRAM**: |
128 | 128 | $$
|
129 | 129 | \mathbf{K}_2 = \begin{bmatrix} k_{31} & k_{32} \\ k_{41} & k_{42} \end{bmatrix}, \quad \mathbf{V}_2 = \begin{bmatrix} v_{31} & v_{32} \\ v_{41} & v_{42} \end{bmatrix}
|
130 | 130 | $$
|
131 | 131 |
|
132 |
| -2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$ :** |
| 132 | +2. **内层循环 $i=1$,处理块 $\mathbf{Q}_1$:** |
133 | 133 | - **加载数据**:当前 $\mathbf{O}_1$ 已包含来自 $\mathbf{V}_1$ 的贡献。
|
134 | 134 | - **计算局部注意力分数**:
|
135 | 135 | $$
|
|
142 | 142 | $$
|
143 | 143 | - **结果等价于全局 Softmax**:最终 $\mathbf{O}_1$ 为前两行注意力结果的加权和。
|
144 | 144 |
|
145 |
| -3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$ :** |
| 145 | +3. **内层循环 $i=2$,处理块 $\mathbf{Q}_2$:** |
146 | 146 | - 类似地,计算 $\mathbf{S}_{22} = \mathbf{Q}_2 \mathbf{K}_2^T$,更新后两行 $\mathbf{O}_2$。
|
147 | 147 |
|
148 | 148 |
|
|
0 commit comments