@@ -226,6 +226,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
226
226
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
227
227
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
228
228
GGML_METAL_KERNEL_TYPE_RMS_NORM,
229
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
230
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
229
231
GGML_METAL_KERNEL_TYPE_L2_NORM,
230
232
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
231
233
GGML_METAL_KERNEL_TYPE_NORM,
@@ -1222,6 +1224,8 @@ @implementation GGMLMetalClass
1222
1224
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
1223
1225
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
1224
1226
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1227
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
1228
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
1225
1229
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1226
1230
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
1227
1231
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
@@ -2115,6 +2119,10 @@ static int ggml_metal_encode_node(
2115
2119
/* .o1 =*/ { offs_src1 },
2116
2120
};
2117
2121
2122
+ // c[0] = add(a, b[0])
2123
+ // c[1] = add(c[0], b[1])
2124
+ // c[2] = add(c[1], b[2])
2125
+ // ...
2118
2126
{
2119
2127
ops[0 ] = GGML_OP_ADD;
2120
2128
ops[1 ] = GGML_OP_ADD;
@@ -2133,6 +2141,11 @@ static int ggml_metal_encode_node(
2133
2141
break ;
2134
2142
}
2135
2143
2144
+ if (nodes[n_fuse] != nodes[n_fuse + 1 ]->src [0 ]) {
2145
+ break ;
2146
+ }
2147
+
2148
+ // b[0] === b[1] === ...
2136
2149
if (!ggml_are_same_layout (nodes[n_fuse]->src [1 ], nodes[n_fuse + 1 ]->src [1 ])) {
2137
2150
break ;
2138
2151
}
@@ -4123,12 +4136,86 @@ static int ggml_metal_encode_node(
4123
4136
case GGML_OP_RMS_NORM:
4124
4137
{
4125
4138
GGML_ASSERT (ne00 % 4 == 0 );
4126
- GGML_ASSERT (ggml_is_contiguous_1 (src0));
4139
+ GGML_ASSERT (ggml_is_contiguous_rows (src0));
4127
4140
4128
4141
float eps;
4129
4142
memcpy (&eps, dst->op_params , sizeof (float ));
4130
4143
4131
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline ;
4144
+ ggml_metal_kargs_rms_norm args = {
4145
+ /* .ne00 =*/ ne00,
4146
+ /* .ne00_4 =*/ ne00/4 ,
4147
+ /* .nb1 =*/ nb1,
4148
+ /* .nb2 =*/ nb2,
4149
+ /* .nb3 =*/ nb3,
4150
+ /* .eps =*/ eps,
4151
+ /* .nef1 =*/ { ne01 },
4152
+ /* .nef2 =*/ { ne02 },
4153
+ /* .nef3 =*/ { ne03 },
4154
+ /* .nbf1 =*/ { nb01 },
4155
+ /* .nbf2 =*/ { nb02 },
4156
+ /* .nbf3 =*/ { nb03 },
4157
+ };
4158
+
4159
+ size_t offs_fuse[2 ] = { 0 , 0 };
4160
+ id <MTLBuffer > id_fuse[2 ] = { id_src0, id_src0 };
4161
+
4162
+ // d[0] = rms_norm(a)
4163
+ // d[1] = mul(d[0], b)
4164
+ // d[2] = add(d[1], c)
4165
+ {
4166
+ ops[0 ] = GGML_OP_RMS_NORM;
4167
+ ops[1 ] = GGML_OP_MUL;
4168
+ ops[2 ] = GGML_OP_ADD;
4169
+
4170
+ for (n_fuse = 0 ; n_fuse <= 1 ; ++n_fuse) {
4171
+ if (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
4172
+ break ;
4173
+ }
4174
+
4175
+ if (nodes[n_fuse] != nodes[n_fuse + 1 ]->src [0 ]) {
4176
+ break ;
4177
+ }
4178
+
4179
+ if (nodes[n_fuse + 1 ]->src [1 ]->ne [0 ] != node->ne [0 ]) {
4180
+ break ;
4181
+ }
4182
+
4183
+ if (!ggml_is_contiguous_rows (nodes[n_fuse + 1 ]->src [1 ])) {
4184
+ break ;
4185
+ }
4186
+
4187
+ if (nodes[n_fuse + 1 ]->type != GGML_TYPE_F32) {
4188
+ break ;
4189
+ }
4190
+
4191
+ id_fuse[n_fuse] = ggml_metal_get_buffer (nodes[n_fuse + 1 ]->src [1 ], &offs_fuse[n_fuse]);
4192
+
4193
+ args.nef1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [1 ];
4194
+ args.nef2 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [2 ];
4195
+ args.nef3 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->ne [3 ];
4196
+
4197
+ args.nbf1 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [1 ];
4198
+ args.nbf2 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [2 ];
4199
+ args.nbf3 [n_fuse + 1 ] = nodes[n_fuse + 1 ]->src [1 ]->nb [3 ];
4200
+ }
4201
+
4202
+ ++n_fuse;
4203
+ }
4204
+
4205
+ // GGML_LOG_INFO("%s: RRRRRRRRRRRRRRRRRRRRRRRRRRRRR n_fuse = %d\n", __func__, n_fuse);
4206
+
4207
+ if (n_fuse > 1 ) {
4208
+ id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
4209
+ }
4210
+
4211
+ id <MTLComputePipelineState > pipeline;
4212
+
4213
+ switch (n_fuse) {
4214
+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline ; break ;
4215
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline ; break ;
4216
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline ; break ;
4217
+ default : GGML_ABORT (" unsupported n_fuse = %d \n " , n_fuse);
4218
+ }
4132
4219
4133
4220
int nth = 32 ; // SIMD width
4134
4221
@@ -4139,23 +4226,16 @@ static int ggml_metal_encode_node(
4139
4226
nth = MIN (nth, (int ) pipeline.maxTotalThreadsPerThreadgroup );
4140
4227
nth = MIN (nth, ne00/4 );
4141
4228
4142
- ggml_metal_kargs_rms_norm args = {
4143
- /* .ne00 =*/ ne00,
4144
- /* .ne00_4 =*/ ne00/4 ,
4145
- /* .nb01 =*/ nb01,
4146
- /* .eps =*/ eps,
4147
- };
4148
-
4149
4229
[encoder setComputePipelineState: pipeline];
4150
- [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4151
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4152
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
4230
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
4231
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
4232
+ [encoder setBuffer: id_fuse[0 ] offset: offs_fuse[0 ] atIndex: 2 ];
4233
+ [encoder setBuffer: id_fuse[1 ] offset: offs_fuse[1 ] atIndex: 3 ];
4234
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 4 ];
4153
4235
4154
4236
[encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
4155
4237
4156
- const int64_t nrows = ggml_nrows (src0);
4157
-
4158
- [encoder dispatchThreadgroups: MTLSizeMake (nrows, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4238
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
4159
4239
} break ;
4160
4240
case GGML_OP_L2_NORM:
4161
4241
{
0 commit comments