Skip to content

Commit 98197e5

Browse files
authored
vulkan: optimizations for deepseek prompt processing (#14555)
* vulkan: allow unclamped loads in coopmat2 mul_mat_id shader * vulkan: increase coopmat2 mul_mat_id tile size * vulkan: optimize mat_mul_id row_ids search to batch loads, and port to coopmat1 path * vulkan: use smaller FA row size when head size is large. applies to both scalar and CM2 paths (CM1 isn't used due to shared memory limits)
1 parent f5e96b3 commit 98197e5

File tree

3 files changed

+104
-18
lines changed

3 files changed

+104
-18
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
17351735
// number of rows/cols for flash attention shader
17361736
static constexpr uint32_t flash_attention_num_small_rows = 32;
17371737
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1738-
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1738+
1739+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
1740+
if (hsv >= 512) {
1741+
return 2;
1742+
} else {
1743+
return 8;
1744+
}
1745+
}
17391746

17401747
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
17411748
// 128 threads split into four subgroups, each subgroup does 1/4
@@ -1760,7 +1767,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
17601767
if (small_rows) {
17611768
return {scalar_flash_attention_num_small_rows, 64};
17621769
} else {
1763-
return {scalar_flash_attention_num_large_rows, 32};
1770+
return {get_fa_scalar_num_large_rows(hsv), 32};
17641771
}
17651772
}
17661773

@@ -1779,7 +1786,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
17791786

17801787
// small cols to reduce register count
17811788
if (ggml_is_quantized(type) || hsk >= 256) {
1782-
return {64, 32};
1789+
if (hsk >= 512) {
1790+
return {32, 32};
1791+
} else {
1792+
return {64, 32};
1793+
}
17831794
}
17841795
return {64, 64};
17851796
}
@@ -1821,7 +1832,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
18211832
const uint32_t warps = warptile[0] / warptile[10];
18221833

18231834
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1824-
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
1835+
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
18251836
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
18261837

18271838
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@@ -1946,10 +1957,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
19461957
s_mmq_wg_denoms_k = { 32, 32, 1 };
19471958

19481959
// spec constants and tile sizes for quant matmul_id
1949-
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
1960+
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
19501961
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
19511962
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
1952-
l_mmqid_wg_denoms = { 128, 64, 1 };
1963+
l_mmqid_wg_denoms = { 128, 128, 1 };
19531964
m_mmqid_wg_denoms = { 128, 64, 1 };
19541965
s_mmqid_wg_denoms = { 128, 64, 1 };
19551966

@@ -6048,7 +6059,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
60486059
// Needs to be kept up to date on shader changes
60496060
GGML_UNUSED(hsv);
60506061
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
6051-
const uint32_t Br = scalar_flash_attention_num_large_rows;
6062+
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
60526063
const uint32_t Bc = scalar_flash_attention_Bc;
60536064

60546065
const uint32_t tmpsh = wg_size * sizeof(float);
@@ -6173,7 +6184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61736184
case FA_SCALAR:
61746185
case FA_COOPMAT1:
61756186
// We may switch from coopmat1 to scalar, so use the scalar limit for both
6176-
max_gqa = scalar_flash_attention_num_large_rows;
6187+
max_gqa = get_fa_scalar_num_large_rows(HSV);
61776188
break;
61786189
case FA_COOPMAT2:
61796190
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#extension GL_KHR_cooperative_matrix : enable
1919
#extension GL_KHR_memory_scope_semantics : enable
2020
#extension GL_KHR_shader_subgroup_basic : enable
21+
#extension GL_KHR_shader_subgroup_ballot : enable
2122
#endif
2223

2324
#ifdef MUL_MAT_ID
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
104105

105106
#ifdef MUL_MAT_ID
106107
shared u16vec2 row_ids[4096];
108+
uint _ne1;
109+
#ifdef COOPMAT
110+
shared uint _ne1_sh;
111+
#endif
107112
#endif // MUL_MAT_ID
108113

109114
#define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -172,7 +177,47 @@ void main() {
172177
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
173178

174179
#ifdef MUL_MAT_ID
175-
uint _ne1 = 0;
180+
#ifdef COOPMAT
181+
// Spread the search across all elements in the first subgroup
182+
if (gl_SubgroupID == 0) {
183+
_ne1 = 0;
184+
uint num_elements = p.nei1 * p.nei0;
185+
186+
uint ids[16];
187+
uint iter = 0;
188+
189+
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
190+
// prefetch up to 16 elements
191+
if (iter == 0) {
192+
[[unroll]] for (uint k = 0; k < 16; ++k) {
193+
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
194+
bool in_range = i < num_elements;
195+
uint ii1 = i / p.nei0;
196+
uint ii0 = i % p.nei0;
197+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
198+
}
199+
}
200+
uint i = j + gl_SubgroupInvocationID;
201+
bool in_range = i < num_elements;
202+
uint ii1 = i / p.nei0;
203+
uint ii0 = i % p.nei0;
204+
uint id = ids[iter++];
205+
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
206+
uint idx = subgroupBallotExclusiveBitCount(ballot);
207+
if (in_range && id == expert_idx) {
208+
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
209+
}
210+
_ne1 += subgroupBallotBitCount(ballot);
211+
iter &= 15;
212+
}
213+
_ne1_sh = _ne1;
214+
}
215+
216+
barrier();
217+
218+
_ne1 = _ne1_sh;
219+
#else
220+
_ne1 = 0;
176221
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
177222
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
178223
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@@ -183,6 +228,7 @@ void main() {
183228
}
184229

185230
barrier();
231+
#endif
186232

187233
// Workgroup has no work
188234
if (ic * BN >= _ne1) return;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,32 @@ void main() {
162162
_ne1 = 0;
163163
uint num_elements = p.nei1 * p.nei0;
164164

165-
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
165+
uint ids[16];
166+
uint iter = 0;
167+
168+
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
169+
// prefetch up to 16 elements
170+
if (iter == 0) {
171+
[[unroll]] for (uint k = 0; k < 16; ++k) {
172+
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
173+
bool in_range = i < num_elements;
174+
uint ii1 = i / p.nei0;
175+
uint ii0 = i % p.nei0;
176+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
177+
}
178+
}
179+
uint i = j + gl_SubgroupInvocationID;
166180
bool in_range = i < num_elements;
167-
uint ii0 = i % p.nei0;
168181
uint ii1 = i / p.nei0;
169-
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
182+
uint ii0 = i % p.nei0;
183+
uint id = ids[iter++];
170184
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
171185
uint idx = subgroupBallotExclusiveBitCount(ballot);
172186
if (in_range && id == expert_idx) {
173187
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
174188
}
175189
_ne1 += subgroupBallotBitCount(ballot);
190+
iter &= 15;
176191
}
177192
_ne1_sh = _ne1;
178193
}
@@ -414,17 +429,31 @@ void main() {
414429
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
415430
}
416431

417-
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418-
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
432+
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
433+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
434+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
419435

420-
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
436+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
421437
#ifdef MUL_MAT_ID
422-
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
438+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
423439
#else
424-
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
440+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
425441
#endif
426442

427-
sum = coopMatMulAdd(mat_a, mat_b, sum);
443+
sum = coopMatMulAdd(mat_a, mat_b, sum);
444+
} else {
445+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
446+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
447+
448+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
449+
#ifdef MUL_MAT_ID
450+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
451+
#else
452+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
453+
#endif
454+
455+
sum = coopMatMulAdd(mat_a, mat_b, sum);
456+
}
428457
}
429458

430459
// Convert from ACC_TYPE to D_TYPE

0 commit comments

Comments
 (0)