Skip to content

Commit e7233c8

Browse files
committed
metal : fuse rms_norm + mul + add
ggml-ci
1 parent 303f2ae commit e7233c8

File tree

3 files changed

+142
-29
lines changed

3 files changed

+142
-29
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ typedef struct {
241241
float max_bias;
242242
float m0;
243243
float m1;
244-
uint16_t n_head_log2;
244+
int16_t n_head_log2;
245245
float logit_softcap;
246246
} ggml_metal_kargs_flash_attn_ext;
247247

@@ -378,8 +378,16 @@ typedef struct {
378378
typedef struct {
379379
int32_t ne00;
380380
int32_t ne00_4;
381-
uint64_t nb01;
381+
uint64_t nb1;
382+
uint64_t nb2;
383+
uint64_t nb3;
382384
float eps;
385+
int32_t nef1[3];
386+
int32_t nef2[3];
387+
int32_t nef3[3];
388+
uint64_t nbf1[3];
389+
uint64_t nbf2[3];
390+
uint64_t nbf3[3];
383391
} ggml_metal_kargs_rms_norm;
384392

385393
typedef struct {
@@ -485,7 +493,7 @@ typedef struct {
485493
float max_bias;
486494
float m0;
487495
float m1;
488-
uint32_t n_head_log2;
496+
int32_t n_head_log2;
489497
} ggml_metal_kargs_soft_max;
490498

491499
typedef struct {

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
232232
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
233233
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
234234
GGML_METAL_KERNEL_TYPE_RMS_NORM,
235+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
236+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
235237
GGML_METAL_KERNEL_TYPE_L2_NORM,
236238
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
237239
GGML_METAL_KERNEL_TYPE_NORM,
@@ -1234,6 +1236,8 @@ @implementation GGMLMetalClass
12341236
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
12351237
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
12361238
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1239+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1240+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
12371241
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12381242
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
12391243
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
@@ -2133,6 +2137,10 @@ static int ggml_metal_encode_node(
21332137
/*.o1 =*/ { offs_src1 },
21342138
};
21352139

2140+
// c[0] = add(a, b[0])
2141+
// c[1] = add(c[0], b[1])
2142+
// c[2] = add(c[1], b[2])
2143+
// ...
21362144
{
21372145
ops[0] = GGML_OP_ADD;
21382146
ops[1] = GGML_OP_ADD;
@@ -2151,6 +2159,11 @@ static int ggml_metal_encode_node(
21512159
break;
21522160
}
21532161

2162+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
2163+
break;
2164+
}
2165+
2166+
// b[0] === b[1] === ...
21542167
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
21552168
break;
21562169
}
@@ -4213,12 +4226,86 @@ static int ggml_metal_encode_node(
42134226
case GGML_OP_RMS_NORM:
42144227
{
42154228
GGML_ASSERT(ne00 % 4 == 0);
4216-
GGML_ASSERT(ggml_is_contiguous_1(src0));
4229+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
42174230

42184231
float eps;
42194232
memcpy(&eps, dst->op_params, sizeof(float));
42204233

4221-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
4234+
ggml_metal_kargs_rms_norm args = {
4235+
/*.ne00 =*/ ne00,
4236+
/*.ne00_4 =*/ ne00/4,
4237+
/*.nb1 =*/ nb1,
4238+
/*.nb2 =*/ nb2,
4239+
/*.nb3 =*/ nb3,
4240+
/*.eps =*/ eps,
4241+
/*.nef1 =*/ { ne01 },
4242+
/*.nef2 =*/ { ne02 },
4243+
/*.nef3 =*/ { ne03 },
4244+
/*.nbf1 =*/ { nb01 },
4245+
/*.nbf2 =*/ { nb02 },
4246+
/*.nbf3 =*/ { nb03 },
4247+
};
4248+
4249+
size_t offs_fuse[2] = { 0, 0 };
4250+
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
4251+
4252+
// d[0] = rms_norm(a)
4253+
// d[1] = mul(d[0], b)
4254+
// d[2] = add(d[1], c)
4255+
{
4256+
ops[0] = GGML_OP_RMS_NORM;
4257+
ops[1] = GGML_OP_MUL;
4258+
ops[2] = GGML_OP_ADD;
4259+
4260+
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
4261+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4262+
break;
4263+
}
4264+
4265+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
4266+
break;
4267+
}
4268+
4269+
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
4270+
break;
4271+
}
4272+
4273+
if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
4274+
break;
4275+
}
4276+
4277+
if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
4278+
break;
4279+
}
4280+
4281+
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
4282+
4283+
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
4284+
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
4285+
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
4286+
4287+
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
4288+
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
4289+
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
4290+
}
4291+
4292+
++n_fuse;
4293+
}
4294+
4295+
//GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
4296+
4297+
if (n_fuse > 1) {
4298+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
4299+
}
4300+
4301+
id<MTLComputePipelineState> pipeline;
4302+
4303+
switch (n_fuse) {
4304+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
4305+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
4306+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
4307+
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
4308+
}
42224309

42234310
int nth = 32; // SIMD width
42244311

@@ -4229,23 +4316,16 @@ static int ggml_metal_encode_node(
42294316
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
42304317
nth = MIN(nth, ne00/4);
42314318

4232-
ggml_metal_kargs_rms_norm args = {
4233-
/*.ne00 =*/ ne00,
4234-
/*.ne00_4 =*/ ne00/4,
4235-
/*.nb01 =*/ nb01,
4236-
/*.eps =*/ eps,
4237-
};
4238-
42394319
[encoder setComputePipelineState:pipeline];
4240-
[encoder setBytes:&args length:sizeof(args) atIndex:0];
4241-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4242-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4320+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
4321+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4322+
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
4323+
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
4324+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
42434325

42444326
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
42454327

4246-
const int64_t nrows = ggml_nrows(src0);
4247-
4248-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4328+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
42494329
} break;
42504330
case GGML_OP_L2_NORM:
42514331
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,26 +2239,39 @@ kernel void kernel_norm(
22392239
}
22402240
}
22412241

2242-
kernel void kernel_rms_norm(
2242+
// F == 1 : rms_norm (no fuse)
2243+
// F == 2 : rms_norm + mul
2244+
// F == 3 : rms_norm + mul + add
2245+
template <short F>
2246+
kernel void kernel_rms_norm_fuse_impl(
22432247
constant ggml_metal_kargs_rms_norm & args,
22442248
device const char * src0,
2249+
device const char * src1_0,
2250+
device const char * src1_1,
22452251
device char * dst,
22462252
threadgroup float * shmem_f32 [[threadgroup(0)]],
2247-
uint tgpig[[threadgroup_position_in_grid]],
2248-
ushort tpitg[[thread_position_in_threadgroup]],
2249-
ushort sgitg[[simdgroup_index_in_threadgroup]],
2250-
ushort tiisg[[thread_index_in_simdgroup]],
2251-
ushort ntg[[threads_per_threadgroup]]) {
2253+
uint3 tgpig[[threadgroup_position_in_grid]],
2254+
ushort3 tpitg[[thread_position_in_threadgroup]],
2255+
ushort sgitg[[simdgroup_index_in_threadgroup]],
2256+
ushort tiisg[[thread_index_in_simdgroup]],
2257+
ushort3 ntg[[threads_per_threadgroup]]) {
22522258
if (sgitg == 0) {
22532259
shmem_f32[tiisg] = 0.0f;
22542260
}
22552261

2256-
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
2262+
const int i01 = tgpig.x;
2263+
const int i02 = tgpig.y;
2264+
const int i03 = tgpig.z;
2265+
2266+
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2267+
2268+
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2269+
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
22572270

22582271
float sumf = 0.0f;
22592272

22602273
// parallel sum
2261-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2274+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
22622275
sumf += dot(x[i00], x[i00]);
22632276
}
22642277
sumf = simd_sum(sumf);
@@ -2277,12 +2290,24 @@ kernel void kernel_rms_norm(
22772290
const float mean = sumf/args.ne00;
22782291
const float scale = 1.0f/sqrt(mean + args.eps);
22792292

2280-
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
2281-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2282-
y[i00] = x[i00] * scale;
2293+
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2294+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
2295+
if (F == 1) {
2296+
y[i00] = (x[i00]*scale);
2297+
} else if (F == 2) {
2298+
y[i00] = (x[i00]*scale)*f0[i00];
2299+
} else if (F == 3) {
2300+
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
2301+
}
22832302
}
22842303
}
22852304

2305+
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
2306+
2307+
template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
2308+
template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
2309+
template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
2310+
22862311
kernel void kernel_l2_norm(
22872312
constant ggml_metal_kargs_l2_norm & args,
22882313
device const char * src0,

0 commit comments

Comments
 (0)