Skip to content

Commit 582c219

Browse files
authored
[code] add flash v1 code (#67)
1 parent 80411bb commit 582c219

File tree

1 file changed

+331
-0
lines changed

1 file changed

+331
-0
lines changed

docs/17_flash_attn/flash_attn_v1.cu

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
#include <cuda.h>
2+
#include <cuda_runtime.h>
3+
#include <stdio.h>
4+
5+
// void atten_naive_cpu(float *Q,
6+
void free_resource(float *ptr, int is_cuda = 1)
7+
{
8+
if (nullptr != ptr)
9+
{
10+
if (is_cuda)
11+
{
12+
cudaFree(ptr);
13+
}
14+
else
15+
{
16+
delete[] ptr;
17+
}
18+
}
19+
ptr = nullptr;
20+
}
21+
22+
void randomize_data(float *mat, int N)
23+
{
24+
for (int i = 0; i < N; i++)
25+
{
26+
mat[i] = rand() % 100;
27+
}
28+
}
29+
30+
void fill_data(float *mat, int N, float value)
31+
{
32+
for (int i = 0; i < N; i++)
33+
{
34+
mat[i] = value;
35+
}
36+
}
37+
38+
// Copy from https://github.com/tspeterkim/flash-attention-minimal
39+
__global__ void flash_attn_v1_kernel(const float *Q,
40+
const float *K,
41+
const float *V,
42+
const int N,
43+
const int d,
44+
const int Tc,
45+
const int Tr,
46+
const int Bc,
47+
const int Br,
48+
const float softmax_scale,
49+
float *l,
50+
float *m,
51+
float *O)
52+
{
53+
int tx = threadIdx.x;
54+
int bx = blockIdx.x;
55+
int by = blockIdx.y; // batch and head index
56+
57+
// Offset into Q,K,V,O,l,m - different for each batch and head
58+
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
59+
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m
60+
61+
// Define SRAM for Q,K,V,S
62+
extern __shared__ float sram[];
63+
int tile_size = Bc * d; // size of Qi, Kj, Vj
64+
float *Qi = sram;
65+
float *Kj = &sram[tile_size];
66+
float *Vj = &sram[tile_size * 2];
67+
float *S = &sram[tile_size * 3];
68+
69+
// outer loop
70+
for (int j = 0; j < Tc; j++)
71+
{
72+
// Load Kj, Vj from HBM to SRAM
73+
for (int x = 0; x < d; x++)
74+
{
75+
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
76+
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
77+
}
78+
__syncthreads();
79+
80+
for (int i = 0; i < Tr; i++)
81+
{
82+
// Load Qi to SRAM, l and m to registers
83+
for (int x = 0; x < d; x++)
84+
{
85+
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
86+
}
87+
float row_m_prev = m[lm_offset + (Br * i) + tx];
88+
float row_l_prev = l[lm_offset + (Br * i) + tx];
89+
90+
// S = QK^T, row_m = rowmax(S)
91+
float row_m = -INFINITY;
92+
for (int y = 0; y < Bc; y++)
93+
{
94+
float sum = 0;
95+
for (int x = 0; x < d; x++)
96+
{
97+
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
98+
}
99+
sum *= softmax_scale;
100+
S[(Bc * tx) + y] = sum;
101+
102+
if (sum > row_m)
103+
row_m = sum;
104+
}
105+
106+
// P = exp(S - row_m), row_l = rowsum(P)
107+
float row_l = 0;
108+
for (int y = 0; y < Bc; y++)
109+
{
110+
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
111+
row_l += S[(Bc * tx) + y];
112+
}
113+
114+
// Compute new m and l
115+
float row_m_new = max(row_m_prev, row_m);
116+
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
117+
118+
// Write O, l, m to HBM
119+
for (int x = 0; x < d; x++)
120+
{
121+
float pv = 0; // Pij * Vj
122+
for (int y = 0; y < Bc; y++)
123+
{
124+
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
125+
}
126+
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) + (__expf(row_m - row_m_new) * pv));
127+
}
128+
m[lm_offset + (Br * i) + tx] = row_m_new;
129+
l[lm_offset + (Br * i) + tx] = row_l_new;
130+
}
131+
__syncthreads();
132+
}
133+
}
134+
135+
// Naive CPU implementation of attention
136+
void attn_cpu(float *Q,
137+
float *K,
138+
float *V,
139+
int B,
140+
int nh,
141+
int N,
142+
int D,
143+
float softmax_scale,
144+
float *O)
145+
{
146+
// Iterate over batch size
147+
for (int b = 0; b < B; ++b)
148+
{
149+
// Iterate over number of attention heads
150+
for (int h = 0; h < nh; ++h)
151+
{
152+
// Iterate over query tokens (index i)
153+
for (int i = 0; i < N; ++i)
154+
{
155+
// Allocate memory for attention scores for this query token (shape N)
156+
float *scores = (float *)malloc(N * sizeof(float));
157+
if (scores == NULL)
158+
{
159+
fprintf(stderr, "Memory allocation failed\n");
160+
return;
161+
}
162+
163+
// Calculate attention scores between the current query token and all
164+
// key tokens (index j)
165+
for (int j = 0; j < N; ++j)
166+
{
167+
float score = 0.0f;
168+
// Calculate dot product over the dimension D (index d)
169+
for (int d = 0; d < D; ++d)
170+
{
171+
score += Q[((b * nh + h) * N + i) * D + d] *
172+
K[((b * nh + h) * N + j) * D + d];
173+
}
174+
scores[j] = score * softmax_scale; // Use the provided softmax_scale
175+
}
176+
177+
// Apply safe softmax
178+
// Find the maximum score
179+
float max_score = scores[0];
180+
for (int j = 1; j < N; ++j)
181+
{
182+
if (scores[j] > max_score)
183+
{
184+
max_score = scores[j];
185+
}
186+
}
187+
188+
// Calculate exponentiated values and their sum
189+
float sum_exp = 0.0f;
190+
float *weights = (float *)malloc(N * sizeof(float));
191+
if (weights == NULL)
192+
{
193+
fprintf(stderr, "Memory allocation failed\n");
194+
free(scores);
195+
return;
196+
}
197+
for (int j = 0; j < N; ++j)
198+
{
199+
weights[j] = expf(scores[j] - max_score);
200+
sum_exp += weights[j];
201+
}
202+
203+
// Normalize to get attention weights
204+
for (int j = 0; j < N; ++j)
205+
{
206+
weights[j] /= sum_exp;
207+
}
208+
209+
// Calculate the weighted sum of value vectors and store in O
210+
for (int d = 0; d < D; ++d)
211+
{
212+
O[((b * nh + h) * N + i) * D + d] = 0.0f;
213+
for (int j = 0; j < N; ++j)
214+
{
215+
O[((b * nh + h) * N + i) * D + d] +=
216+
weights[j] * V[((b * nh + h) * N + j) * D + d];
217+
}
218+
}
219+
220+
// Free temporary memory
221+
free(scores);
222+
free(weights);
223+
}
224+
}
225+
}
226+
}
227+
228+
int main()
229+
{
230+
const int B = 4; // batch size
231+
const int nh = 8; // head number
232+
const int N = 128; // sequence length
233+
const int D = 64; // embedding dimension
234+
235+
// split kv seq_len to Tc and Q seq_len to Tr
236+
const int Bc = 32;
237+
const int Br = 32;
238+
const int Tc = ceil((float)N / Bc);
239+
const int Tr = ceil((float)N / Br);
240+
241+
const float softmax_scale = 1.0 / sqrt(D);
242+
243+
// Allocate memory
244+
float *Q = (float *)malloc(B * nh * N * D * sizeof(float));
245+
float *K = (float *)malloc(B * nh * N * D * sizeof(float));
246+
float *V = (float *)malloc(B * nh * N * D * sizeof(float));
247+
float *O = (float *)malloc(B * nh * N * D * sizeof(float));
248+
float *O_cpu = (float *)malloc(B * nh * N * D * sizeof(float));
249+
float *l = (float *)malloc(B * nh * N * sizeof(float));
250+
float *m = (float *)malloc(B * nh * N * sizeof(float));
251+
252+
// Initialize data
253+
randomize_data(Q, B * nh * N * D);
254+
randomize_data(K, B * nh * N * D);
255+
randomize_data(V, B * nh * N * D);
256+
fill_data(O, B * nh * N * D, 0.0f);
257+
fill_data(l, B * nh * N, 0.0f);
258+
fill_data(m, B * nh * N, -INFINITY);
259+
260+
// Allocate device memory
261+
float *d_Q, *d_K, *d_V, *d_O, *d_l, *d_m;
262+
cudaMalloc((void **)&d_Q, B * nh * N * D * sizeof(float));
263+
cudaMalloc((void **)&d_K, B * nh * N * D * sizeof(float));
264+
cudaMalloc((void **)&d_V, B * nh * N * D * sizeof(float));
265+
cudaMalloc((void **)&d_O, B * nh * N * D * sizeof(float));
266+
cudaMalloc((void **)&d_l, B * nh * N * sizeof(float));
267+
cudaMalloc((void **)&d_m, B * nh * N * sizeof(float));
268+
269+
// Copy matrices to device
270+
cudaMemcpy(d_Q, Q, B * nh * N * D * sizeof(float), cudaMemcpyHostToDevice);
271+
cudaMemcpy(d_K, K, B * nh * N * D * sizeof(float), cudaMemcpyHostToDevice);
272+
cudaMemcpy(d_V, V, B * nh * N * D * sizeof(float), cudaMemcpyHostToDevice);
273+
cudaMemcpy(d_O, O, B * nh * N * D * sizeof(float), cudaMemcpyHostToDevice);
274+
cudaMemcpy(d_l, l, B * nh * N * sizeof(float), cudaMemcpyHostToDevice);
275+
cudaMemcpy(d_m, m, B * nh * N * sizeof(float), cudaMemcpyHostToDevice);
276+
277+
// Calculate SRAM size needed per block
278+
const int sram_size =
279+
(3 * Bc * D * sizeof(float)) + (Bc * Br * sizeof(float));
280+
int max_sram_size;
281+
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
282+
printf("Max shared memory: %d, requested shared memory: %d \n",
283+
max_sram_size,
284+
sram_size);
285+
286+
dim3 grid_dim(B, nh); // batch_size x num_heads
287+
dim3 block_dim(Bc); // Bc threads per block
288+
289+
// Launch kernel
290+
flash_attn_v1_kernel<<<grid_dim, block_dim, sram_size>>>(
291+
d_Q, d_K, d_V, N, D, Tc, Tr, Bc, Br, softmax_scale, d_l, d_m, d_O);
292+
293+
// Copy result to host
294+
cudaMemcpy(O, d_O, B * nh * N * D * sizeof(float), cudaMemcpyDeviceToHost);
295+
296+
// Run cpu flash attention
297+
attn_cpu(Q, K, V, B, nh, N, D, softmax_scale, O_cpu);
298+
299+
// Check results
300+
float max_diff = 0.0f;
301+
for (int i = 0; i < B * nh * N * D; i++)
302+
{
303+
max_diff = fmaxf(max_diff, fabsf(O[i] - O_cpu[i]));
304+
}
305+
306+
if (max_diff < 0.0001)
307+
{
308+
printf("Results are correct! ");
309+
}
310+
else
311+
{
312+
printf("Results are incorrect! Max diff: %f\n", max_diff);
313+
}
314+
315+
// Free memory
316+
free_resource(Q, 0);
317+
free_resource(K, 0);
318+
free_resource(V, 0);
319+
free_resource(O, 0);
320+
free_resource(O_cpu, 0);
321+
free_resource(l, 0);
322+
free_resource(m, 0);
323+
free_resource(d_Q);
324+
free_resource(d_K);
325+
free_resource(d_V);
326+
free_resource(d_O);
327+
free_resource(d_l);
328+
free_resource(d_m);
329+
330+
return 0;
331+
}

0 commit comments

Comments
 (0)