55
55
bool has_residency_sets;
56
56
bool has_bfloat;
57
57
bool use_bfloat;
58
+ bool use_fusion;
59
+
60
+ int debug_fusion;
61
+
62
+ // how many times a given op was fused
63
+ uint64_t fuse_cnt[GGML_OP_COUNT];
58
64
59
65
size_t max_size;
60
66
69
75
/* .has_residency_sets =*/ false ,
70
76
/* .has_bfloat =*/ false ,
71
77
/* .use_bfloat =*/ false ,
78
+ /* .use_fusion =*/ true ,
79
+ /* .debug_fusion =*/ 0 ,
80
+ /* .fuse_cnt =*/ { 0 },
72
81
/* .max_size =*/ 0 ,
73
82
/* .name =*/ " " ,
74
83
};
83
92
84
93
if (ctx->mtl_device == nil ) {
85
94
ctx->mtl_device = MTLCreateSystemDefaultDevice ();
86
- }
87
95
88
- if (ctx->mtl_device ) {
89
96
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily: MTLGPUFamilyApple7];
90
97
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
91
98
92
99
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily: MTLGPUFamilyApple7];
93
100
94
101
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
95
- ctx->has_residency_sets = getenv (" GGML_METAL_NO_RESIDENCY" ) == NULL ;
102
+ ctx->has_residency_sets = getenv (" GGML_METAL_NO_RESIDENCY" ) == nil ;
96
103
#endif
97
104
98
105
ctx->has_bfloat = [ctx->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
103
110
#else
104
111
ctx->use_bfloat = false ;
105
112
#endif
113
+ ctx->use_fusion = getenv (" GGML_METAL_FUSION_DISABLE" ) == nil ;
114
+
115
+ {
116
+ const char * val = getenv (" GGML_METAL_FUSION_DEBUG" );
117
+ ctx->debug_fusion = val ? atoi (val) : 0 ;
118
+ }
119
+
120
+ memset (ctx->fuse_cnt , 0 , sizeof (ctx->fuse_cnt ));
106
121
107
122
ctx->max_size = ctx->mtl_device .maxBufferLength ;
108
123
@@ -122,6 +137,17 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
122
137
ctx->mtl_device_ref_count --;
123
138
124
139
if (ctx->mtl_device_ref_count == 0 ) {
140
+ if (ctx->debug_fusion > 0 ) {
141
+ for (int i = 0 ; i < GGML_OP_COUNT; i++) {
142
+ if (ctx->fuse_cnt [i] == 0 ) {
143
+ continue ;
144
+ }
145
+
146
+ // note: cannot use ggml_log here
147
+ fprintf (stderr, " %s : %s : %" PRIu64 " \n " , __func__, ggml_op_name ((enum ggml_op) i), ctx->fuse_cnt [i]);
148
+ }
149
+ }
150
+
125
151
if (ctx->mtl_lock ) {
126
152
[ctx->mtl_lock release ];
127
153
ctx->mtl_lock = nil ;
@@ -2123,7 +2149,7 @@ static int ggml_metal_encode_node(
2123
2149
// c[1] = add(c[0], b[1])
2124
2150
// c[2] = add(c[1], b[2])
2125
2151
// ...
2126
- {
2152
+ if (ctx_dev-> use_fusion ) {
2127
2153
ops[0 ] = GGML_OP_ADD;
2128
2154
ops[1 ] = GGML_OP_ADD;
2129
2155
ops[2 ] = GGML_OP_ADD;
@@ -2156,10 +2182,16 @@ static int ggml_metal_encode_node(
2156
2182
break ;
2157
2183
}
2158
2184
2185
+ ctx_dev->fuse_cnt [nodes[n_fuse + 1 ]->op]++;
2186
+
2159
2187
args.o1 [n_fuse + 1 ] = offs_fuse;
2160
2188
}
2161
2189
2162
2190
++n_fuse;
2191
+
2192
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1 ) {
2193
+ GGML_LOG_DEBUG (" %s : fuse: ADD x %d \n " , __func__, n_fuse);
2194
+ }
2163
2195
}
2164
2196
2165
2197
// GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
@@ -4162,7 +4194,7 @@ static int ggml_metal_encode_node(
4162
4194
// d[0] = rms_norm(a)
4163
4195
// d[1] = mul(d[0], b)
4164
4196
// d[2] = add(d[1], c)
4165
- {
4197
+ if (ctx_dev-> use_fusion ) {
4166
4198
ops[0 ] = GGML_OP_RMS_NORM;
4167
4199
ops[1 ] = GGML_OP_MUL;
4168
4200
ops[2 ] = GGML_OP_ADD;
@@ -4188,6 +4220,8 @@ static int ggml_metal_encode_node(
4188
4220
break ;
4189
4221
}
4190
4222
4223
+ ctx_dev->fuse_cnt [nodes[n_fuse + 1 ]->op]++;
4224
+
4191
4225
id_fuse[n_fuse] = ggml_metal_get_buffer (nodes[n_fuse + 1 ]->src [1 ], &offs_fuse[n_fuse]);
4192
4226
4193
4227
args.nef1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [1 ];
@@ -4200,6 +4234,15 @@ static int ggml_metal_encode_node(
4200
4234
}
4201
4235
4202
4236
++n_fuse;
4237
+
4238
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1 ) {
4239
+ if (n_fuse == 2 ) {
4240
+ GGML_LOG_DEBUG (" %s : fuse: RMS_NORM + MUL\n " , __func__);
4241
+ }
4242
+ if (n_fuse == 3 ) {
4243
+ GGML_LOG_DEBUG (" %s : fuse: RMS_NORM + MUL + ADD\n " , __func__);
4244
+ }
4245
+ }
4203
4246
}
4204
4247
4205
4248
// GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
0 commit comments