Skip to content

Commit b61796c

Browse files
committed
metal : fuse add
1 parent 28f8817 commit b61796c

File tree

6 files changed

+339
-51
lines changed

6 files changed

+339
-51
lines changed

ggml/src/ggml-alloc.c

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
2222
return t->view_src != NULL;
2323
}
2424

25-
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
26-
if (a->type != b->type) {
27-
return false;
28-
}
29-
for (int i = 0; i < GGML_MAX_DIMS; i++) {
30-
if (a->ne[i] != b->ne[i]) {
31-
return false;
32-
}
33-
if (a->nb[i] != b->nb[i]) {
34-
return false;
35-
}
36-
}
37-
return true;
38-
}
39-
4025
// ops that return true for this function must not use restrict pointers for their backend implementations
4126
static bool ggml_op_can_inplace(enum ggml_op op) {
4227
switch (op) {

ggml/src/ggml-backend.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
352352

353353
// backend copy
354354

355-
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
356-
if (a->type != b->type) {
357-
return false;
358-
}
359-
for (int i = 0; i < GGML_MAX_DIMS; i++) {
360-
if (a->ne[i] != b->ne[i]) {
361-
return false;
362-
}
363-
if (a->nb[i] != b->nb[i]) {
364-
return false;
365-
}
366-
}
367-
return true;
368-
}
369-
370355
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
371356
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
372357

ggml/src/ggml-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
7373
return (n + m - 1) & ~(m - 1);
7474
}
7575

76+
// TODO: move to ggml.h?
77+
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
78+
if (a->type != b->type) {
79+
return false;
80+
}
81+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
82+
if (a->ne[i] != b->ne[i]) {
83+
return false;
84+
}
85+
if (a->nb[i] != b->nb[i]) {
86+
return false;
87+
}
88+
}
89+
return true;
90+
}
91+
7692
//
7793
// logging
7894
//

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147147

148148
enum ggml_metal_kernel_type {
149149
GGML_METAL_KERNEL_TYPE_ADD,
150+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
152+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
153+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150154
GGML_METAL_KERNEL_TYPE_ADD_ROW,
155+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
156+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
157+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
158+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151159
GGML_METAL_KERNEL_TYPE_SUB,
152160
GGML_METAL_KERNEL_TYPE_SUB_ROW,
153161
GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1137,15 @@ @implementation GGMLMetalClass
11291137
// simd_sum and simd_max requires MTLGPUFamilyApple7
11301138

11311139
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1140+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1141+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1142+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1143+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
11321144
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1145+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true);
1146+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true);
1147+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true);
1148+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true);
11331149
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
11341150
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
11351151
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
@@ -1875,7 +1891,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18751891
}
18761892
}
18771893

1878-
static bool ggml_metal_encode_node(
1894+
static int ggml_metal_encode_node(
18791895
ggml_backend_t backend,
18801896
int idx,
18811897
id<MTLComputeCommandEncoder> encoder,
@@ -1885,7 +1901,12 @@ static bool ggml_metal_encode_node(
18851901

18861902
struct ggml_cgraph * gf = ctx->gf;
18871903

1888-
struct ggml_tensor * node = ggml_graph_node(gf, idx);
1904+
enum ggml_op ops[8];
1905+
1906+
struct ggml_tensor ** nodes = ggml_graph_nodes(gf);
1907+
struct ggml_tensor * node = nodes[idx];
1908+
1909+
struct ggml_tensor ** fuse = nodes + idx + 1;
18891910

18901911
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
18911912

@@ -1895,7 +1916,7 @@ static bool ggml_metal_encode_node(
18951916
struct ggml_tensor * dst = node;
18961917

18971918
if (ggml_is_empty(dst)) {
1898-
return true;
1919+
return 1;
18991920
}
19001921

19011922
switch (dst->op) {
@@ -1906,7 +1927,7 @@ static bool ggml_metal_encode_node(
19061927
case GGML_OP_PERMUTE:
19071928
{
19081929
// noop -> next node
1909-
} return true;
1930+
} return 1;
19101931
default:
19111932
{
19121933
} break;
@@ -1973,6 +1994,8 @@ static bool ggml_metal_encode_node(
19731994
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
19741995
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
19751996

1997+
int n_fuse = 1;
1998+
19761999
#if 0
19772000
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
19782001
if (src0) {
@@ -2050,14 +2073,50 @@ static bool ggml_metal_encode_node(
20502073

20512074
id<MTLComputePipelineState> pipeline = nil;
20522075

2076+
{
2077+
ops[0] = GGML_OP_ADD;
2078+
ops[1] = GGML_OP_ADD;
2079+
ops[2] = GGML_OP_ADD;
2080+
ops[3] = GGML_OP_ADD;
2081+
ops[4] = GGML_OP_ADD;
2082+
ops[5] = GGML_OP_ADD;
2083+
ops[6] = GGML_OP_ADD;
2084+
ops[7] = GGML_OP_ADD;
2085+
2086+
for (n_fuse = 8; n_fuse > 1; --n_fuse) {
2087+
if (n_fuse % 2 == 1) {
2088+
continue;
2089+
}
2090+
if (ggml_can_fuse(gf, idx, ops, n_fuse)) {
2091+
if (ggml_are_same_layout(node->src[1], fuse[0]->src[1]) &&
2092+
ggml_are_same_layout(node->src[1], fuse[n_fuse - 2]->src[1])) {
2093+
break;
2094+
}
2095+
}
2096+
}
2097+
}
2098+
20532099
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
20542100
GGML_ASSERT(ggml_is_contiguous(src0));
20552101

20562102
// src1 is a row
20572103
GGML_ASSERT(ne11 == 1);
20582104

20592105
switch (dst->op) {
2060-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
2106+
case GGML_OP_ADD:
2107+
{
2108+
switch (n_fuse) {
2109+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline; break;
2110+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline; break;
2111+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline; break;
2112+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline; break;
2113+
default:
2114+
{
2115+
GGML_ASSERT(n_fuse == 1);
2116+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline;
2117+
}
2118+
}
2119+
} break;
20612120
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
20622121
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
20632122
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
@@ -2067,7 +2126,21 @@ static bool ggml_metal_encode_node(
20672126
bcast_row = true;
20682127
} else {
20692128
switch (dst->op) {
2070-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2129+
case GGML_OP_ADD:
2130+
{
2131+
//GGML_LOG_INFO("XXXXXXXXXXXXXXXXXXXXXXXXX n_fuse = %d\n", n_fuse);
2132+
switch (n_fuse) {
2133+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2134+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2135+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2136+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2137+
default:
2138+
{
2139+
GGML_ASSERT(n_fuse == 1);
2140+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2141+
}
2142+
}
2143+
} break;
20712144
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
20722145
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
20732146
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
@@ -2107,7 +2180,16 @@ static bool ggml_metal_encode_node(
21072180
[encoder setBytes:&args length:sizeof(args) atIndex:0];
21082181
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
21092182
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2110-
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2183+
for (int f = 0; f < n_fuse - 1; ++f) {
2184+
id_src1 = ggml_metal_get_buffer(fuse[f]->src[1], &offs_src1);
2185+
2186+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:3 + f];
2187+
2188+
if (f + 1 == n_fuse - 1) {
2189+
id_dst = ggml_metal_get_buffer(fuse[f], &offs_dst);
2190+
}
2191+
}
2192+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2 + n_fuse];
21112193

21122194
if (bcast_row) {
21132195
const int64_t n = ggml_nelements(dst)/4;
@@ -2674,7 +2756,7 @@ static bool ggml_metal_encode_node(
26742756
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
26752757
if (!h_src0) {
26762758
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677-
return false;
2759+
return 0;
26782760
}
26792761

26802762
offs_src0 = 0;
@@ -3550,7 +3632,7 @@ static bool ggml_metal_encode_node(
35503632
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
35513633
if (!h_src1) {
35523634
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3553-
return false;
3635+
return 0;
35543636
}
35553637

35563638
const int64_t neh0 = ne0;
@@ -3566,15 +3648,15 @@ static bool ggml_metal_encode_node(
35663648
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
35673649
if (!h_dst) {
35683650
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3569-
return false;
3651+
return 0;
35703652
}
35713653

35723654
// tokens per expert
35733655
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
35743656
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
35753657
if (!h_tpe) {
35763658
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3577-
return false;
3659+
return 0;
35783660
}
35793661

35803662
// id map
@@ -3583,7 +3665,7 @@ static bool ggml_metal_encode_node(
35833665
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
35843666
if (!h_ids) {
35853667
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3586-
return false;
3668+
return 0;
35873669
}
35883670

35893671
{
@@ -5442,7 +5524,7 @@ static bool ggml_metal_encode_node(
54425524
}
54435525
}
54445526

5445-
return true;
5527+
return n_fuse;
54465528
}
54475529

54485530
static enum ggml_status ggml_metal_graph_compute(
@@ -5948,20 +6030,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59486030
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
59496031
ggml_metal_mem_pool_reset(mem_pool);
59506032

5951-
for (int idx = node_start; idx < node_end; ++idx) {
6033+
for (int idx = node_start; idx < node_end;) {
59526034
if (should_capture) {
59536035
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
59546036
}
59556037

5956-
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6038+
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
59576039

59586040
if (should_capture) {
59596041
[encoder popDebugGroup];
59606042
}
59616043

5962-
if (!res) {
6044+
if (res == 0) {
59636045
break;
59646046
}
6047+
6048+
idx += res;
59656049
}
59666050

59676051
[encoder endEncoding];

0 commit comments

Comments
 (0)