|
78 | 78 | 这里我们以一个最简单的例子来说明更新的过程。
|
79 | 79 |
|
80 | 80 |
|
81 |
| -我们以 **序列长度 $ N = 4 $**、**特征维度 $ d = 2 $** 为例,将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 均分为 **2 块**,展示 FlashAttention 的分块计算和流式更新过程。假设: |
82 |
| -- $\mathbf{Q} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_r = 2 $)。 |
83 |
| -- $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_c = 2 $)。 |
| 81 | +我们以 **序列长度 $ N = 4 $ 、特征维度 $ d = 2 $为例**,将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 均分为 **2 块**,展示 FlashAttention 的分块计算和流式更新过程。假设: |
| 82 | +- $\mathbf{Q} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_r = 2 $ )。 |
| 83 | +- $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_c = 2 $ )。 |
84 | 84 |
|
85 | 85 | 初始状态下:
|
86 | 86 | - 输出矩阵 $\mathbf{O} = \begin{bmatrix} 0 & 0 \\ 0 & 0 \\ 0 & 0 \\ 0 & 0 \end{bmatrix}$。
|
87 | 87 | - 全局统计量:$\ell = [0, 0, 0, 0]^T$, $m = [-\infty, -\infty, -\infty, -\infty]^T$。
|
88 | 88 |
|
89 | 89 | ---
|
90 | 90 |
|
91 |
| -**步骤 1:外层循环 $ j=1 $,处理块 $\mathbf{K}_1, \mathbf{V}_1$ ** |
| 91 | +**步骤 1:外层循环 $ j=1 $,处理块 $\mathbf{K}_1$ , $\mathbf{V}_1$ :** |
92 | 92 |
|
93 |
| -1. **加载 $\mathbf{K}_1, \mathbf{V}_1$ 到 SRAM**: |
| 93 | +1. **加载 $\mathbf{K}_1$, $\mathbf{V}_1$ 到 SRAM**: |
94 | 94 |
|
95 | 95 | $$
|
96 | 96 | \mathbf{K}_1 = \begin{bmatrix} k_{11} & k_{12} \\ k_{21} & k_{22} \end{bmatrix}, \quad \mathbf{V}_1 = \begin{bmatrix} v_{11} & v_{12} \\ v_{21} & v_{22} \end{bmatrix}
|
97 | 97 | $$
|
98 | 98 |
|
99 |
| -2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$ **: |
| 99 | +2. **内层循环 $i=1$,处理块 $\mathbf{Q}_1$ :** |
100 | 100 | - **加载数据**:
|
101 | 101 | $$
|
102 | 102 | \mathbf{Q}_1 = \begin{bmatrix} q_{11} & q_{12} \\ q_{21} & q_{22} \end{bmatrix}, \quad \mathbf{O}_1 = \begin{bmatrix} 0 & 0 \\ 0 & 0 \end{bmatrix}, \quad \ell_1 = [0, 0]^T, \quad m_1 = [-\infty, -\infty]^T
|
|
118 | 118 | $$
|
119 | 119 | - **写回 HBM**:更新后的 $\mathbf{O}_1$ 对应前两行,$\ell_1$ 和 $m_1$ 同步更新。
|
120 | 120 |
|
121 |
| -3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$**: |
| 121 | +3. **内层循环 $i=2$,处理块 $\mathbf{Q}_2$ :** |
122 | 122 | - 类似地,加载 $\mathbf{Q}_2 = \begin{bmatrix} q_{31} & q_{32} \\ q_{41} & q_{42} \end{bmatrix}$,计算 $\mathbf{S}_{21} = \mathbf{Q}_2 \mathbf{K}_1^T$,更新后两行 $\mathbf{O}_2$。
|
123 | 123 |
|
124 | 124 |
|
|
129 | 129 | \mathbf{K}_2 = \begin{bmatrix} k_{31} & k_{32} \\ k_{41} & k_{42} \end{bmatrix}, \quad \mathbf{V}_2 = \begin{bmatrix} v_{31} & v_{32} \\ v_{41} & v_{42} \end{bmatrix}
|
130 | 130 | $$
|
131 | 131 |
|
132 |
| -2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$**: |
| 132 | +2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$ :** |
133 | 133 | - **加载数据**:当前 $\mathbf{O}_1$ 已包含来自 $\mathbf{V}_1$ 的贡献。
|
134 | 134 | - **计算局部注意力分数**:
|
135 | 135 | $$
|
|
142 | 142 | $$
|
143 | 143 | - **结果等价于全局 Softmax**:最终 $\mathbf{O}_1$ 为前两行注意力结果的加权和。
|
144 | 144 |
|
145 |
| -3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$ **: |
| 145 | +3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$ :** |
146 | 146 | - 类似地,计算 $\mathbf{S}_{22} = \mathbf{Q}_2 \mathbf{K}_2^T$,更新后两行 $\mathbf{O}_2$。
|
147 | 147 |
|
148 | 148 |
|
149 |
| - |
150 |
| - |
151 | 149 | 通过这种分阶段、分块处理的方式,FlashAttention 在不牺牲计算精度的前提下,显著提升了注意力机制的效率,成为处理长序列任务的利器。
|
152 | 150 |
|
153 | 151 |
|
|
0 commit comments