Skip to content

Commit e72f1a1

Browse files
committed
[fix] fix formula error
1 parent f0ac6b7 commit e72f1a1

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

docs/17_flash_attn/01_flash_attn_v1_part1.md

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,25 @@ $$
7878
这里我们以一个最简单的例子来说明更新的过程。
7979

8080

81-
我们以 **序列长度 $ N = 4 $****特征维度 $ d = 2 $** 为例,将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 均分为 **2 块**,展示 FlashAttention 的分块计算和流式更新过程。假设:
82-
- $\mathbf{Q} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_r = 2 $)。
83-
- $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_c = 2 $)。
81+
我们以 **序列长度 $ N = 4 $特征维度 $ d = 2 $为例**,将输入矩阵 $\mathbf{Q}, \mathbf{K}, \mathbf{V}$ 均分为 **2 块**,展示 FlashAttention 的分块计算和流式更新过程。假设:
82+
- $\mathbf{Q} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_r = 2 $ )。
83+
- $\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$, $\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$(每块行数 $ B_c = 2 $ )。
8484

8585
初始状态下:
8686
- 输出矩阵 $\mathbf{O} = \begin{bmatrix} 0 & 0 \\ 0 & 0 \\ 0 & 0 \\ 0 & 0 \end{bmatrix}$。
8787
- 全局统计量:$\ell = [0, 0, 0, 0]^T$, $m = [-\infty, -\infty, -\infty, -\infty]^T$。
8888

8989
---
9090

91-
**步骤 1:外层循环 $ j=1 $,处理块 $\mathbf{K}_1, \mathbf{V}_1$ **
91+
**步骤 1:外层循环 $ j=1 $,处理块 $\mathbf{K}_1$ , $\mathbf{V}_1$ :**
9292

93-
1. **加载 $\mathbf{K}_1, \mathbf{V}_1$ 到 SRAM**
93+
1. **加载 $\mathbf{K}_1$, $\mathbf{V}_1$ 到 SRAM**
9494

9595
$$
9696
\mathbf{K}_1 = \begin{bmatrix} k_{11} & k_{12} \\ k_{21} & k_{22} \end{bmatrix}, \quad \mathbf{V}_1 = \begin{bmatrix} v_{11} & v_{12} \\ v_{21} & v_{22} \end{bmatrix}
9797
$$
9898

99-
2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$ **
99+
2. **内层循环 $i=1$,处理块 $\mathbf{Q}_1$ **
100100
- **加载数据**
101101
$$
102102
\mathbf{Q}_1 = \begin{bmatrix} q_{11} & q_{12} \\ q_{21} & q_{22} \end{bmatrix}, \quad \mathbf{O}_1 = \begin{bmatrix} 0 & 0 \\ 0 & 0 \end{bmatrix}, \quad \ell_1 = [0, 0]^T, \quad m_1 = [-\infty, -\infty]^T
@@ -118,7 +118,7 @@ $$
118118
$$
119119
- **写回 HBM**:更新后的 $\mathbf{O}_1$ 对应前两行,$\ell_1$ 和 $m_1$ 同步更新。
120120

121-
3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$**
121+
3. **内层循环 $i=2$,处理块 $\mathbf{Q}_2$**
122122
- 类似地,加载 $\mathbf{Q}_2 = \begin{bmatrix} q_{31} & q_{32} \\ q_{41} & q_{42} \end{bmatrix}$,计算 $\mathbf{S}_{21} = \mathbf{Q}_2 \mathbf{K}_1^T$,更新后两行 $\mathbf{O}_2$。
123123

124124

@@ -129,7 +129,7 @@ $$
129129
\mathbf{K}_2 = \begin{bmatrix} k_{31} & k_{32} \\ k_{41} & k_{42} \end{bmatrix}, \quad \mathbf{V}_2 = \begin{bmatrix} v_{31} & v_{32} \\ v_{41} & v_{42} \end{bmatrix}
130130
$$
131131

132-
2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$**
132+
2. **内层循环 $ i=1 $,处理块 $\mathbf{Q}_1$**
133133
- **加载数据**:当前 $\mathbf{O}_1$ 已包含来自 $\mathbf{V}_1$ 的贡献。
134134
- **计算局部注意力分数**
135135
$$
@@ -142,12 +142,10 @@ $$
142142
$$
143143
- **结果等价于全局 Softmax**:最终 $\mathbf{O}_1$ 为前两行注意力结果的加权和。
144144

145-
3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$ **
145+
3. **内层循环 $ i=2 $,处理块 $\mathbf{Q}_2$ **
146146
- 类似地,计算 $\mathbf{S}_{22} = \mathbf{Q}_2 \mathbf{K}_2^T$,更新后两行 $\mathbf{O}_2$。
147147

148148

149-
150-
151149
通过这种分阶段、分块处理的方式,FlashAttention 在不牺牲计算精度的前提下,显著提升了注意力机制的效率,成为处理长序列任务的利器。
152150

153151

0 commit comments

Comments
 (0)