@@ -232,6 +232,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
232
232
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
233
233
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
234
234
GGML_METAL_KERNEL_TYPE_RMS_NORM,
235
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
236
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
235
237
GGML_METAL_KERNEL_TYPE_L2_NORM,
236
238
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
237
239
GGML_METAL_KERNEL_TYPE_NORM,
@@ -1234,6 +1236,8 @@ @implementation GGMLMetalClass
1234
1236
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
1235
1237
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
1236
1238
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1239
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1240
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1237
1241
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1238
1242
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1239
1243
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
@@ -2133,6 +2137,10 @@ static int ggml_metal_encode_node(
2133
2137
/* .o1 =*/ { offs_src1 },
2134
2138
};
2135
2139
2140
+ // c[0] = add(a, b[0])
2141
+ // c[1] = add(c[0], b[1])
2142
+ // c[2] = add(c[1], b[2])
2143
+ // ...
2136
2144
{
2137
2145
ops[0 ] = GGML_OP_ADD;
2138
2146
ops[1 ] = GGML_OP_ADD;
@@ -2151,6 +2159,11 @@ static int ggml_metal_encode_node(
2151
2159
break ;
2152
2160
}
2153
2161
2162
+ if (nodes[n_fuse] != nodes[n_fuse + 1 ]->src [0 ]) {
2163
+ break ;
2164
+ }
2165
+
2166
+ // b[0] === b[1] === ...
2154
2167
if (!ggml_are_same_layout (nodes[n_fuse]->src [1 ], nodes[n_fuse + 1 ]->src [1 ])) {
2155
2168
break ;
2156
2169
}
@@ -4213,12 +4226,86 @@ static int ggml_metal_encode_node(
4213
4226
case GGML_OP_RMS_NORM:
4214
4227
{
4215
4228
GGML_ASSERT (ne00 % 4 == 0 );
4216
- GGML_ASSERT (ggml_is_contiguous_1 (src0));
4229
+ GGML_ASSERT (ggml_is_contiguous_rows (src0));
4217
4230
4218
4231
float eps;
4219
4232
memcpy (&eps, dst->op_params , sizeof (float ));
4220
4233
4221
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline ;
4234
+ ggml_metal_kargs_rms_norm args = {
4235
+ /* .ne00 =*/ ne00,
4236
+ /* .ne00_4 =*/ ne00/4 ,
4237
+ /* .nb1 =*/ nb1,
4238
+ /* .nb2 =*/ nb2,
4239
+ /* .nb3 =*/ nb3,
4240
+ /* .eps =*/ eps,
4241
+ /* .nef1 =*/ { ne01 },
4242
+ /* .nef2 =*/ { ne02 },
4243
+ /* .nef3 =*/ { ne03 },
4244
+ /* .nbf1 =*/ { nb01 },
4245
+ /* .nbf2 =*/ { nb02 },
4246
+ /* .nbf3 =*/ { nb03 },
4247
+ };
4248
+
4249
+ size_t offs_fuse[2 ] = { 0 , 0 };
4250
+ id <MTLBuffer > id_fuse[2 ] = { id_src0, id_src0 };
4251
+
4252
+ // d[0] = rms_norm(a)
4253
+ // d[1] = mul(d[0], b)
4254
+ // d[2] = add(d[1], c)
4255
+ {
4256
+ ops[0 ] = GGML_OP_RMS_NORM;
4257
+ ops[1 ] = GGML_OP_MUL;
4258
+ ops[2 ] = GGML_OP_ADD;
4259
+
4260
+ for (n_fuse = 0 ; n_fuse <= 1 ; ++n_fuse) {
4261
+ if (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
4262
+ break ;
4263
+ }
4264
+
4265
+ if (nodes[n_fuse] != nodes[n_fuse + 1 ]->src [0 ]) {
4266
+ break ;
4267
+ }
4268
+
4269
+ if (nodes[n_fuse + 1 ]->src [1 ]->ne [0 ] != node->ne [0 ]) {
4270
+ break ;
4271
+ }
4272
+
4273
+ if (!ggml_is_contiguous_rows (nodes[n_fuse + 1 ]->src [1 ])) {
4274
+ break ;
4275
+ }
4276
+
4277
+ if (nodes[n_fuse + 1 ]->type != GGML_TYPE_F32) {
4278
+ break ;
4279
+ }
4280
+
4281
+ id_fuse[n_fuse] = ggml_metal_get_buffer (nodes[n_fuse + 1 ]->src [1 ], &offs_fuse[n_fuse]);
4282
+
4283
+ args.nef1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [1 ];
4284
+ args.nef2 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [2 ];
4285
+ args.nef3 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [3 ];
4286
+
4287
+ args.nbf1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [1 ];
4288
+ args.nbf2 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [2 ];
4289
+ args.nbf3 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [3 ];
4290
+ }
4291
+
4292
+ ++n_fuse;
4293
+ }
4294
+
4295
+ // GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
4296
+
4297
+ if (n_fuse > 1 ) {
4298
+ id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
4299
+ }
4300
+
4301
+ id <MTLComputePipelineState > pipeline;
4302
+
4303
+ switch (n_fuse) {
4304
+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline ; break ;
4305
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline ; break ;
4306
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline ; break ;
4307
+ default : GGML_ABORT (" unsupported n_fuse = %d \n " , n_fuse);
4308
+ }
4222
4309
4223
4310
int nth = 32 ; // SIMD width
4224
4311
@@ -4229,23 +4316,16 @@ static int ggml_metal_encode_node(
4229
4316
nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
4230
4317
nth = MIN (nth, ne00/4 );
4231
4318
4232
- ggml_metal_kargs_rms_norm args = {
4233
- /* .ne00 =*/ ne00,
4234
- /* .ne00_4 =*/ ne00/4 ,
4235
- /* .nb01 =*/ nb01,
4236
- /* .eps =*/ eps,
4237
- };
4238
-
4239
4319
[encoder setComputePipelineState: pipeline];
4240
- [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4241
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4242
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
4320
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4321
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4322
+ [encoder setBuffer: id_fuse[0 ] offset: offs_fuse[0 ] atIndex: 2 ];
4323
+ [encoder setBuffer: id_fuse[1 ] offset: offs_fuse[1 ] atIndex: 3 ];
4324
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
4243
4325
4244
4326
[encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
4245
4327
4246
- const int64_t nrows = ggml_nrows (src0);
4247
-
4248
- [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4328
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4249
4329
} break ;
4250
4330
case GGML_OP_L2_NORM:
4251
4331
{
0 commit comments