@@ -437,6 +437,7 @@ struct vk_device_struct {
437
437
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
438
438
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
439
439
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
440
+ vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
440
441
vk_pipeline pipeline_norm_f32;
441
442
vk_pipeline pipeline_group_norm_f32;
442
443
vk_pipeline pipeline_rms_norm_f32;
@@ -2749,19 +2750,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
2749
2750
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2750
2751
2751
2752
if (device->float_controls_rte_fp16) {
2752
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0) , 1, 1}, {}, 1);
2753
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1) , 1, 1}, {}, 1);
2754
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0) , 1, 1}, {}, 1);
2755
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1) , 1, 1}, {}, 1);
2756
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0) , 1, 1}, {}, 1);
2757
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL) , 1, 1}, {}, 1);
2753
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32 , 1, 1}, {}, 1);
2754
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32 , 1, 1}, {}, 1);
2755
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32 , 1, 1}, {}, 1);
2756
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32 , 1, 1}, {}, 1);
2757
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32 , 1, 1}, {}, 1);
2758
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32 , 1, 1}, {}, 1);
2758
2759
} else {
2759
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2760
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2761
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2762
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2763
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2764
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2760
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2761
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2762
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2763
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2764
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2765
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
2766
+ }
2767
+
2768
+ if (device->float_controls_rte_fp16) {
2769
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2770
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2771
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2772
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2773
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2774
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2775
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2776
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2777
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2778
+ } else {
2779
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2780
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2781
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2782
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2783
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2784
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2785
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2786
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2787
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
2765
2788
}
2766
2789
2767
2790
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
@@ -6527,6 +6550,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6527
6550
case GGML_OP_CONT:
6528
6551
case GGML_OP_DUP:
6529
6552
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
6553
+ case GGML_OP_SET_ROWS:
6554
+ return ctx->device->pipeline_set_rows[dst->type];
6530
6555
case GGML_OP_SILU_BACK:
6531
6556
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6532
6557
return ctx->device->pipeline_silu_back_f32;
@@ -6765,6 +6790,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6765
6790
case GGML_OP_RMS_NORM:
6766
6791
case GGML_OP_CONV_2D_DW:
6767
6792
case GGML_OP_IM2COL:
6793
+ case GGML_OP_SET_ROWS:
6768
6794
return true;
6769
6795
default:
6770
6796
return false;
@@ -7078,6 +7104,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7078
7104
ne *= ggml_type_size(src0->type) / 2;
7079
7105
}
7080
7106
}
7107
+ // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
7108
+ // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
7109
+ // So divide by block size here before splitting into 512x512 groups.
7110
+ if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7111
+ ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
7112
+ }
7081
7113
if (ne > 262144) {
7082
7114
elements = { 512, 512, CEIL_DIV(ne, 262144) };
7083
7115
} else if (ne > 512) {
@@ -7086,6 +7118,25 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7086
7118
elements = { ne, 1, 1 };
7087
7119
}
7088
7120
} break;
7121
+ case GGML_OP_SET_ROWS:
7122
+ {
7123
+ uint32_t ne = ggml_nelements(src0);
7124
+ if (ggml_is_quantized(dst->type)) {
7125
+ // quants run 32 threads each doing QUANT_K elements
7126
+ ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
7127
+ } else {
7128
+ // scalar types do one element per thread, running 512 threads
7129
+ ne = CEIL_DIV(ne, 512);
7130
+ }
7131
+ if (ne > 262144) {
7132
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
7133
+ } else if (ne > 512) {
7134
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
7135
+ } else {
7136
+ elements = { ne, 1, 1 };
7137
+ }
7138
+ }
7139
+ break;
7089
7140
default:
7090
7141
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
7091
7142
break;
@@ -7648,6 +7699,21 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
7648
7699
}, dryrun);
7649
7700
}
7650
7701
7702
+ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7703
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7704
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7705
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7706
+
7707
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, {
7708
+ (uint32_t)ggml_nelements(src0),
7709
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7710
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7711
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7712
+ 0,
7713
+ 0.0f, 0.0f, 0,
7714
+ }, dryrun);
7715
+ }
7716
+
7651
7717
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7652
7718
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7653
7719
}
@@ -8968,6 +9034,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
8968
9034
case GGML_OP_CLAMP:
8969
9035
case GGML_OP_PAD:
8970
9036
case GGML_OP_CPY:
9037
+ case GGML_OP_SET_ROWS:
8971
9038
case GGML_OP_CONT:
8972
9039
case GGML_OP_DUP:
8973
9040
case GGML_OP_SILU_BACK:
@@ -9034,6 +9101,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9034
9101
case GGML_OP_CLAMP:
9035
9102
case GGML_OP_PAD:
9036
9103
case GGML_OP_CPY:
9104
+ case GGML_OP_SET_ROWS:
9037
9105
case GGML_OP_CONT:
9038
9106
case GGML_OP_DUP:
9039
9107
case GGML_OP_SILU_BACK:
@@ -9142,6 +9210,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9142
9210
case GGML_OP_DUP:
9143
9211
ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
9144
9212
9213
+ break;
9214
+ case GGML_OP_SET_ROWS:
9215
+ ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun);
9216
+
9145
9217
break;
9146
9218
case GGML_OP_SILU_BACK:
9147
9219
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9357,6 +9429,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9357
9429
case GGML_OP_CLAMP:
9358
9430
case GGML_OP_PAD:
9359
9431
case GGML_OP_CPY:
9432
+ case GGML_OP_SET_ROWS:
9360
9433
case GGML_OP_CONT:
9361
9434
case GGML_OP_DUP:
9362
9435
case GGML_OP_SILU_BACK:
@@ -10422,9 +10495,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10422
10495
} break;
10423
10496
case GGML_OP_SET_ROWS:
10424
10497
{
10425
- // TODO: add support
10426
- // ref: https://github.com/ggml-org/llama.cpp/pull/14274
10427
- return false;
10498
+ switch (op->type) {
10499
+ case GGML_TYPE_F32:
10500
+ case GGML_TYPE_F16:
10501
+ case GGML_TYPE_BF16:
10502
+ case GGML_TYPE_Q4_0:
10503
+ case GGML_TYPE_Q4_1:
10504
+ case GGML_TYPE_Q5_0:
10505
+ case GGML_TYPE_Q5_1:
10506
+ case GGML_TYPE_Q8_0:
10507
+ case GGML_TYPE_IQ4_NL:
10508
+ return true;
10509
+ default:
10510
+ return false;
10511
+ }
10428
10512
} break;
10429
10513
case GGML_OP_CONT:
10430
10514
case GGML_OP_CPY:
@@ -11039,6 +11123,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11039
11123
} else {
11040
11124
tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
11041
11125
}
11126
+ } else if (tensor->op == GGML_OP_SET_ROWS) {
11127
+ tensor_clone = ggml_set_rows(ggml_ctx, src_clone[0], src_clone[1]);
11042
11128
} else if (tensor->op == GGML_OP_CONT) {
11043
11129
tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
11044
11130
} else if (tensor->op == GGML_OP_RESHAPE) {
0 commit comments