@@ -147,7 +147,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147
147
148
148
enum ggml_metal_kernel_type {
149
149
GGML_METAL_KERNEL_TYPE_ADD,
150
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
152
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
153
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150
154
GGML_METAL_KERNEL_TYPE_ADD_ROW,
155
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
156
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
157
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
158
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151
159
GGML_METAL_KERNEL_TYPE_SUB,
152
160
GGML_METAL_KERNEL_TYPE_SUB_ROW,
153
161
GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1137,15 @@ @implementation GGMLMetalClass
1129
1137
// simd_sum and simd_max requires MTLGPUFamilyApple7
1130
1138
1131
1139
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
1140
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true );
1141
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true );
1142
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1143
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1132
1144
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true );
1145
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true );
1146
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true );
1147
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true );
1148
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true );
1133
1149
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1134
1150
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true );
1135
1151
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
@@ -1875,7 +1891,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1875
1891
}
1876
1892
}
1877
1893
1878
- static bool ggml_metal_encode_node (
1894
+ static int ggml_metal_encode_node (
1879
1895
ggml_backend_t backend,
1880
1896
int idx,
1881
1897
id <MTLComputeCommandEncoder > encoder,
@@ -1885,7 +1901,12 @@ static bool ggml_metal_encode_node(
1885
1901
1886
1902
struct ggml_cgraph * gf = ctx->gf ;
1887
1903
1888
- struct ggml_tensor * node = ggml_graph_node (gf, idx);
1904
+ enum ggml_op ops[8 ];
1905
+
1906
+ struct ggml_tensor ** nodes = ggml_graph_nodes (gf);
1907
+ struct ggml_tensor * node = nodes[idx];
1908
+
1909
+ struct ggml_tensor ** fuse = nodes + idx + 1 ;
1889
1910
1890
1911
// GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
1891
1912
@@ -1895,7 +1916,7 @@ static bool ggml_metal_encode_node(
1895
1916
struct ggml_tensor * dst = node;
1896
1917
1897
1918
if (ggml_is_empty (dst)) {
1898
- return true ;
1919
+ return 1 ;
1899
1920
}
1900
1921
1901
1922
switch (dst->op ) {
@@ -1906,7 +1927,7 @@ static bool ggml_metal_encode_node(
1906
1927
case GGML_OP_PERMUTE:
1907
1928
{
1908
1929
// noop -> next node
1909
- } return true ;
1930
+ } return 1 ;
1910
1931
default :
1911
1932
{
1912
1933
} break ;
@@ -1973,6 +1994,8 @@ static bool ggml_metal_encode_node(
1973
1994
id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
1974
1995
id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
1975
1996
1997
+ int n_fuse = 1 ;
1998
+
1976
1999
#if 0
1977
2000
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1978
2001
if (src0) {
@@ -2050,14 +2073,50 @@ static bool ggml_metal_encode_node(
2050
2073
2051
2074
id <MTLComputePipelineState > pipeline = nil ;
2052
2075
2076
+ {
2077
+ ops[0 ] = GGML_OP_ADD;
2078
+ ops[1 ] = GGML_OP_ADD;
2079
+ ops[2 ] = GGML_OP_ADD;
2080
+ ops[3 ] = GGML_OP_ADD;
2081
+ ops[4 ] = GGML_OP_ADD;
2082
+ ops[5 ] = GGML_OP_ADD;
2083
+ ops[6 ] = GGML_OP_ADD;
2084
+ ops[7 ] = GGML_OP_ADD;
2085
+
2086
+ for (n_fuse = 8 ; n_fuse > 1 ; --n_fuse) {
2087
+ if (n_fuse % 2 == 1 ) {
2088
+ continue ;
2089
+ }
2090
+ if (ggml_can_fuse (gf, idx, ops, n_fuse)) {
2091
+ if (ggml_are_same_layout (node->src [1 ], fuse[0 ]->src [1 ]) &&
2092
+ ggml_are_same_layout (node->src [1 ], fuse[n_fuse - 2 ]->src [1 ])) {
2093
+ break ;
2094
+ }
2095
+ }
2096
+ }
2097
+ }
2098
+
2053
2099
if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2054
2100
GGML_ASSERT (ggml_is_contiguous (src0));
2055
2101
2056
2102
// src1 is a row
2057
2103
GGML_ASSERT (ne11 == 1 );
2058
2104
2059
2105
switch (dst->op ) {
2060
- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
2106
+ case GGML_OP_ADD:
2107
+ {
2108
+ switch (n_fuse) {
2109
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline ; break ;
2110
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline ; break ;
2111
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline ; break ;
2112
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline ; break ;
2113
+ default :
2114
+ {
2115
+ GGML_ASSERT (n_fuse == 1 );
2116
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ;
2117
+ }
2118
+ }
2119
+ } break ;
2061
2120
case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
2062
2121
case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
2063
2122
case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
@@ -2067,7 +2126,21 @@ static bool ggml_metal_encode_node(
2067
2126
bcast_row = true ;
2068
2127
} else {
2069
2128
switch (dst->op ) {
2070
- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2129
+ case GGML_OP_ADD:
2130
+ {
2131
+ // GGML_LOG_INFO("XXXXXXXXXXXXXXXXXXXXXXXXX n_fuse = %d\n", n_fuse);
2132
+ switch (n_fuse) {
2133
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline ; break ;
2134
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline ; break ;
2135
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline ; break ;
2136
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline ; break ;
2137
+ default :
2138
+ {
2139
+ GGML_ASSERT (n_fuse == 1 );
2140
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2141
+ }
2142
+ }
2143
+ } break ;
2071
2144
case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2072
2145
case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2073
2146
case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
@@ -2107,7 +2180,16 @@ static bool ggml_metal_encode_node(
2107
2180
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2108
2181
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2109
2182
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2110
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
2183
+ for (int f = 0 ; f < n_fuse - 1 ; ++f) {
2184
+ id_src1 = ggml_metal_get_buffer (fuse[f]->src [1 ], &offs_src1);
2185
+
2186
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 3 + f];
2187
+
2188
+ if (f + 1 == n_fuse - 1 ) {
2189
+ id_dst = ggml_metal_get_buffer (fuse[f], &offs_dst);
2190
+ }
2191
+ }
2192
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 + n_fuse];
2111
2193
2112
2194
if (bcast_row) {
2113
2195
const int64_t n = ggml_nelements (dst)/4 ;
@@ -2674,7 +2756,7 @@ static bool ggml_metal_encode_node(
2674
2756
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2675
2757
if (!h_src0) {
2676
2758
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677
- return false ;
2759
+ return 0 ;
2678
2760
}
2679
2761
2680
2762
offs_src0 = 0;
@@ -3550,7 +3632,7 @@ static bool ggml_metal_encode_node(
3550
3632
id <MTLBuffer > h_src1 = ggml_metal_mem_pool_alloc (mem_pool, s_src1);
3551
3633
if (!h_src1) {
3552
3634
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_src1);
3553
- return false ;
3635
+ return 0 ;
3554
3636
}
3555
3637
3556
3638
const int64_t neh0 = ne0;
@@ -3566,15 +3648,15 @@ static bool ggml_metal_encode_node(
3566
3648
id <MTLBuffer > h_dst = ggml_metal_mem_pool_alloc (mem_pool, s_dst);
3567
3649
if (!h_dst) {
3568
3650
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_dst);
3569
- return false ;
3651
+ return 0 ;
3570
3652
}
3571
3653
3572
3654
// tokens per expert
3573
3655
const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
3574
3656
id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
3575
3657
if (!h_tpe) {
3576
3658
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tpe);
3577
- return false ;
3659
+ return 0 ;
3578
3660
}
3579
3661
3580
3662
// id map
@@ -3583,7 +3665,7 @@ static bool ggml_metal_encode_node(
3583
3665
id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
3584
3666
if (!h_ids) {
3585
3667
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
3586
- return false ;
3668
+ return 0 ;
3587
3669
}
3588
3670
3589
3671
{
@@ -5442,7 +5524,7 @@ static bool ggml_metal_encode_node(
5442
5524
}
5443
5525
}
5444
5526
5445
- return true ;
5527
+ return n_fuse ;
5446
5528
}
5447
5529
5448
5530
static enum ggml_status ggml_metal_graph_compute (
@@ -5948,20 +6030,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5948
6030
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs [cb_idx].mem_pool ;
5949
6031
ggml_metal_mem_pool_reset (mem_pool);
5950
6032
5951
- for (int idx = node_start; idx < node_end; ++idx ) {
6033
+ for (int idx = node_start; idx < node_end;) {
5952
6034
if (should_capture) {
5953
6035
[encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
5954
6036
}
5955
6037
5956
- const bool res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6038
+ const int res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
5957
6039
5958
6040
if (should_capture) {
5959
6041
[encoder popDebugGroup ];
5960
6042
}
5961
6043
5962
- if (! res) {
6044
+ if (res == 0 ) {
5963
6045
break ;
5964
6046
}
6047
+
6048
+ idx += res;
5965
6049
}
5966
6050
5967
6051
[encoder endEncoding ];
0 commit comments