Skip to content

Commit 4d568cb

Browse files
committed
metal : env var + stats
ggml-ci
1 parent e7233c8 commit 4d568cb

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ typedef struct {
241241
float max_bias;
242242
float m0;
243243
float m1;
244-
int16_t n_head_log2;
244+
int32_t n_head_log2;
245245
float logit_softcap;
246246
} ggml_metal_kargs_flash_attn_ext;
247247

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@
5555
bool has_residency_sets;
5656
bool has_bfloat;
5757
bool use_bfloat;
58+
bool use_fusion;
59+
60+
int debug_fusion;
61+
62+
// how many times a given op was fused
63+
uint64_t fuse_cnt[GGML_OP_COUNT];
5864

5965
size_t max_size;
6066

@@ -69,6 +75,9 @@
6975
/*.has_residency_sets =*/ false,
7076
/*.has_bfloat =*/ false,
7177
/*.use_bfloat =*/ false,
78+
/*.use_fusion =*/ true,
79+
/*.debug_fusion =*/ 0,
80+
/*.fuse_cnt =*/ { 0 },
7281
/*.max_size =*/ 0,
7382
/*.name =*/ "",
7483
};
@@ -83,16 +92,14 @@
8392

8493
if (ctx->mtl_device == nil) {
8594
ctx->mtl_device = MTLCreateSystemDefaultDevice();
86-
}
8795

88-
if (ctx->mtl_device) {
8996
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
9097
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
9198

9299
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
93100

94101
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
95-
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
102+
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
96103
#endif
97104

98105
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@
103110
#else
104111
ctx->use_bfloat = false;
105112
#endif
113+
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
114+
115+
{
116+
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
117+
ctx->debug_fusion = val ? atoi(val) : 0;
118+
}
119+
120+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
106121

107122
ctx->max_size = ctx->mtl_device.maxBufferLength;
108123

@@ -122,6 +137,17 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
122137
ctx->mtl_device_ref_count--;
123138

124139
if (ctx->mtl_device_ref_count == 0) {
140+
if (ctx->debug_fusion > 0) {
141+
for (int i = 0; i < GGML_OP_COUNT; i++) {
142+
if (ctx->fuse_cnt[i] == 0) {
143+
continue;
144+
}
145+
146+
// note: cannot use ggml_log here
147+
fprintf(stderr, "%s: %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
148+
}
149+
}
150+
125151
if (ctx->mtl_lock) {
126152
[ctx->mtl_lock release];
127153
ctx->mtl_lock = nil;
@@ -2141,7 +2167,7 @@ static int ggml_metal_encode_node(
21412167
// c[1] = add(c[0], b[1])
21422168
// c[2] = add(c[1], b[2])
21432169
// ...
2144-
{
2170+
if (ctx_dev->use_fusion) {
21452171
ops[0] = GGML_OP_ADD;
21462172
ops[1] = GGML_OP_ADD;
21472173
ops[2] = GGML_OP_ADD;
@@ -2174,10 +2200,16 @@ static int ggml_metal_encode_node(
21742200
break;
21752201
}
21762202

2203+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
2204+
21772205
args.o1[n_fuse + 1] = offs_fuse;
21782206
}
21792207

21802208
++n_fuse;
2209+
2210+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
2211+
GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2212+
}
21812213
}
21822214

21832215
//GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
@@ -4252,7 +4284,7 @@ static int ggml_metal_encode_node(
42524284
// d[0] = rms_norm(a)
42534285
// d[1] = mul(d[0], b)
42544286
// d[2] = add(d[1], c)
4255-
{
4287+
if (ctx_dev->use_fusion) {
42564288
ops[0] = GGML_OP_RMS_NORM;
42574289
ops[1] = GGML_OP_MUL;
42584290
ops[2] = GGML_OP_ADD;
@@ -4278,6 +4310,8 @@ static int ggml_metal_encode_node(
42784310
break;
42794311
}
42804312

4313+
ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
4314+
42814315
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
42824316

42834317
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
@@ -4290,6 +4324,15 @@ static int ggml_metal_encode_node(
42904324
}
42914325

42924326
++n_fuse;
4327+
4328+
if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
4329+
if (n_fuse == 2) {
4330+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
4331+
}
4332+
if (n_fuse == 3) {
4333+
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
4334+
}
4335+
}
42934336
}
42944337

42954338
//GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,9 +2294,11 @@ kernel void kernel_rms_norm_fuse_impl(
22942294
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
22952295
if (F == 1) {
22962296
y[i00] = (x[i00]*scale);
2297-
} else if (F == 2) {
2297+
}
2298+
if (F == 2) {
22982299
y[i00] = (x[i00]*scale)*f0[i00];
2299-
} else if (F == 3) {
2300+
}
2301+
if (F == 3) {
23002302
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
23012303
}
23022304
}

0 commit comments

Comments
 (0)