Skip to content

Commit b3ad3a0

Browse files
authored
vulkan: support SET_ROWS (#14587)
* vulkan: support SET_ROWS Add variants of the copy_to_quant shader that do the SET_ROWS operation. Change these shaders to spread the work across the workgroup. The memory access pattern is probably not great (one thread per quant block), but should be fine for now. * vulkan: optimize set_rows Larger workgroups for non-quant types. Set "norepeat" (there is manual repeat logic). Use fastmod.
1 parent 98197e5 commit b3ad3a0

File tree

3 files changed

+164
-22
lines changed

3 files changed

+164
-22
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ struct vk_device_struct {
437437
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
438438
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
439439
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
440+
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
440441
vk_pipeline pipeline_norm_f32;
441442
vk_pipeline pipeline_group_norm_f32;
442443
vk_pipeline pipeline_rms_norm_f32;
@@ -2749,19 +2750,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
27492750
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
27502751

27512752
if (device->float_controls_rte_fp16) {
2752-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2753-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2754-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2755-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2756-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2757-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2753+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2754+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2755+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2756+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2757+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2758+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
27582759
} else {
2759-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2760-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2761-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2762-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2763-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2764-
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2760+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2761+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2762+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2763+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2764+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2765+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2766+
}
2767+
2768+
if (device->float_controls_rte_fp16) {
2769+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2770+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2771+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2772+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2773+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2774+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2775+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2776+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2777+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2778+
} else {
2779+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2780+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2781+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2782+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2783+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2784+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2785+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2786+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2787+
ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
27652788
}
27662789

27672790
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@@ -6527,6 +6550,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65276550
case GGML_OP_CONT:
65286551
case GGML_OP_DUP:
65296552
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
6553+
case GGML_OP_SET_ROWS:
6554+
return ctx->device->pipeline_set_rows[dst->type];
65306555
case GGML_OP_SILU_BACK:
65316556
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
65326557
return ctx->device->pipeline_silu_back_f32;
@@ -6765,6 +6790,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
67656790
case GGML_OP_RMS_NORM:
67666791
case GGML_OP_CONV_2D_DW:
67676792
case GGML_OP_IM2COL:
6793+
case GGML_OP_SET_ROWS:
67686794
return true;
67696795
default:
67706796
return false;
@@ -7078,6 +7104,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
70787104
ne *= ggml_type_size(src0->type) / 2;
70797105
}
70807106
}
7107+
// copy_to_quant has block size of 32, and each thread does QUANT_K elements.
7108+
// Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
7109+
// So divide by block size here before splitting into 512x512 groups.
7110+
if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7111+
ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7112+
}
70817113
if (ne > 262144) {
70827114
elements = { 512, 512, CEIL_DIV(ne, 262144) };
70837115
} else if (ne > 512) {
@@ -7086,6 +7118,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
70867118
elements = { ne, 1, 1 };
70877119
}
70887120
} break;
7121+
case GGML_OP_SET_ROWS:
7122+
{
7123+
uint32_t ne = ggml_nelements(src0);
7124+
if (ggml_is_quantized(dst->type)) {
7125+
// quants run 32 threads each doing QUANT_K elements
7126+
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
7127+
} else {
7128+
// scalar types do one element per thread, running 512 threads
7129+
ne = CEIL_DIV(ne, 512);
7130+
}
7131+
if (ne > 262144) {
7132+
elements = { 512, 512, CEIL_DIV(ne, 262144) };
7133+
} else if (ne > 512) {
7134+
elements = { 512, CEIL_DIV(ne, 512), 1 };
7135+
} else {
7136+
elements = { ne, 1, 1 };
7137+
}
7138+
}
7139+
break;
70897140
default:
70907141
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
70917142
break;
@@ -7648,6 +7699,21 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
76487699
}, dryrun);
76497700
}
76507701

7702+
static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7703+
const uint32_t src0_type_size = ggml_type_size(src0->type);
7704+
const uint32_t src1_type_size = ggml_type_size(src1->type);
7705+
const uint32_t dst_type_size = ggml_type_size(dst->type);
7706+
7707+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7708+
(uint32_t)ggml_nelements(src0),
7709+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7710+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7711+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7712+
0,
7713+
0.0f, 0.0f, 0,
7714+
}, dryrun);
7715+
}
7716+
76517717
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
76527718
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
76537719
}
@@ -8968,6 +9034,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
89689034
case GGML_OP_CLAMP:
89699035
case GGML_OP_PAD:
89709036
case GGML_OP_CPY:
9037+
case GGML_OP_SET_ROWS:
89719038
case GGML_OP_CONT:
89729039
case GGML_OP_DUP:
89739040
case GGML_OP_SILU_BACK:
@@ -9034,6 +9101,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
90349101
case GGML_OP_CLAMP:
90359102
case GGML_OP_PAD:
90369103
case GGML_OP_CPY:
9104+
case GGML_OP_SET_ROWS:
90379105
case GGML_OP_CONT:
90389106
case GGML_OP_DUP:
90399107
case GGML_OP_SILU_BACK:
@@ -9142,6 +9210,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
91429210
case GGML_OP_DUP:
91439211
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
91449212

9213+
break;
9214+
case GGML_OP_SET_ROWS:
9215+
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9216+
91459217
break;
91469218
case GGML_OP_SILU_BACK:
91479219
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9357,6 +9429,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
93579429
case GGML_OP_CLAMP:
93589430
case GGML_OP_PAD:
93599431
case GGML_OP_CPY:
9432+
case GGML_OP_SET_ROWS:
93609433
case GGML_OP_CONT:
93619434
case GGML_OP_DUP:
93629435
case GGML_OP_SILU_BACK:
@@ -10422,9 +10495,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1042210495
} break;
1042310496
case GGML_OP_SET_ROWS:
1042410497
{
10425-
// TODO: add support
10426-
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
10427-
return false;
10498+
switch (op->type) {
10499+
case GGML_TYPE_F32:
10500+
case GGML_TYPE_F16:
10501+
case GGML_TYPE_BF16:
10502+
case GGML_TYPE_Q4_0:
10503+
case GGML_TYPE_Q4_1:
10504+
case GGML_TYPE_Q5_0:
10505+
case GGML_TYPE_Q5_1:
10506+
case GGML_TYPE_Q8_0:
10507+
case GGML_TYPE_IQ4_NL:
10508+
return true;
10509+
default:
10510+
return false;
10511+
}
1042810512
} break;
1042910513
case GGML_OP_CONT:
1043010514
case GGML_OP_CPY:
@@ -11039,6 +11123,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1103911123
} else {
1104011124
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
1104111125
}
11126+
} else if (tensor->op == GGML_OP_SET_ROWS) {
11127+
tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
1104211128
} else if (tensor->op == GGML_OP_CONT) {
1104311129
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1104411130
} else if (tensor->op == GGML_OP_RESHAPE) {

0 commit comments

Comments
 (0)