Skip to content

Commit 3bb8c0c

Browse files
rgerganovggerganov
authored andcommitted
ggml : add ggml_set_rows (ggml-org#14274)
* ggml : add ggml_set_rows Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using indices from 'c'. ref: ggml-org#8366 * use I64 for indices * ggml : add repeat impl for i64 * ggml : add ggml_is_contiguous_rows * ggml : ggml_set_rows support broadcast * ggml : ggml_set_rows support quantized dst ggml-ci * ggml : support GGML_TYPE_F32 ".from_float" trait * ggml : ggml_set_rows update comment + better index name * tests : add ggml_set_rows * metal : add ggml_set_rows implementation ggml-ci * ggml : simplify forward_dup_f32 * ggml : fix supports_op * tests : add comment to set_rows * ggml : leave the repeat_i64 for a separate PR ggml-ci * ggml : set_rows use std::min instead of MIN * ggml : better error message for set_rows unsupported type * metal : perform op->type check only once * tests : more consistent implementation + more tests ggml-ci --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 0db1d5e commit 3bb8c0c

File tree

4 files changed

+524
-185
lines changed

4 files changed

+524
-185
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213+
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
205214
GGML_METAL_KERNEL_TYPE_RMS_NORM,
206215
GGML_METAL_KERNEL_TYPE_L2_NORM,
207216
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1169,6 +1178,15 @@ @implementation GGMLMetalClass
11691178
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
11701179
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
11711180
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1181+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1182+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1183+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1184+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1185+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1186+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1187+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1188+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1189+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
11721190
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
11731191
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
11741192
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1635,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16351653
const bool use_bfloat = ctx_dev->use_bfloat;
16361654

16371655
if (!use_bfloat) {
1656+
if (op->type == GGML_TYPE_BF16) {
1657+
return false;
1658+
}
1659+
16381660
for (size_t i = 0, n = 3; i < n; ++i) {
16391661
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
16401662
return false;
@@ -1804,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18041826
{
18051827
return op->ne[3] == 1;
18061828
}
1829+
case GGML_OP_SET_ROWS:
1830+
{
1831+
if (op->src[0]->type != GGML_TYPE_F32) {
1832+
return false;
1833+
}
1834+
1835+
switch (op->type) {
1836+
case GGML_TYPE_F32:
1837+
case GGML_TYPE_F16:
1838+
case GGML_TYPE_BF16:
1839+
case GGML_TYPE_Q8_0:
1840+
case GGML_TYPE_Q4_0:
1841+
case GGML_TYPE_Q4_1:
1842+
case GGML_TYPE_Q5_0:
1843+
case GGML_TYPE_Q5_1:
1844+
case GGML_TYPE_IQ4_NL:
1845+
return true;
1846+
default:
1847+
return false;
1848+
};
1849+
}
18071850
default:
18081851
return false;
18091852
}
@@ -3777,13 +3820,74 @@ static bool ggml_metal_encode_node(
37773820
};
37783821

37793822
[encoder setComputePipelineState:pipeline];
3780-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3781-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3782-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3783-
[encoder setBytes:&args length:sizeof(args) atIndex:3];
3823+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3824+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3825+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3826+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
37843827

37853828
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
37863829
} break;
3830+
case GGML_OP_SET_ROWS:
3831+
{
3832+
id<MTLComputePipelineState> pipeline = nil;
3833+
3834+
switch (dst->type) {
3835+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3836+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3837+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3838+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3839+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3840+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3841+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3842+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3843+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3844+
default: GGML_ABORT("not implemented");
3845+
}
3846+
3847+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3848+
3849+
int nth = 32; // SIMD width
3850+
3851+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3852+
nth *= 2;
3853+
}
3854+
3855+
int nrptg = 1;
3856+
if (nth > nk0) {
3857+
nrptg = (nth + nk0 - 1)/nk0;
3858+
nth = nk0;
3859+
3860+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3861+
nrptg--;
3862+
}
3863+
}
3864+
3865+
nth = MIN(nth, nk0);
3866+
3867+
ggml_metal_kargs_set_rows args = {
3868+
/*.nk0 =*/ nk0,
3869+
/*.ne01 =*/ ne01,
3870+
/*.nb01 =*/ nb01,
3871+
/*.nb02 =*/ nb02,
3872+
/*.nb03 =*/ nb03,
3873+
/*.ne11 =*/ ne11,
3874+
/*.ne12 =*/ ne12,
3875+
/*.nb10 =*/ nb10,
3876+
/*.nb11 =*/ nb11,
3877+
/*.nb12 =*/ nb12,
3878+
/*.nb1 =*/ nb1,
3879+
/*.nb2 =*/ nb2,
3880+
/*.nb3 =*/ nb3,
3881+
};
3882+
3883+
[encoder setComputePipelineState:pipeline];
3884+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3885+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3886+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3887+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3888+
3889+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3890+
} break;
37873891
case GGML_OP_RMS_NORM:
37883892
{
37893893
GGML_ASSERT(ne00 % 4 == 0);

0 commit comments

Comments
 (0)