Skip to content

Commit cad1eec

Browse files
JohannesGaesslerqnixsynapse
authored andcommitted
CUDA: broadcasting for FlashAttention mask (ggml-org#14500)
1 parent 55d752d commit cad1eec

File tree

7 files changed

+43
-25
lines changed

7 files changed

+43
-25
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
3232
const int ne12,
3333
const int ne13,
3434
const int ne31,
35+
const int ne32,
3536
const int nb31,
37+
const int nb32,
3638
const int nb01,
3739
const int nb02,
3840
const int nb03,
@@ -851,7 +853,8 @@ void launch_fattn(
851853
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
852854
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
853855
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
854-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
856+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
855858
Q->nb[1], Q->nb[2], Q->nb[3],
856859
nb11, nb12, nb13,
857860
nb21, nb22, nb23,

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
12231223
const int ne12,
12241224
const int ne13,
12251225
const int ne31,
1226+
const int ne32,
12261227
const int nb31,
1228+
const int nb32,
12271229
const int nb01,
12281230
const int nb02,
12291231
const int nb03,
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
12881290

12891291
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
12901292
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1291-
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1293+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
12921295
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
12931296

12941297
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
13271330

13281331
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
13291332
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1330-
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
1333+
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
13311335
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
13321336

13331337
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
13481352
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
13491353
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
13501354
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
1351-
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
1352-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
1355+
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
1356+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
13531357
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
13541358
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
13551359
GGML_UNUSED(ne2); GGML_UNUSED(ne3);

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
88
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9-
__launch_bounds__(nwarps*WARP_SIZE, 1)
9+
__launch_bounds__(nwarps*WARP_SIZE, 2)
1010
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
static __global__ void flash_attn_tile_ext_f16(
1212
const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
3030
const int ne12,
3131
const int ne13,
3232
const int ne31,
33+
const int ne32,
3334
const int nb31,
35+
const int nb32,
3436
const int nb01,
3537
const int nb02,
3638
const int nb03,
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
6466
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
6567
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
6668
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
67-
const half * maskh = (const half *) mask + ne11*ic0;
69+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
6870

6971
const int stride_KV2 = nb11 / sizeof(half2);
7072

@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
288290
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
289291
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
290292
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
291-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
292-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
293+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
294+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
293295
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
294296
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
295297
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
88
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9-
__launch_bounds__(nwarps*WARP_SIZE, 1)
9+
__launch_bounds__(nwarps*WARP_SIZE, 2)
1010
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
static __global__ void flash_attn_tile_ext_f32(
1212
const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32(
3030
const int ne12,
3131
const int ne13,
3232
const int ne31,
33+
const int ne32,
3334
const int nb31,
35+
const int nb32,
3436
const int nb01,
3537
const int nb02,
3638
const int nb03,
@@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32(
5860
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
5961
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
6062
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
61-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
62-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
63+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
64+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
6365
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
6466
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
6567
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32(
7678
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
7779
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
7880
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
79-
const half * maskh = (const half *) mask + ne11*ic0;
81+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
8082

8183
const int stride_KV2 = nb11 / sizeof(half2);
8284

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
2727
const int ne12,
2828
const int ne13,
2929
const int ne31,
30+
const int ne32,
3031
const int nb31,
32+
const int nb32,
3133
const int nb01,
3234
const int nb02,
3335
const int nb03,
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
6870
K += nb12*(blockIdx.z / gqa_ratio);
6971
V += nb22*(blockIdx.z / gqa_ratio);
7072

71-
const half * maskh = (const half *) mask + ne11*ic0;
73+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
7274

7375
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
7476
const half slopeh = __float2half(slopef);
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
342344
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
343345
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
344346
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
345-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
346-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
347+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
348+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
347349
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
348350
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
349351
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32(
2727
const int ne12,
2828
const int ne13,
2929
const int ne31,
30+
const int ne32,
3031
const int nb31,
32+
const int nb32,
3133
const int nb01,
3234
const int nb02,
3335
const int nb03,
@@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32(
5153
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
5254
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
5355
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
54-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
55-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
56+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
57+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
5658
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
5759
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
5860
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32(
7981
Q += nb02* blockIdx.z + nb01*ic0;
8082
K += nb12*(blockIdx.z / gqa_ratio);
8183
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
82-
const half * maskh = (const half *) mask + ne11*ic0;
84+
85+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
8386

8487
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
8588

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16(
4646
const int ne12,
4747
const int ne13,
4848
const int ne31,
49+
const int ne32,
4950
const int nb31,
51+
const int nb32,
5052
const int nb01,
5153
const int nb02,
5254
const int nb03,
@@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16(
9496
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
9597

9698
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
97-
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
98-
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
99-
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
100-
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
101-
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
99+
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
100+
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
101+
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
102+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
103+
const half2 * mask2 = (const half2 *) maskh;
102104

103105
const int stride_Q = nb01 / sizeof(float);
104106
const int stride_KV = nb11 / sizeof(half);
@@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16(
440442
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
441443
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
442444
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
443-
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
445+
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
444446
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
445447
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
446448
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

0 commit comments

Comments
 (0)