@@ -154,20 +154,20 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
154
154
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
155
155
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
156
156
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
157
- GGML_METAL_KERNEL_TYPE_ADD_ROW ,
158
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 ,
159
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 ,
160
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 ,
161
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 ,
162
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 ,
163
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 ,
164
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 ,
157
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ,
158
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 ,
159
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 ,
160
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 ,
161
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 ,
162
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 ,
163
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 ,
164
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 ,
165
165
GGML_METAL_KERNEL_TYPE_SUB,
166
- GGML_METAL_KERNEL_TYPE_SUB_ROW ,
166
+ GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 ,
167
167
GGML_METAL_KERNEL_TYPE_MUL,
168
- GGML_METAL_KERNEL_TYPE_MUL_ROW ,
168
+ GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 ,
169
169
GGML_METAL_KERNEL_TYPE_DIV,
170
- GGML_METAL_KERNEL_TYPE_DIV_ROW ,
170
+ GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 ,
171
171
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
172
172
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
173
173
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -1156,20 +1156,20 @@ @implementation GGMLMetalClass
1156
1156
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1157
1157
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true );
1158
1158
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1159
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW , add_row, true );
1160
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 , add_row_fuse_2, true );
1161
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 , add_row_fuse_3, true );
1162
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 , add_row_fuse_4, true );
1163
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 , add_row_fuse_5, true );
1164
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 , add_row_fuse_6, true );
1165
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 , add_row_fuse_7, true );
1166
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 , add_row_fuse_8, true );
1159
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 , add_row_c4, true );
1160
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 , add_row_c4_fuse_2, true );
1161
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 , add_row_c4_fuse_3, true );
1162
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 , add_row_c4_fuse_4, true );
1163
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 , add_row_c4_fuse_5, true );
1164
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 , add_row_c4_fuse_6, true );
1165
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 , add_row_c4_fuse_7, true );
1166
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 , add_row_c4_fuse_8, true );
1167
1167
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1168
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW , sub_row, true );
1168
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 , sub_row_c4, true );
1169
1169
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
1170
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW , mul_row, true );
1170
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 , mul_row_c4, true );
1171
1171
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
1172
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW , div_row, true );
1172
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 , div_row_c4, true );
1173
1173
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
1174
1174
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
1175
1175
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true );
@@ -2167,6 +2167,8 @@ static int ggml_metal_encode_node(
2167
2167
++n_fuse;
2168
2168
}
2169
2169
2170
+ // GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
2171
+
2170
2172
if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2171
2173
GGML_ASSERT (ggml_is_contiguous (src0));
2172
2174
@@ -2177,20 +2179,20 @@ static int ggml_metal_encode_node(
2177
2179
case GGML_OP_ADD:
2178
2180
{
2179
2181
switch (n_fuse) {
2180
- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW ].pipeline ; break ;
2181
- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 ].pipeline ; break ;
2182
- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 ].pipeline ; break ;
2183
- case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 ].pipeline ; break ;
2184
- case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 ].pipeline ; break ;
2185
- case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 ].pipeline ; break ;
2186
- case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 ].pipeline ; break ;
2187
- case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 ].pipeline ; break ;
2182
+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline ; break ;
2183
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 ].pipeline ; break ;
2184
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 ].pipeline ; break ;
2185
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 ].pipeline ; break ;
2186
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 ].pipeline ; break ;
2187
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 ].pipeline ; break ;
2188
+ case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 ].pipeline ; break ;
2189
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 ].pipeline ; break ;
2188
2190
default : GGML_ABORT (" fatal error" );
2189
2191
}
2190
2192
} break ;
2191
- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW ].pipeline ; break ;
2192
- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW ].pipeline ; break ;
2193
- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW ].pipeline ; break ;
2193
+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 ].pipeline ; break ;
2194
+ case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 ].pipeline ; break ;
2195
+ case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 ].pipeline ; break ;
2194
2196
default : GGML_ABORT (" fatal error" );
2195
2197
}
2196
2198
@@ -2225,11 +2227,7 @@ static int ggml_metal_encode_node(
2225
2227
[encoder setComputePipelineState: pipeline];
2226
2228
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2227
2229
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2228
- if (dst->op == GGML_OP_ADD) {
2229
- [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
2230
- } else {
2231
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2232
- }
2230
+ [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
2233
2231
[encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
2234
2232
2235
2233
if (bcast_row) {
0 commit comments