Skip to content

Commit ba1ceb3

Browse files
authored
vulkan: fix noncontig check for mat_mul_id splitting (ggml-org#14683)
* vulkan: fix noncontig check for mat_mul_id splitting Remove supports_op check for > 4096 (splitting fixes this) * vulkan: fix batched matmul dequant for Q*_K
1 parent 10a0351 commit ba1ceb3

File tree

6 files changed

+6
-10
lines changed

6 files changed

+6
-10
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4922,7 +4922,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
49224922
return
49234923
tensor->nb[0] == ggml_type_size(tensor->type) &&
49244924
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
4925-
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
4925+
(tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
49264926
}
49274927

49284928
static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
@@ -10356,10 +10356,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1035610356
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
1035710357
return false;
1035810358
}
10359-
// Check against size of shared memory variable
10360-
if (op->src[2]->ne[0] > 4096) {
10361-
return false;
10362-
}
1036310359
}
1036410360
switch (src0_type) {
1036510361
case GGML_TYPE_F32:

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint i = gl_WorkGroupID.x * 256 + wgy;
13-
if (i >= p.M * p.K / QUANT_K) {
13+
if (i >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
13-
if (i >= p.M * p.K / QUANT_K) {
13+
if (i >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint ib = gl_WorkGroupID.x * 256 + wgy;
13-
if (ib >= p.M * p.K / QUANT_K) {
13+
if (ib >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint ib = gl_WorkGroupID.x * 256 + wgy;
13-
if (ib >= p.M * p.K / QUANT_K) {
13+
if (ib >= p.nel / QUANT_K) {
1414
return;
1515
}
1616

ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
1212
const uint i = gl_WorkGroupID.x * 256 + wgy;
13-
if (i >= p.M * p.K / QUANT_K) {
13+
if (i >= p.nel / QUANT_K) {
1414
return;
1515
}
1616
const uint tid = gl_LocalInvocationID.x;

0 commit comments

Comments
 (0)