Skip to content

Vulkan: iquants and flash attention split_k_reduce improvement #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions ggml/src/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1717,7 +1717,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;

static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
if (hsv >= 512) {
return 2;
} else {
return 8;
}
}

// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
// 128 threads split into four subgroups, each subgroup does 1/4
Expand All @@ -1742,7 +1749,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
return {scalar_flash_attention_num_large_rows, 32};
return {get_fa_scalar_num_large_rows(hsv), 32};
}
}

Expand All @@ -1761,7 +1768,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3

// small cols to reduce register count
if (ggml_is_quantized(type) || hsk >= 256) {
return {64, 32};
if (hsk >= 512) {
return {32, 32};
} else {
return {64, 32};
}
}
return {64, 64};
}
Expand Down Expand Up @@ -1803,7 +1814,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10];

const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;

const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
Expand Down Expand Up @@ -1928,10 +1939,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 32, 1 };

// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_mmqid_wg_denoms = { 128, 64, 1 };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };

Expand Down Expand Up @@ -2688,7 +2699,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);

for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
Expand Down Expand Up @@ -5994,7 +6005,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
const uint32_t Bc = scalar_flash_attention_Bc;

const uint32_t tmpsh = wg_size * sizeof(float);
Expand Down Expand Up @@ -6118,7 +6129,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = scalar_flash_attention_num_large_rows;
max_gqa = get_fa_scalar_num_large_rows(HSV);
break;
case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
Expand Down Expand Up @@ -6197,13 +6208,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;

// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
if (workgroups_x == 1 && shader_core_count > 0) {
// Try to run two workgroups per SM.
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
split_k = CEIL_DIV(KV, split_kv);
workgroups_x = split_k;
}
Expand Down Expand Up @@ -6336,7 +6347,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, HSV, 1 });
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
Expand Down
46 changes: 38 additions & 8 deletions ggml/src/vulkan-shaders/flash_attn_split_k_reduce.comp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

#extension GL_EXT_control_flow_attributes : enable

#define BLOCK_SIZE 32
layout(constant_id = 0) const uint BLOCK_SIZE = 32;

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
Expand All @@ -15,6 +15,8 @@ layout (push_constant) uniform parameter {
uint k_num;
} p;

shared float tmpsh[BLOCK_SIZE];

void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
Expand All @@ -30,23 +32,51 @@ void main() {

// Compute the max m value for the row
float m_max = -1.0/0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float m = data_a[m_offset + k * lm_stride];
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float m = data_a[m_offset + (k + tid) * lm_stride];
m_max = max(m_max, m);
}

// reduce across the workgroup
tmpsh[tid] = m_max;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
m_max = max(m_max, tmpsh[tid + s]);
tmpsh[tid] = m_max;
}
barrier();
}
m_max = tmpsh[0];

barrier();

// Compute L based on m_max
float L = 0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float l = data_a[l_offset + k * lm_stride];
float m = data_a[m_offset + k * lm_stride];
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float l = data_a[l_offset + (k + tid) * lm_stride];
float m = data_a[m_offset + (k + tid) * lm_stride];
L += exp(m - m_max) * l;
}

// reduce across the workgroup
tmpsh[tid] = L;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
L += tmpsh[tid + s];
tmpsh[tid] = L;
}
barrier();
}
L = tmpsh[0];

L = 1.0 / L;

// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
// Scale and sum the O contributions based on m_max and store the result to memory
for (uint d = tid; d < D; d += BLOCK_SIZE) {
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * k + D * n + d;
Expand Down
Loading