@@ -147,7 +147,21 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147
147
148
148
enum ggml_metal_kernel_type {
149
149
GGML_METAL_KERNEL_TYPE_ADD,
150
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
152
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
153
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
154
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
155
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
156
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150
157
GGML_METAL_KERNEL_TYPE_ADD_ROW,
158
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
159
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3,
160
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
161
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5,
162
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
163
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7,
164
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151
165
GGML_METAL_KERNEL_TYPE_SUB,
152
166
GGML_METAL_KERNEL_TYPE_SUB_ROW,
153
167
GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1143,21 @@ @implementation GGMLMetalClass
1129
1143
// simd_sum and simd_max requires MTLGPUFamilyApple7
1130
1144
1131
1145
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
1146
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true );
1147
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true );
1148
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true );
1149
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true );
1150
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1151
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true );
1152
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1132
1153
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true );
1154
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true );
1155
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3, add_row_fuse_3, true );
1156
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true );
1157
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5, add_row_fuse_5, true );
1158
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true );
1159
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7, add_row_fuse_7, true );
1160
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true );
1133
1161
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1134
1162
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true );
1135
1163
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
@@ -1875,7 +1903,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1875
1903
}
1876
1904
}
1877
1905
1878
- static bool ggml_metal_encode_node (
1906
+ static int ggml_metal_encode_node (
1879
1907
ggml_backend_t backend,
1880
1908
int idx,
1881
1909
id <MTLComputeCommandEncoder > encoder,
@@ -1885,7 +1913,10 @@ static bool ggml_metal_encode_node(
1885
1913
1886
1914
struct ggml_cgraph * gf = ctx->gf ;
1887
1915
1888
- struct ggml_tensor * node = ggml_graph_node (gf, idx);
1916
+ enum ggml_op ops[8 ];
1917
+
1918
+ struct ggml_tensor ** nodes = ggml_graph_nodes (gf) + idx;
1919
+ struct ggml_tensor * node = nodes[0 ];
1889
1920
1890
1921
// GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
1891
1922
@@ -1895,7 +1926,7 @@ static bool ggml_metal_encode_node(
1895
1926
struct ggml_tensor * dst = node;
1896
1927
1897
1928
if (ggml_is_empty (dst)) {
1898
- return true ;
1929
+ return 1 ;
1899
1930
}
1900
1931
1901
1932
switch (dst->op ) {
@@ -1906,7 +1937,7 @@ static bool ggml_metal_encode_node(
1906
1937
case GGML_OP_PERMUTE:
1907
1938
{
1908
1939
// noop -> next node
1909
- } return true ;
1940
+ } return 1 ;
1910
1941
default :
1911
1942
{
1912
1943
} break ;
@@ -1973,6 +2004,8 @@ static bool ggml_metal_encode_node(
1973
2004
id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
1974
2005
id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
1975
2006
2007
+ int n_fuse = 1 ;
2008
+
1976
2009
#if 0
1977
2010
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1978
2011
if (src0) {
@@ -2050,31 +2083,6 @@ static bool ggml_metal_encode_node(
2050
2083
2051
2084
id <MTLComputePipelineState > pipeline = nil ;
2052
2085
2053
- if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2054
- GGML_ASSERT (ggml_is_contiguous (src0));
2055
-
2056
- // src1 is a row
2057
- GGML_ASSERT (ne11 == 1 );
2058
-
2059
- switch (dst->op ) {
2060
- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
2061
- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
2062
- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
2063
- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
2064
- default : GGML_ABORT (" fatal error" );
2065
- }
2066
-
2067
- bcast_row = true ;
2068
- } else {
2069
- switch (dst->op ) {
2070
- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2071
- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2072
- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2073
- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
2074
- default : GGML_ABORT (" fatal error" );
2075
- }
2076
- }
2077
-
2078
2086
ggml_metal_kargs_bin args = {
2079
2087
/* .ne00 =*/ ne00,
2080
2088
/* .ne01 =*/ ne01,
@@ -2101,12 +2109,106 @@ static bool ggml_metal_encode_node(
2101
2109
/* .nb2 =*/ nb2,
2102
2110
/* .nb3 =*/ nb3,
2103
2111
/* .offs =*/ offs,
2112
+ /* .o1 =*/ { offs_src1 },
2104
2113
};
2105
2114
2115
+ {
2116
+ ops[0 ] = GGML_OP_ADD;
2117
+ ops[1 ] = GGML_OP_ADD;
2118
+ ops[2 ] = GGML_OP_ADD;
2119
+ ops[3 ] = GGML_OP_ADD;
2120
+ ops[4 ] = GGML_OP_ADD;
2121
+ ops[5 ] = GGML_OP_ADD;
2122
+ ops[6 ] = GGML_OP_ADD;
2123
+ ops[7 ] = GGML_OP_ADD;
2124
+
2125
+ size_t offs_fuse;
2126
+ id <MTLBuffer > id_fuse;
2127
+
2128
+ for (n_fuse = 0 ; n_fuse <= 6 ; ++n_fuse) {
2129
+ if (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
2130
+ break ;
2131
+ }
2132
+
2133
+ if (!ggml_are_same_layout (nodes[n_fuse]->src [1 ], nodes[n_fuse + 1 ]->src [1 ])) {
2134
+ break ;
2135
+ }
2136
+
2137
+ // only fuse nodes if src1 is in the same Metal buffer
2138
+ id_fuse = ggml_metal_get_buffer (nodes[n_fuse + 1 ]->src [1 ], &offs_fuse);
2139
+ if (id_fuse != id_src1) {
2140
+ break ;
2141
+ }
2142
+
2143
+ args.o1 [n_fuse + 1 ] = offs_fuse;
2144
+ }
2145
+
2146
+ ++n_fuse;
2147
+ }
2148
+
2149
+ if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2150
+ GGML_ASSERT (ggml_is_contiguous (src0));
2151
+
2152
+ // src1 is a row
2153
+ GGML_ASSERT (ne11 == 1 );
2154
+
2155
+ switch (dst->op ) {
2156
+ case GGML_OP_ADD:
2157
+ {
2158
+ switch (n_fuse) {
2159
+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW ].pipeline ; break ;
2160
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline ; break ;
2161
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3].pipeline ; break ;
2162
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline ; break ;
2163
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5].pipeline ; break ;
2164
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline ; break ;
2165
+ case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7].pipeline ; break ;
2166
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline ; break ;
2167
+ default : GGML_ABORT (" fatal error" );
2168
+ }
2169
+ } break ;
2170
+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
2171
+ case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
2172
+ case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
2173
+ default : GGML_ABORT (" fatal error" );
2174
+ }
2175
+
2176
+ bcast_row = true ;
2177
+ } else {
2178
+ switch (dst->op ) {
2179
+ case GGML_OP_ADD:
2180
+ {
2181
+ switch (n_fuse) {
2182
+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD ].pipeline ; break ;
2183
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline ; break ;
2184
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline ; break ;
2185
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline ; break ;
2186
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline ; break ;
2187
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline ; break ;
2188
+ case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline ; break ;
2189
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline ; break ;
2190
+ default : GGML_ABORT (" fatal error" );
2191
+ }
2192
+ } break ;
2193
+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2194
+ case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2195
+ case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
2196
+ default : GGML_ABORT (" fatal error" );
2197
+ }
2198
+ }
2199
+
2200
+ if (n_fuse > 1 ) {
2201
+ id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
2202
+ }
2203
+
2106
2204
[encoder setComputePipelineState: pipeline];
2107
2205
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2108
2206
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2109
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2207
+ if (dst->op == GGML_OP_ADD) {
2208
+ [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
2209
+ } else {
2210
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2211
+ }
2110
2212
[encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
2111
2213
2112
2214
if (bcast_row) {
@@ -2239,6 +2341,7 @@ static bool ggml_metal_encode_node(
2239
2341
/* .nb2 =*/ pnb2,
2240
2342
/* .nb3 =*/ pnb3,
2241
2343
/* .offs =*/ offs,
2344
+ /* .o1 =*/ { offs_src1 },
2242
2345
};
2243
2346
2244
2347
[encoder setComputePipelineState: pipeline];
@@ -2674,7 +2777,7 @@ static bool ggml_metal_encode_node(
2674
2777
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2675
2778
if (!h_src0) {
2676
2779
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677
- return false ;
2780
+ return 0 ;
2678
2781
}
2679
2782
2680
2783
offs_src0 = 0;
@@ -3550,7 +3653,7 @@ static bool ggml_metal_encode_node(
3550
3653
id <MTLBuffer > h_src1 = ggml_metal_mem_pool_alloc (mem_pool, s_src1);
3551
3654
if (!h_src1) {
3552
3655
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_src1);
3553
- return false ;
3656
+ return 0 ;
3554
3657
}
3555
3658
3556
3659
const int64_t neh0 = ne0;
@@ -3566,15 +3669,15 @@ static bool ggml_metal_encode_node(
3566
3669
id <MTLBuffer > h_dst = ggml_metal_mem_pool_alloc (mem_pool, s_dst);
3567
3670
if (!h_dst) {
3568
3671
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_dst);
3569
- return false ;
3672
+ return 0 ;
3570
3673
}
3571
3674
3572
3675
// tokens per expert
3573
3676
const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
3574
3677
id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
3575
3678
if (!h_tpe) {
3576
3679
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tpe);
3577
- return false ;
3680
+ return 0 ;
3578
3681
}
3579
3682
3580
3683
// id map
@@ -3583,7 +3686,7 @@ static bool ggml_metal_encode_node(
3583
3686
id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
3584
3687
if (!h_ids) {
3585
3688
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
3586
- return false ;
3689
+ return 0 ;
3587
3690
}
3588
3691
3589
3692
{
@@ -5442,7 +5545,7 @@ static bool ggml_metal_encode_node(
5442
5545
}
5443
5546
}
5444
5547
5445
- return true ;
5548
+ return n_fuse ;
5446
5549
}
5447
5550
5448
5551
static enum ggml_status ggml_metal_graph_compute (
@@ -5948,20 +6051,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5948
6051
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs [cb_idx].mem_pool ;
5949
6052
ggml_metal_mem_pool_reset (mem_pool);
5950
6053
5951
- for (int idx = node_start; idx < node_end; ++idx ) {
6054
+ for (int idx = node_start; idx < node_end;) {
5952
6055
if (should_capture) {
5953
6056
[encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
5954
6057
}
5955
6058
5956
- const bool res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6059
+ const int res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
5957
6060
5958
6061
if (should_capture) {
5959
6062
[encoder popDebugGroup ];
5960
6063
}
5961
6064
5962
- if (! res) {
6065
+ if (res == 0 ) {
5963
6066
break ;
5964
6067
}
6068
+
6069
+ idx += res;
5965
6070
}
5966
6071
5967
6072
[encoder endEncoding ];
0 commit comments