Skip to content

Commit 7f0c759

Browse files
committed
metal : env var + stats
ggml-ci
1 parent e4db33e commit 7f0c759

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ typedef struct {
241241
float max_bias;
242242
float m0;
243243
float m1;
244-
int16_t n_head_log2;
244+
int32_t n_head_log2;
245245
float logit_softcap;
246246
} ggml_metal_kargs_flash_attn_ext;
247247

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@
5555
bool has_residency_sets;
5656
bool has_bfloat;
5757
bool use_bfloat;
58+
bool use_fusion;
59+
60+
int debug_fusion;
61+
62+
// how many times a given op was fused
63+
uint64_t fuse_cnt[GGML_OP_COUNT];
5864

5965
size_t max_size;
6066

@@ -69,6 +75,9 @@
6975
/*.has_residency_sets =*/ false,
7076
/*.has_bfloat =*/ false,
7177
/*.use_bfloat =*/ false,
78+
/*.use_fusion =*/ true,
79+
/*.debug_fusion =*/ 0,
80+
/*.fuse_cnt =*/ { 0 },
7281
/*.max_size =*/ 0,
7382
/*.name =*/ "",
7483
};
@@ -83,16 +92,14 @@
8392

8493
if (ctx->mtl_device == nil) {
8594
ctx->mtl_device = MTLCreateSystemDefaultDevice();
86-
}
8795

88-
if (ctx->mtl_device) {
8996
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
9097
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
9198

9299
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
93100

94101
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
95-
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
102+
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
96103
#endif
97104

98105
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@
103110
#else
104111
ctx->use_bfloat = false;
105112
#endif
113+
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
114+
115+
{
116+
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
117+
ctx->debug_fusion = val ? atoi(val) : 0;
118+
}
119+
120+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
106121

107122
ctx->max_size = ctx->mtl_device.maxBufferLength;
108123

@@ -122,6 +137,17 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
122137
ctx->mtl_device_ref_count--;
123138

124139
if (ctx->mtl_device_ref_count == 0) {
140+
if (ctx->debug_fusion > 0) {
141+
for (int i = 0; i < GGML_OP_COUNT; i++) {
142+
if (ctx->fuse_cnt[i] == 0) {
143+
continue;
144+
}
145+
146+
// note: cannot use ggml_log here
147+
fprintf(stderr, "%s: %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
148+
}
149+
}
150+
125151
if (ctx->mtl_lock) {
126152
[ctx->mtl_lock release];
127153
ctx->mtl_lock = nil;
@@ -2123,7 +2149,7 @@ static int ggml_metal_encode_node(
21232149
// c[1] = add(c[0], b[1])
21242150
// c[2] = add(c[1], b[2])
21252151
// ...
2126-
{
2152+
if (ctx_dev->use_fusion) {
21272153
ops[0] = GGML_OP_ADD;
21282154
ops[1] = GGML_OP_ADD;
21292155
ops[2] = GGML_OP_ADD;
@@ -2156,10 +2182,16 @@ static int ggml_metal_encode_node(
21562182
break;
21572183
}
21582184

2185+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
2186+
21592187
args.o1[n_fuse + 1] = offs_fuse;
21602188
}
21612189

21622190
++n_fuse;
2191+
2192+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
2193+
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2194+
}
21632195
}
21642196

21652197
//GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
@@ -4162,7 +4194,7 @@ static int ggml_metal_encode_node(
41624194
// d[0] = rms_norm(a)
41634195
// d[1] = mul(d[0], b)
41644196
// d[2] = add(d[1], c)
4165-
{
4197+
if (ctx_dev->use_fusion) {
41664198
ops[0] = GGML_OP_RMS_NORM;
41674199
ops[1] = GGML_OP_MUL;
41684200
ops[2] = GGML_OP_ADD;
@@ -4188,6 +4220,8 @@ static int ggml_metal_encode_node(
41884220
break;
41894221
}
41904222

4223+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
4224+
41914225
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
41924226

41934227
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
@@ -4200,6 +4234,15 @@ static int ggml_metal_encode_node(
42004234
}
42014235

42024236
++n_fuse;
4237+
4238+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
4239+
if (n_fuse == 2) {
4240+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
4241+
}
4242+
if (n_fuse == 3) {
4243+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
4244+
}
4245+
}
42034246
}
42044247

42054248
//GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,9 +2249,11 @@ kernel void kernel_rms_norm_fuse_impl(
22492249
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
22502250
if (F == 1) {
22512251
y[i00] = (x[i00]*scale);
2252-
} else if (F == 2) {
2252+
}
2253+
if (F == 2) {
22532254
y[i00] = (x[i00]*scale)*f0[i00];
2254-
} else if (F == 3) {
2255+
}
2256+
if (F == 3) {
22552257
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
22562258
}
22572259
}

0 commit comments

Comments
 (0)