@@ -154,20 +154,20 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
154
154
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
155
155
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
156
156
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
157
- GGML_METAL_KERNEL_TYPE_ADD_ROW ,
158
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 ,
159
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 ,
160
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 ,
161
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 ,
162
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 ,
163
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 ,
164
- GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 ,
157
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ,
158
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 ,
159
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 ,
160
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 ,
161
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 ,
162
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 ,
163
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 ,
164
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 ,
165
165
GGML_METAL_KERNEL_TYPE_SUB,
166
- GGML_METAL_KERNEL_TYPE_SUB_ROW ,
166
+ GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 ,
167
167
GGML_METAL_KERNEL_TYPE_MUL,
168
- GGML_METAL_KERNEL_TYPE_MUL_ROW ,
168
+ GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 ,
169
169
GGML_METAL_KERNEL_TYPE_DIV,
170
- GGML_METAL_KERNEL_TYPE_DIV_ROW ,
170
+ GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 ,
171
171
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
172
172
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
173
173
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -1150,20 +1150,20 @@ @implementation GGMLMetalClass
1150
1150
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1151
1151
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true );
1152
1152
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1153
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW , add_row, true );
1154
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 , add_row_fuse_2, true );
1155
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 , add_row_fuse_3, true );
1156
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 , add_row_fuse_4, true );
1157
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 , add_row_fuse_5, true );
1158
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 , add_row_fuse_6, true );
1159
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 , add_row_fuse_7, true );
1160
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 , add_row_fuse_8, true );
1153
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 , add_row_c4, true );
1154
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 , add_row_c4_fuse_2, true );
1155
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 , add_row_c4_fuse_3, true );
1156
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 , add_row_c4_fuse_4, true );
1157
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 , add_row_c4_fuse_5, true );
1158
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 , add_row_c4_fuse_6, true );
1159
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 , add_row_c4_fuse_7, true );
1160
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 , add_row_c4_fuse_8, true );
1161
1161
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1162
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW , sub_row, true );
1162
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 , sub_row_c4, true );
1163
1163
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
1164
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW , mul_row, true );
1164
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 , mul_row_c4, true );
1165
1165
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
1166
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW , div_row, true );
1166
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 , div_row_c4, true );
1167
1167
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
1168
1168
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
1169
1169
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true );
@@ -2149,6 +2149,8 @@ static int ggml_metal_encode_node(
2149
2149
++n_fuse;
2150
2150
}
2151
2151
2152
+ // GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
2153
+
2152
2154
if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2153
2155
GGML_ASSERT (ggml_is_contiguous (src0));
2154
2156
@@ -2159,20 +2161,20 @@ static int ggml_metal_encode_node(
2159
2161
case GGML_OP_ADD:
2160
2162
{
2161
2163
switch (n_fuse) {
2162
- case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW ].pipeline ; break ;
2163
- case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2 ].pipeline ; break ;
2164
- case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3 ].pipeline ; break ;
2165
- case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4 ].pipeline ; break ;
2166
- case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5 ].pipeline ; break ;
2167
- case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6 ].pipeline ; break ;
2168
- case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7 ].pipeline ; break ;
2169
- case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8 ].pipeline ; break ;
2164
+ case 1 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline ; break ;
2165
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2 ].pipeline ; break ;
2166
+ case 3 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3 ].pipeline ; break ;
2167
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4 ].pipeline ; break ;
2168
+ case 5 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5 ].pipeline ; break ;
2169
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6 ].pipeline ; break ;
2170
+ case 7 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7 ].pipeline ; break ;
2171
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8 ].pipeline ; break ;
2170
2172
default : GGML_ABORT (" fatal error" );
2171
2173
}
2172
2174
} break ;
2173
- case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW ].pipeline ; break ;
2174
- case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW ].pipeline ; break ;
2175
- case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW ].pipeline ; break ;
2175
+ case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW_C4 ].pipeline ; break ;
2176
+ case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW_C4 ].pipeline ; break ;
2177
+ case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW_C4 ].pipeline ; break ;
2176
2178
default : GGML_ABORT (" fatal error" );
2177
2179
}
2178
2180
@@ -2207,11 +2209,7 @@ static int ggml_metal_encode_node(
2207
2209
[encoder setComputePipelineState: pipeline];
2208
2210
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2209
2211
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2210
- if (dst->op == GGML_OP_ADD) {
2211
- [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
2212
- } else {
2213
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2214
- }
2212
+ [encoder setBuffer: id_src1 offset: 0 atIndex: 2 ];
2215
2213
[encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
2216
2214
2217
2215
if (bcast_row) {
0 commit comments