55
55
bool has_residency_sets;
56
56
bool has_bfloat;
57
57
bool use_bfloat;
58
+ bool use_fusion;
59
+
60
+ int debug_fusion;
61
+
62
+ // how many times a given op was fused
63
+ uint64_t fuse_cnt[GGML_OP_COUNT];
58
64
59
65
size_t max_size;
60
66
69
75
/* .has_residency_sets =*/ false ,
70
76
/* .has_bfloat =*/ false ,
71
77
/* .use_bfloat =*/ false ,
78
+ /* .use_fusion =*/ true ,
79
+ /* .debug_fusion =*/ 0 ,
80
+ /* .fuse_cnt =*/ { 0 },
72
81
/* .max_size =*/ 0 ,
73
82
/* .name =*/ " " ,
74
83
};
83
92
84
93
if (ctx->mtl_device == nil ) {
85
94
ctx->mtl_device = MTLCreateSystemDefaultDevice ();
86
- }
87
95
88
- if (ctx->mtl_device ) {
89
96
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily: MTLGPUFamilyApple7];
90
97
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
91
98
92
99
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily: MTLGPUFamilyApple7];
93
100
94
101
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
95
- ctx->has_residency_sets = getenv (" GGML_METAL_NO_RESIDENCY" ) == NULL ;
102
+ ctx->has_residency_sets = getenv (" GGML_METAL_NO_RESIDENCY" ) == nil ;
96
103
#endif
97
104
98
105
ctx->has_bfloat = [ctx->mtl_device supportsFamily: MTLGPUFamilyMetal3_GGML];
103
110
#else
104
111
ctx->use_bfloat = false ;
105
112
#endif
113
+ ctx->use_fusion = getenv (" GGML_METAL_FUSION_DISABLE" ) == nil ;
114
+
115
+ {
116
+ const char * val = getenv (" GGML_METAL_FUSION_DEBUG" );
117
+ ctx->debug_fusion = val ? atoi (val) : 0 ;
118
+ }
119
+
120
+ memset (ctx->fuse_cnt , 0 , sizeof (ctx->fuse_cnt ));
106
121
107
122
ctx->max_size = ctx->mtl_device .maxBufferLength ;
108
123
@@ -122,6 +137,17 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
122
137
ctx->mtl_device_ref_count --;
123
138
124
139
if (ctx->mtl_device_ref_count == 0 ) {
140
+ if (ctx->debug_fusion > 0 ) {
141
+ for (int i = 0 ; i < GGML_OP_COUNT; i++) {
142
+ if (ctx->fuse_cnt [i] == 0 ) {
143
+ continue ;
144
+ }
145
+
146
+ // note: cannot use ggml_log here
147
+ fprintf (stderr, " %s : %s : %" PRIu64 " \n " , __func__, ggml_op_name ((enum ggml_op) i), ctx->fuse_cnt [i]);
148
+ }
149
+ }
150
+
125
151
if (ctx->mtl_lock ) {
126
152
[ctx->mtl_lock release ];
127
153
ctx->mtl_lock = nil ;
@@ -2141,7 +2167,7 @@ static int ggml_metal_encode_node(
2141
2167
// c[1] = add(c[0], b[1])
2142
2168
// c[2] = add(c[1], b[2])
2143
2169
// ...
2144
- {
2170
+ if (ctx_dev-> use_fusion ) {
2145
2171
ops[0 ] = GGML_OP_ADD;
2146
2172
ops[1 ] = GGML_OP_ADD;
2147
2173
ops[2 ] = GGML_OP_ADD;
@@ -2174,10 +2200,16 @@ static int ggml_metal_encode_node(
2174
2200
break ;
2175
2201
}
2176
2202
2203
+ ctx_dev->fuse_cnt [nodes[n_fuse + 1 ]->op]++;
2204
+
2177
2205
args.o1 [n_fuse + 1 ] = offs_fuse;
2178
2206
}
2179
2207
2180
2208
++n_fuse;
2209
+
2210
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1 ) {
2211
+ GGML_LOG_DEBUG (" %s : fuse: ADD x %d \n " , __func__, n_fuse);
2212
+ }
2181
2213
}
2182
2214
2183
2215
// GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
@@ -4252,7 +4284,7 @@ static int ggml_metal_encode_node(
4252
4284
// d[0] = rms_norm(a)
4253
4285
// d[1] = mul(d[0], b)
4254
4286
// d[2] = add(d[1], c)
4255
- {
4287
+ if (ctx_dev-> use_fusion ) {
4256
4288
ops[0 ] = GGML_OP_RMS_NORM;
4257
4289
ops[1 ] = GGML_OP_MUL;
4258
4290
ops[2 ] = GGML_OP_ADD;
@@ -4278,6 +4310,8 @@ static int ggml_metal_encode_node(
4278
4310
break ;
4279
4311
}
4280
4312
4313
+ ctx_dev->fuse_cnt [nodes[n_fuse + 1 ]->op]++;
4314
+
4281
4315
id_fuse[n_fuse] = ggml_metal_get_buffer (nodes[n_fuse + 1 ]->src [1 ], &offs_fuse[n_fuse]);
4282
4316
4283
4317
args.nef1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [1 ];
@@ -4290,6 +4324,15 @@ static int ggml_metal_encode_node(
4290
4324
}
4291
4325
4292
4326
++n_fuse;
4327
+
4328
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1 ) {
4329
+ if (n_fuse == 2 ) {
4330
+ GGML_LOG_DEBUG (" %s : fuse: RMS_NORM + MUL\n " , __func__);
4331
+ }
4332
+ if (n_fuse == 3 ) {
4333
+ GGML_LOG_DEBUG (" %s : fuse: RMS_NORM + MUL + ADD\n " , __func__);
4334
+ }
4335
+ }
4293
4336
}
4294
4337
4295
4338
// GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
0 commit comments