Skip to content

Commit b6c05ce

Browse files
rgerganovggerganov
andcommitted
yolo : add backend support (ggml/924)
* yolo : add backend support * metal : add sub and sqrt kernels --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 52c80ca commit b6c05ce

File tree

5 files changed

+105
-1
lines changed

5 files changed

+105
-1
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,6 +2181,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21812181
case GGML_OP_ADD:
21822182
ggml_cuda_op_add(ctx, dst);
21832183
break;
2184+
case GGML_OP_SUB:
2185+
ggml_cuda_op_sub(ctx, dst);
2186+
break;
21842187
case GGML_OP_ACC:
21852188
ggml_cuda_op_acc(ctx, dst);
21862189
break;
@@ -2859,6 +2862,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28592862
case GGML_OP_TRANSPOSE:
28602863
case GGML_OP_NORM:
28612864
case GGML_OP_ADD:
2865+
case GGML_OP_SUB:
28622866
case GGML_OP_MUL:
28632867
case GGML_OP_DIV:
28642868
case GGML_OP_RMS_NORM:

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
99
return a + b;
1010
}
1111

12+
static __device__ __forceinline__ float op_sub(const float a, const float b) {
13+
return a - b;
14+
}
15+
1216
static __device__ __forceinline__ float op_mul(const float a, const float b) {
1317
return a * b;
1418
}
@@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
271275
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
272276
}
273277

278+
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
279+
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
280+
}
281+
274282
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
275283
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
276284
}

ggml/src/ggml-cuda/binbcast.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
44
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5+
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
56
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
67
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-metal.m

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
enum ggml_metal_kernel_type {
3232
GGML_METAL_KERNEL_TYPE_ADD,
3333
GGML_METAL_KERNEL_TYPE_ADD_ROW,
34+
GGML_METAL_KERNEL_TYPE_SUB,
35+
GGML_METAL_KERNEL_TYPE_SUB_ROW,
3436
GGML_METAL_KERNEL_TYPE_MUL,
3537
GGML_METAL_KERNEL_TYPE_MUL_ROW,
3638
GGML_METAL_KERNEL_TYPE_DIV,
@@ -205,6 +207,7 @@
205207
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
206208
GGML_METAL_KERNEL_TYPE_CONCAT,
207209
GGML_METAL_KERNEL_TYPE_SQR,
210+
GGML_METAL_KERNEL_TYPE_SQRT,
208211
GGML_METAL_KERNEL_TYPE_SIN,
209212
GGML_METAL_KERNEL_TYPE_COS,
210213
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -493,6 +496,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
493496

494497
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
495498
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
499+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
500+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
496501
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
497502
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
498503
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
@@ -667,6 +672,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
667672
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
668673
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
669674
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
675+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
670676
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
671677
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
672678
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
@@ -769,6 +775,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
769775
case GGML_OP_PERMUTE:
770776
case GGML_OP_CONCAT:
771777
case GGML_OP_ADD:
778+
case GGML_OP_SUB:
772779
case GGML_OP_ACC:
773780
case GGML_OP_MUL:
774781
case GGML_OP_DIV:
@@ -777,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
777784
case GGML_OP_CLAMP:
778785
return true;
779786
case GGML_OP_SQR:
787+
case GGML_OP_SQRT:
780788
case GGML_OP_SIN:
781789
case GGML_OP_COS:
782790
return ggml_is_contiguous(op->src[0]);
@@ -1057,6 +1065,7 @@ static enum ggml_status ggml_metal_graph_compute(
10571065
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10581066
} break;
10591067
case GGML_OP_ADD:
1068+
case GGML_OP_SUB:
10601069
case GGML_OP_MUL:
10611070
case GGML_OP_DIV:
10621071
{
@@ -1080,6 +1089,7 @@ static enum ggml_status ggml_metal_graph_compute(
10801089
nb = ne00 / 4;
10811090
switch (dst->op) {
10821091
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
1092+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
10831093
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
10841094
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
10851095
default: GGML_ABORT("fatal error");
@@ -1089,6 +1099,7 @@ static enum ggml_status ggml_metal_graph_compute(
10891099
} else {
10901100
switch (dst->op) {
10911101
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
1102+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
10921103
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
10931104
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
10941105
default: GGML_ABORT("fatal error");
@@ -1416,6 +1427,20 @@ static enum ggml_status ggml_metal_graph_compute(
14161427

14171428
const int64_t n = ggml_nelements(dst);
14181429

1430+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1431+
} break;
1432+
case GGML_OP_SQRT:
1433+
{
1434+
GGML_ASSERT(ggml_is_contiguous(src0));
1435+
1436+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
1437+
1438+
[encoder setComputePipelineState:pipeline];
1439+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1440+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1441+
1442+
const int64_t n = ggml_nelements(dst);
1443+
14191444
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
14201445
} break;
14211446
case GGML_OP_SIN:

ggml/src/ggml-metal.metal

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ enum ggml_sort_order {
1717
GGML_SORT_ORDER_DESC,
1818
};
1919

20-
// general-purpose kernel for addition, multiplication and division of two tensors
20+
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
2121
// pros: works for non-contiguous tensors, supports broadcast across all dims
2222
// cons: not very efficient
2323
kernel void kernel_add(
@@ -70,6 +70,56 @@ kernel void kernel_add(
7070
}
7171
}
7272

73+
kernel void kernel_sub(
74+
device const char * src0,
75+
device const char * src1,
76+
device char * dst,
77+
constant int64_t & ne00,
78+
constant int64_t & ne01,
79+
constant int64_t & ne02,
80+
constant int64_t & ne03,
81+
constant uint64_t & nb00,
82+
constant uint64_t & nb01,
83+
constant uint64_t & nb02,
84+
constant uint64_t & nb03,
85+
constant int64_t & ne10,
86+
constant int64_t & ne11,
87+
constant int64_t & ne12,
88+
constant int64_t & ne13,
89+
constant uint64_t & nb10,
90+
constant uint64_t & nb11,
91+
constant uint64_t & nb12,
92+
constant uint64_t & nb13,
93+
constant int64_t & ne0,
94+
constant int64_t & ne1,
95+
constant int64_t & ne2,
96+
constant int64_t & ne3,
97+
constant uint64_t & nb0,
98+
constant uint64_t & nb1,
99+
constant uint64_t & nb2,
100+
constant uint64_t & nb3,
101+
constant int64_t & offs,
102+
uint3 tgpig[[threadgroup_position_in_grid]],
103+
uint3 tpitg[[thread_position_in_threadgroup]],
104+
uint3 ntg[[threads_per_threadgroup]]) {
105+
const int64_t i03 = tgpig.z;
106+
const int64_t i02 = tgpig.y;
107+
const int64_t i01 = tgpig.x;
108+
109+
const int64_t i13 = i03 % ne13;
110+
const int64_t i12 = i02 % ne12;
111+
const int64_t i11 = i01 % ne11;
112+
113+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
114+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
115+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
116+
117+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
118+
const int i10 = i0 % ne10;
119+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
120+
}
121+
}
122+
73123
kernel void kernel_mul(
74124
device const char * src0,
75125
device const char * src1,
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
226276
dst[tpig] = src0[tpig] + src1[tpig % nb];
227277
}
228278

279+
kernel void kernel_sub_row(
280+
device const float4 * src0,
281+
device const float4 * src1,
282+
device float4 * dst,
283+
constant uint64_t & nb [[buffer(28)]],
284+
uint tpig[[thread_position_in_grid]]) {
285+
dst[tpig] = src0[tpig] - src1[tpig % nb];
286+
}
287+
229288
kernel void kernel_mul_row(
230289
device const float4 * src0,
231290
device const float4 * src1,
@@ -358,6 +417,13 @@ kernel void kernel_sqr(
358417
dst[tpig] = src0[tpig] * src0[tpig];
359418
}
360419

420+
kernel void kernel_sqrt(
421+
device const float * src0,
422+
device float * dst,
423+
uint tpig[[thread_position_in_grid]]) {
424+
dst[tpig] = sqrt(src0[tpig]);
425+
}
426+
361427
kernel void kernel_sin(
362428
device const float * src0,
363429
device float * dst,

0 commit comments

Comments
 (0)