Skip to content

Commit e4db33e

Browse files
committed
metal : fuse rms_norm + mul + add
ggml-ci
1 parent 17224b6 commit e4db33e

File tree

3 files changed

+142
-29
lines changed

3 files changed

+142
-29
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ typedef struct {
241241
float max_bias;
242242
float m0;
243243
float m1;
244-
uint16_t n_head_log2;
244+
int16_t n_head_log2;
245245
float logit_softcap;
246246
} ggml_metal_kargs_flash_attn_ext;
247247

@@ -378,8 +378,16 @@ typedef struct {
378378
typedef struct {
379379
int32_t ne00;
380380
int32_t ne00_4;
381-
uint64_t nb01;
381+
uint64_t nb1;
382+
uint64_t nb2;
383+
uint64_t nb3;
382384
float eps;
385+
int32_t nef1[3];
386+
int32_t nef2[3];
387+
int32_t nef3[3];
388+
uint64_t nbf1[3];
389+
uint64_t nbf2[3];
390+
uint64_t nbf3[3];
383391
} ggml_metal_kargs_rms_norm;
384392

385393
typedef struct {
@@ -485,7 +493,7 @@ typedef struct {
485493
float max_bias;
486494
float m0;
487495
float m1;
488-
uint32_t n_head_log2;
496+
int32_t n_head_log2;
489497
} ggml_metal_kargs_soft_max;
490498

491499
typedef struct {

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
226226
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
227227
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
228228
GGML_METAL_KERNEL_TYPE_RMS_NORM,
229+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
230+
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
229231
GGML_METAL_KERNEL_TYPE_L2_NORM,
230232
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
231233
GGML_METAL_KERNEL_TYPE_NORM,
@@ -1222,6 +1224,8 @@ @implementation GGMLMetalClass
12221224
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
12231225
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
12241226
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1227+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1228+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
12251229
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
12261230
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
12271231
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
@@ -2115,6 +2119,10 @@ static int ggml_metal_encode_node(
21152119
/*.o1 =*/ { offs_src1 },
21162120
};
21172121

2122+
// c[0] = add(a, b[0])
2123+
// c[1] = add(c[0], b[1])
2124+
// c[2] = add(c[1], b[2])
2125+
// ...
21182126
{
21192127
ops[0] = GGML_OP_ADD;
21202128
ops[1] = GGML_OP_ADD;
@@ -2133,6 +2141,11 @@ static int ggml_metal_encode_node(
21332141
break;
21342142
}
21352143

2144+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
2145+
break;
2146+
}
2147+
2148+
// b[0] === b[1] === ...
21362149
if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
21372150
break;
21382151
}
@@ -4123,12 +4136,86 @@ static int ggml_metal_encode_node(
41234136
case GGML_OP_RMS_NORM:
41244137
{
41254138
GGML_ASSERT(ne00 % 4 == 0);
4126-
GGML_ASSERT(ggml_is_contiguous_1(src0));
4139+
GGML_ASSERT(ggml_is_contiguous_rows(src0));
41274140

41284141
float eps;
41294142
memcpy(&eps, dst->op_params, sizeof(float));
41304143

4131-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
4144+
ggml_metal_kargs_rms_norm args = {
4145+
/*.ne00 =*/ ne00,
4146+
/*.ne00_4 =*/ ne00/4,
4147+
/*.nb1 =*/ nb1,
4148+
/*.nb2 =*/ nb2,
4149+
/*.nb3 =*/ nb3,
4150+
/*.eps =*/ eps,
4151+
/*.nef1 =*/ { ne01 },
4152+
/*.nef2 =*/ { ne02 },
4153+
/*.nef3 =*/ { ne03 },
4154+
/*.nbf1 =*/ { nb01 },
4155+
/*.nbf2 =*/ { nb02 },
4156+
/*.nbf3 =*/ { nb03 },
4157+
};
4158+
4159+
size_t offs_fuse[2] = { 0, 0 };
4160+
id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
4161+
4162+
// d[0] = rms_norm(a)
4163+
// d[1] = mul(d[0], b)
4164+
// d[2] = add(d[1], c)
4165+
{
4166+
ops[0] = GGML_OP_RMS_NORM;
4167+
ops[1] = GGML_OP_MUL;
4168+
ops[2] = GGML_OP_ADD;
4169+
4170+
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
4171+
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
4172+
break;
4173+
}
4174+
4175+
if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
4176+
break;
4177+
}
4178+
4179+
if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
4180+
break;
4181+
}
4182+
4183+
if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
4184+
break;
4185+
}
4186+
4187+
if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
4188+
break;
4189+
}
4190+
4191+
id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
4192+
4193+
args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
4194+
args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
4195+
args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
4196+
4197+
args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
4198+
args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
4199+
args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
4200+
}
4201+
4202+
++n_fuse;
4203+
}
4204+
4205+
//GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
4206+
4207+
if (n_fuse > 1) {
4208+
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
4209+
}
4210+
4211+
id<MTLComputePipelineState> pipeline;
4212+
4213+
switch (n_fuse) {
4214+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
4215+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
4216+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
4217+
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
4218+
}
41324219

41334220
int nth = 32; // SIMD width
41344221

@@ -4139,23 +4226,16 @@ static int ggml_metal_encode_node(
41394226
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
41404227
nth = MIN(nth, ne00/4);
41414228

4142-
ggml_metal_kargs_rms_norm args = {
4143-
/*.ne00 =*/ ne00,
4144-
/*.ne00_4 =*/ ne00/4,
4145-
/*.nb01 =*/ nb01,
4146-
/*.eps =*/ eps,
4147-
};
4148-
41494229
[encoder setComputePipelineState:pipeline];
4150-
[encoder setBytes:&args length:sizeof(args) atIndex:0];
4151-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4152-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4230+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
4231+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4232+
[encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
4233+
[encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
4234+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
41534235

41544236
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
41554237

4156-
const int64_t nrows = ggml_nrows(src0);
4157-
4158-
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4238+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
41594239
} break;
41604240
case GGML_OP_L2_NORM:
41614241
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,26 +2194,39 @@ kernel void kernel_norm(
21942194
}
21952195
}
21962196

2197-
kernel void kernel_rms_norm(
2197+
// F == 1 : rms_norm (no fuse)
2198+
// F == 2 : rms_norm + mul
2199+
// F == 3 : rms_norm + mul + add
2200+
template <short F>
2201+
kernel void kernel_rms_norm_fuse_impl(
21982202
constant ggml_metal_kargs_rms_norm & args,
21992203
device const char * src0,
2204+
device const char * src1_0,
2205+
device const char * src1_1,
22002206
device char * dst,
22012207
threadgroup float * shmem_f32 [[threadgroup(0)]],
2202-
uint tgpig[[threadgroup_position_in_grid]],
2203-
ushort tpitg[[thread_position_in_threadgroup]],
2204-
ushort sgitg[[simdgroup_index_in_threadgroup]],
2205-
ushort tiisg[[thread_index_in_simdgroup]],
2206-
ushort ntg[[threads_per_threadgroup]]) {
2208+
uint3 tgpig[[threadgroup_position_in_grid]],
2209+
ushort3 tpitg[[thread_position_in_threadgroup]],
2210+
ushort sgitg[[simdgroup_index_in_threadgroup]],
2211+
ushort tiisg[[thread_index_in_simdgroup]],
2212+
ushort3 ntg[[threads_per_threadgroup]]) {
22072213
if (sgitg == 0) {
22082214
shmem_f32[tiisg] = 0.0f;
22092215
}
22102216

2211-
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
2217+
const int i01 = tgpig.x;
2218+
const int i02 = tgpig.y;
2219+
const int i03 = tgpig.z;
2220+
2221+
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2222+
2223+
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2224+
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
22122225

22132226
float sumf = 0.0f;
22142227

22152228
// parallel sum
2216-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2229+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
22172230
sumf += dot(x[i00], x[i00]);
22182231
}
22192232
sumf = simd_sum(sumf);
@@ -2232,12 +2245,24 @@ kernel void kernel_rms_norm(
22322245
const float mean = sumf/args.ne00;
22332246
const float scale = 1.0f/sqrt(mean + args.eps);
22342247

2235-
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
2236-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
2237-
y[i00] = x[i00] * scale;
2248+
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2249+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
2250+
if (F == 1) {
2251+
y[i00] = (x[i00]*scale);
2252+
} else if (F == 2) {
2253+
y[i00] = (x[i00]*scale)*f0[i00];
2254+
} else if (F == 3) {
2255+
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
2256+
}
22382257
}
22392258
}
22402259

2260+
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
2261+
2262+
template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
2263+
template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
2264+
template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
2265+
22412266
kernel void kernel_l2_norm(
22422267
constant ggml_metal_kargs_l2_norm & args,
22432268
device const char * src0,

0 commit comments

Comments
 (0)