Skip to content

Commit 6e07c3e

Browse files
committed
metal : fuse add
ggml-ci
1 parent bc0a20c commit 6e07c3e

File tree

7 files changed

+234
-82
lines changed

7 files changed

+234
-82
lines changed

ggml/src/ggml-alloc.c

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
2222
return t->view_src != NULL;
2323
}
2424

25-
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
26-
if (a->type != b->type) {
27-
return false;
28-
}
29-
for (int i = 0; i < GGML_MAX_DIMS; i++) {
30-
if (a->ne[i] != b->ne[i]) {
31-
return false;
32-
}
33-
if (a->nb[i] != b->nb[i]) {
34-
return false;
35-
}
36-
}
37-
return true;
38-
}
39-
4025
// ops that return true for this function must not use restrict pointers for their backend implementations
4126
static bool ggml_op_can_inplace(enum ggml_op op) {
4227
switch (op) {

ggml/src/ggml-backend.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
352352

353353
// backend copy
354354

355-
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
356-
if (a->type != b->type) {
357-
return false;
358-
}
359-
for (int i = 0; i < GGML_MAX_DIMS; i++) {
360-
if (a->ne[i] != b->ne[i]) {
361-
return false;
362-
}
363-
if (a->nb[i] != b->nb[i]) {
364-
return false;
365-
}
366-
}
367-
return true;
368-
}
369-
370355
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
371356
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
372357

ggml/src/ggml-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
7373
return (n + m - 1) & ~(m - 1);
7474
}
7575

76+
// TODO: move to ggml.h?
77+
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
78+
if (a->type != b->type) {
79+
return false;
80+
}
81+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
82+
if (a->ne[i] != b->ne[i]) {
83+
return false;
84+
}
85+
if (a->nb[i] != b->nb[i]) {
86+
return false;
87+
}
88+
}
89+
return true;
90+
}
91+
7692
//
7793
// logging
7894
//

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ typedef struct {
126126
uint64_t nb2;
127127
uint64_t nb3;
128128
uint64_t offs;
129+
uint64_t o1[8];
129130
} ggml_metal_kargs_bin;
130131

131132
typedef struct {

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 144 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,21 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147147

148148
enum ggml_metal_kernel_type {
149149
GGML_METAL_KERNEL_TYPE_ADD,
150+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
152+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
153+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
154+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
155+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
156+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150157
GGML_METAL_KERNEL_TYPE_ADD_ROW,
158+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
159+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3,
160+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
161+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5,
162+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
163+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7,
164+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151165
GGML_METAL_KERNEL_TYPE_SUB,
152166
GGML_METAL_KERNEL_TYPE_SUB_ROW,
153167
GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1143,21 @@ @implementation GGMLMetalClass
11291143
// simd_sum and simd_max requires MTLGPUFamilyApple7
11301144

11311145
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1146+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1147+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
1148+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1149+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
1150+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1151+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
1152+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
11321153
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1154+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true);
1155+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3, add_row_fuse_3, true);
1156+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true);
1157+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5, add_row_fuse_5, true);
1158+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true);
1159+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7, add_row_fuse_7, true);
1160+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true);
11331161
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
11341162
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
11351163
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
@@ -1875,7 +1903,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18751903
}
18761904
}
18771905

1878-
static bool ggml_metal_encode_node(
1906+
static int ggml_metal_encode_node(
18791907
ggml_backend_t backend,
18801908
int idx,
18811909
id<MTLComputeCommandEncoder> encoder,
@@ -1885,7 +1913,10 @@ static bool ggml_metal_encode_node(
18851913

18861914
struct ggml_cgraph * gf = ctx->gf;
18871915

1888-
struct ggml_tensor * node = ggml_graph_node(gf, idx);
1916+
enum ggml_op ops[8];
1917+
1918+
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
1919+
struct ggml_tensor * node = nodes[0];
18891920

18901921
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
18911922

@@ -1895,7 +1926,7 @@ static bool ggml_metal_encode_node(
18951926
struct ggml_tensor * dst = node;
18961927

18971928
if (ggml_is_empty(dst)) {
1898-
return true;
1929+
return 1;
18991930
}
19001931

19011932
switch (dst->op) {
@@ -1906,7 +1937,7 @@ static bool ggml_metal_encode_node(
19061937
case GGML_OP_PERMUTE:
19071938
{
19081939
// noop -> next node
1909-
} return true;
1940+
} return 1;
19101941
default:
19111942
{
19121943
} break;
@@ -1973,6 +2004,8 @@ static bool ggml_metal_encode_node(
19732004
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
19742005
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
19752006

2007+
int n_fuse = 1;
2008+
19762009
#if 0
19772010
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
19782011
if (src0) {
@@ -2050,31 +2083,6 @@ static bool ggml_metal_encode_node(
20502083

20512084
id<MTLComputePipelineState> pipeline = nil;
20522085

2053-
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2054-
GGML_ASSERT(ggml_is_contiguous(src0));
2055-
2056-
// src1 is a row
2057-
GGML_ASSERT(ne11 == 1);
2058-
2059-
switch (dst->op) {
2060-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
2061-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
2062-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
2063-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
2064-
default: GGML_ABORT("fatal error");
2065-
}
2066-
2067-
bcast_row = true;
2068-
} else {
2069-
switch (dst->op) {
2070-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2071-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2072-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2073-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2074-
default: GGML_ABORT("fatal error");
2075-
}
2076-
}
2077-
20782086
ggml_metal_kargs_bin args = {
20792087
/*.ne00 =*/ ne00,
20802088
/*.ne01 =*/ ne01,
@@ -2101,12 +2109,106 @@ static bool ggml_metal_encode_node(
21012109
/*.nb2 =*/ nb2,
21022110
/*.nb3 =*/ nb3,
21032111
/*.offs =*/ offs,
2112+
/*.o1 =*/ { offs_src1 },
21042113
};
21052114

2115+
{
2116+
ops[0] = GGML_OP_ADD;
2117+
ops[1] = GGML_OP_ADD;
2118+
ops[2] = GGML_OP_ADD;
2119+
ops[3] = GGML_OP_ADD;
2120+
ops[4] = GGML_OP_ADD;
2121+
ops[5] = GGML_OP_ADD;
2122+
ops[6] = GGML_OP_ADD;
2123+
ops[7] = GGML_OP_ADD;
2124+
2125+
size_t offs_fuse;
2126+
id<MTLBuffer> id_fuse;
2127+
2128+
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2129+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
2130+
break;
2131+
}
2132+
2133+
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
2134+
break;
2135+
}
2136+
2137+
// only fuse nodes if src1 is in the same Metal buffer
2138+
id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
2139+
if (id_fuse != id_src1) {
2140+
break;
2141+
}
2142+
2143+
args.o1[n_fuse + 1] = offs_fuse;
2144+
}
2145+
2146+
++n_fuse;
2147+
}
2148+
2149+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
2150+
GGML_ASSERT(ggml_is_contiguous(src0));
2151+
2152+
// src1 is a row
2153+
GGML_ASSERT(ne11 == 1);
2154+
2155+
switch (dst->op) {
2156+
case GGML_OP_ADD:
2157+
{
2158+
switch (n_fuse) {
2159+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW ].pipeline; break;
2160+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline; break;
2161+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3].pipeline; break;
2162+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline; break;
2163+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5].pipeline; break;
2164+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline; break;
2165+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7].pipeline; break;
2166+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline; break;
2167+
default: GGML_ABORT("fatal error");
2168+
}
2169+
} break;
2170+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
2171+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
2172+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
2173+
default: GGML_ABORT("fatal error");
2174+
}
2175+
2176+
bcast_row = true;
2177+
} else {
2178+
switch (dst->op) {
2179+
case GGML_OP_ADD:
2180+
{
2181+
switch (n_fuse) {
2182+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
2183+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2184+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
2185+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2186+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
2187+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2188+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
2189+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2190+
default: GGML_ABORT("fatal error");
2191+
}
2192+
} break;
2193+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
2194+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
2195+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
2196+
default: GGML_ABORT("fatal error");
2197+
}
2198+
}
2199+
2200+
if (n_fuse > 1) {
2201+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
2202+
}
2203+
21062204
[encoder setComputePipelineState:pipeline];
21072205
[encoder setBytes:&args length:sizeof(args) atIndex:0];
21082206
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2109-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2207+
if (dst->op == GGML_OP_ADD) {
2208+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
2209+
} else {
2210+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2211+
}
21102212
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
21112213

21122214
if (bcast_row) {
@@ -2239,6 +2341,7 @@ static bool ggml_metal_encode_node(
22392341
/*.nb2 =*/ pnb2,
22402342
/*.nb3 =*/ pnb3,
22412343
/*.offs =*/ offs,
2344+
/*.o1 =*/ { offs_src1 },
22422345
};
22432346

22442347
[encoder setComputePipelineState:pipeline];
@@ -2674,7 +2777,7 @@ static bool ggml_metal_encode_node(
26742777
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
26752778
if (!h_src0) {
26762779
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677-
return false;
2780+
return 0;
26782781
}
26792782

26802783
offs_src0 = 0;
@@ -3550,7 +3653,7 @@ static bool ggml_metal_encode_node(
35503653
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
35513654
if (!h_src1) {
35523655
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3553-
return false;
3656+
return 0;
35543657
}
35553658

35563659
const int64_t neh0 = ne0;
@@ -3566,15 +3669,15 @@ static bool ggml_metal_encode_node(
35663669
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
35673670
if (!h_dst) {
35683671
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3569-
return false;
3672+
return 0;
35703673
}
35713674

35723675
// tokens per expert
35733676
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
35743677
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
35753678
if (!h_tpe) {
35763679
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3577-
return false;
3680+
return 0;
35783681
}
35793682

35803683
// id map
@@ -3583,7 +3686,7 @@ static bool ggml_metal_encode_node(
35833686
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
35843687
if (!h_ids) {
35853688
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3586-
return false;
3689+
return 0;
35873690
}
35883691

35893692
{
@@ -5442,7 +5545,7 @@ static bool ggml_metal_encode_node(
54425545
}
54435546
}
54445547

5445-
return true;
5548+
return n_fuse;
54465549
}
54475550

54485551
static enum ggml_status ggml_metal_graph_compute(
@@ -5948,20 +6051,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59486051
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
59496052
ggml_metal_mem_pool_reset(mem_pool);
59506053

5951-
for (int idx = node_start; idx < node_end; ++idx) {
6054+
for (int idx = node_start; idx < node_end;) {
59526055
if (should_capture) {
59536056
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
59546057
}
59556058

5956-
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6059+
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
59576060

59586061
if (should_capture) {
59596062
[encoder popDebugGroup];
59606063
}
59616064

5962-
if (!res) {
6065+
if (res == 0) {
59636066
break;
59646067
}
6068+
6069+
idx += res;
59656070
}
59666071

59676072
[encoder endEncoding];

0 commit comments

Comments
 (0)