Skip to content

Commit 6efcd65

Browse files
authored
vulkan: optimize flash attention split_k_reduce (ggml-org#14554)
* vulkan: allow FA split_k with smaller KV values * vulkan: spread split_k_reduce work across more threads k_num can get rather large. Use the whole workgroup to reduce the M/L values. Launch a thread for each element in the HSV dimension of the output. Helps a lot for large HSV (like deepseek).
1 parent 699f439 commit 6efcd65

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
27062706
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
27072707

27082708
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2709-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2709+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
27102710
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
27112711

27122712
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62526252
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
62536253

62546254
// Try to use split_k when KV is large enough to be worth the overhead
6255-
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6255+
if (workgroups_x == 1 && shader_core_count > 0) {
62566256
// Try to run two workgroups per SM.
62576257
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
62586258
if (split_k > 1) {
62596259
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
62606260
// of "align", so recompute split_k based on that.
6261-
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
6261+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
62626262
split_k = CEIL_DIV(KV, split_kv);
62636263
workgroups_x = split_k;
62646264
}
@@ -6392,7 +6392,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63926392
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
63936393
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
63946394
},
6395-
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
6395+
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
63966396
} else {
63976397
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
63986398
{

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
#extension GL_EXT_control_flow_attributes : enable
44

5-
#define BLOCK_SIZE 32
5+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
66

7-
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
7+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
88

99
layout (binding = 0) readonly buffer A {float data_a[];};
1010
layout (binding = 1) writeonly buffer D {float data_d[];};
@@ -16,6 +16,8 @@ layout (push_constant) uniform parameter {
1616
uint k_num;
1717
} p;
1818

19+
shared float tmpsh[BLOCK_SIZE];
20+
1921
void main() {
2022
// Each workgroup handles a row
2123
const uint n = gl_WorkGroupID.x;
@@ -32,23 +34,51 @@ void main() {
3234

3335
// Compute the max m value for the row
3436
float m_max = -1.0/0.0;
35-
[[unroll]] for (uint k = 0; k < k_num; ++k) {
36-
float m = data_a[m_offset + k * lm_stride];
37+
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
38+
float m = data_a[m_offset + (k + tid) * lm_stride];
3739
m_max = max(m_max, m);
3840
}
3941

42+
// reduce across the workgroup
43+
tmpsh[tid] = m_max;
44+
barrier();
45+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
46+
if (tid < s) {
47+
m_max = max(m_max, tmpsh[tid + s]);
48+
tmpsh[tid] = m_max;
49+
}
50+
barrier();
51+
}
52+
m_max = tmpsh[0];
53+
54+
barrier();
55+
4056
// Compute L based on m_max
4157
float L = 0;
42-
[[unroll]] for (uint k = 0; k < k_num; ++k) {
43-
float l = data_a[l_offset + k * lm_stride];
44-
float m = data_a[m_offset + k * lm_stride];
58+
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
59+
float l = data_a[l_offset + (k + tid) * lm_stride];
60+
float m = data_a[m_offset + (k + tid) * lm_stride];
4561
L += exp(m - m_max) * l;
4662
}
4763

64+
// reduce across the workgroup
65+
tmpsh[tid] = L;
66+
barrier();
67+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
68+
if (tid < s) {
69+
L += tmpsh[tid + s];
70+
tmpsh[tid] = L;
71+
}
72+
barrier();
73+
}
74+
L = tmpsh[0];
75+
4876
L = 1.0 / L;
4977

78+
// D dimension is split across workgroups in the y dimension
79+
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
5080
// Scale and sum the O contributions based on m_max and store the result to memory
51-
for (uint d = tid; d < D; d += BLOCK_SIZE) {
81+
if (d < D) {
5282
float O = 0.0;
5383
[[unroll]] for (uint k = 0; k < k_num; ++k) {
5484
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;

0 commit comments

Comments
 (0)