Skip to content

Commit 064cdc2

Browse files
authored
vulkan : fix Qantized Mat-Vec Mul on AMD GPUs for ncols < 64 (#8855)
* Fix Vulkan mul mat vec invalid results when ncols < warp size * Only run backend ops mul mat vec block size test if block size not already covered
1 parent 5587e57 commit 064cdc2

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

ggml/src/vulkan-shaders/mul_mat_vec.comp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,22 @@ void main() {
1616
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
1717
const uint tid = gl_LocalInvocationID.x;
1818

19+
// There are not enough cols to use all threads
20+
if (tid >= p.ncols) {
21+
return;
22+
}
23+
24+
const uint block_size = min(p.ncols, BLOCK_SIZE);
25+
1926
uint a_offset, b_offset, d_offset;
2027
get_offsets(a_offset, b_offset, d_offset);
2128

2229
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
2330

2431
tmp[tid] = FLOAT_TYPE(0.0f);
2532

26-
[[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) {
27-
const uint col = i*BLOCK_SIZE + 2*tid;
33+
[[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
34+
const uint col = i*block_size + 2*tid;
2835
const uint ib = (row*p.ncols + col)/QUANT_K; // block index
2936
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
3037
const uint iybs = col - col%QUANT_K; // y block start index
@@ -38,7 +45,7 @@ void main() {
3845

3946
// sum up partial sums and write back result
4047
barrier();
41-
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
48+
[[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
4249
if (tid < s) {
4350
tmp[tid] += tmp[tid + s];
4451
}

tests/test-backend-ops.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,9 +2271,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22712271

22722272
for (ggml_type type_a : other_types) {
22732273
for (ggml_type type_b : {GGML_TYPE_F32}) {
2274-
2275-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), { 1, 1}, {1, 1}));
2276-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
2274+
if (ggml_blck_size(type_a) != 256) {
2275+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1}));
2276+
}
2277+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
22772278
}
22782279
}
22792280

0 commit comments

Comments
 (0)