diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 075f14a49e9ac..9122fca6cf99f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup( const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int channel = kbc0 / (iter_k*iter_j); - const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. if (jt*ncols1 + j >= ne01) { return; } - dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { - VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; - dst += D * gridDim.z*blockIdx.x; + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + const int ne01 = gridDim.x; + const int ne02 = gridDim.y; + + const int col = blockIdx.x; + const int head = blockIdx.y; + const int sequence = blockIdx.z; + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; const int tid = threadIdx.x; __builtin_assume(tid < D); extern __shared__ float2 meta[]; for (int i = tid; i < 2*parallel_blocks; i += D) { - ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } __syncthreads(); @@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; + dst[tid] = VKQ_numerator / VKQ_denominator; } [[noreturn]] @@ -705,8 +723,6 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - GGML_ASSERT(Q->ne[3] == 1); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -853,8 +869,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, - mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, @@ -869,11 +885,11 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); flash_attn_combine_results diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 709589854f0af..6fa2e77299eb0 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16( constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16( int kb0_start = kbc % iter_k; int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16( return; } - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 0c967f178e7b1..1f141328845a4 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16( half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum((float)kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val /= __half2half2(kqsum_j); } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val); + dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else @@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 908c76dbdd270..a4965583cef1c 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32( float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val.x /= kqsum_j; dst_val.y /= kqsum_j; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y; + dst2[j_dst_unrolled*(D/2) + i0] = dst_val; } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e78fb181919fd..b2d469938abf2 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index b2f1724c95588..405b6f5106ea0 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; @@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index c95ca7b1f285f..741b8781d29f5 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); @@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16( if (ic0 + j_VKQ >= ne01) { return; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; float KQ_rowsum_j; if (std::is_same::value) { @@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; @@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16( if (gridDim.y == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val; + dst[j_dst_unrolled*D + i] = dst_val; } if (gridDim.y == 1 || threadIdx.x != 0) { @@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16( dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } dst_meta_val.y = KQ_rowsum_j; - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; + dst_meta[j_dst_unrolled] = dst_meta_val; } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); + GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 72406f0af3622..ce2180877a386 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3398,12 +3398,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } - // TODO: support broadcast - // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but - // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 - if (op->src[0]->ne[3] != 1) { - return false; - } if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { return false; } @@ -3416,6 +3410,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { return true; } + if (op->src[3] && op->src[3]->ne[2] != 1) { + return false; + } return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a743ad6e8c10d..fadcda9dba43d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4256,7 +4256,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * m = nullptr; if (mask) { - m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]); + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]); ggml_set_name(m, "m"); }