Skip to content

Commit 303f2ae

Browse files
committed
metal : more consistent binary kernels
1 parent 48f1061 commit 303f2ae

File tree

2 files changed

+137
-70
lines changed

2 files changed

+137
-70
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,20 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
154154
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
155155
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
156156
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
157-
GGML_METAL_KERNEL_TYPE_ADD_ROW,
158-
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
159-
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3,
160-
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
161-
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5,
162-
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
163-
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7,
164-
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
157+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
158+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
159+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
160+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
161+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
162+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
163+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
164+
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
165165
GGML_METAL_KERNEL_TYPE_SUB,
166-
GGML_METAL_KERNEL_TYPE_SUB_ROW,
166+
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
167167
GGML_METAL_KERNEL_TYPE_MUL,
168-
GGML_METAL_KERNEL_TYPE_MUL_ROW,
168+
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
169169
GGML_METAL_KERNEL_TYPE_DIV,
170-
GGML_METAL_KERNEL_TYPE_DIV_ROW,
170+
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
171171
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
172172
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
173173
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -1156,20 +1156,20 @@ @implementation GGMLMetalClass
11561156
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
11571157
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
11581158
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
1159-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1160-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true);
1161-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3, add_row_fuse_3, true);
1162-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true);
1163-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5, add_row_fuse_5, true);
1164-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true);
1165-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7, add_row_fuse_7, true);
1166-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true);
1159+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
1160+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
1161+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
1162+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
1163+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
1164+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
1165+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
1166+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
11671167
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
1168-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
1168+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
11691169
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
1170-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
1170+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
11711171
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
1172-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
1172+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
11731173
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
11741174
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
11751175
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
@@ -2167,6 +2167,8 @@ static int ggml_metal_encode_node(
21672167
++n_fuse;
21682168
}
21692169

2170+
//GGML_LOG_INFO("%s: XXXXXXXXXXXXXXXXXXX n_fuse = %d\n", __func__, n_fuse);
2171+
21702172
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
21712173
GGML_ASSERT(ggml_is_contiguous(src0));
21722174

@@ -2177,20 +2179,20 @@ static int ggml_metal_encode_node(
21772179
case GGML_OP_ADD:
21782180
{
21792181
switch (n_fuse) {
2180-
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW ].pipeline; break;
2181-
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline; break;
2182-
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_3].pipeline; break;
2183-
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline; break;
2184-
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_5].pipeline; break;
2185-
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline; break;
2186-
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_7].pipeline; break;
2187-
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline; break;
2182+
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
2183+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
2184+
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
2185+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
2186+
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
2187+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
2188+
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
2189+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
21882190
default: GGML_ABORT("fatal error");
21892191
}
21902192
} break;
2191-
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
2192-
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
2193-
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
2193+
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
2194+
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
2195+
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
21942196
default: GGML_ABORT("fatal error");
21952197
}
21962198

@@ -2225,11 +2227,7 @@ static int ggml_metal_encode_node(
22252227
[encoder setComputePipelineState:pipeline];
22262228
[encoder setBytes:&args length:sizeof(args) atIndex:0];
22272229
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2228-
if (dst->op == GGML_OP_ADD) {
2229-
[encoder setBuffer:id_src1 offset:0 atIndex:2];
2230-
} else {
2231-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2232-
}
2230+
[encoder setBuffer:id_src1 offset:0 atIndex:2];
22332231
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
22342232

22352233
if (bcast_row) {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 101 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ kernel void kernel_sub(
899899
const int i11 = i01%args.ne11;
900900

901901
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
902-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
902+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
903903
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
904904

905905
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
@@ -924,9 +924,9 @@ kernel void kernel_mul(
924924
const int i12 = i02%args.ne12;
925925
const int i11 = i01%args.ne11;
926926

927-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
928-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
929-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
927+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
928+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
929+
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
930930

931931
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
932932
const int i10 = i0%args.ne10;
@@ -950,9 +950,9 @@ kernel void kernel_div(
950950
const int i12 = i02%args.ne12;
951951
const int i11 = i01%args.ne11;
952952

953-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
954-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
955-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
953+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
954+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
955+
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
956956

957957
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
958958
const int i10 = i0%args.ne10;
@@ -995,7 +995,7 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
995995
// assumption: src1 is a row
996996
// broadcast src1 into src0
997997
template <short F>
998-
kernel void kernel_add_row_fuse_impl(
998+
kernel void kernel_add_row_c4_fuse_impl(
999999
constant ggml_metal_kargs_bin & args,
10001000
device const char * src0,
10011001
device const char * src1,
@@ -1023,47 +1023,116 @@ kernel void kernel_add_row_fuse_impl(
10231023
dst_row[tpig] = res;
10241024
}
10251025

1026-
typedef decltype(kernel_add_row_fuse_impl<2>) kernel_add_row_fuse_t;
1026+
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
10271027

1028-
template [[host_name("kernel_add_row")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<1>;
1029-
template [[host_name("kernel_add_row_fuse_2")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<2>;
1030-
template [[host_name("kernel_add_row_fuse_3")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<3>;
1031-
template [[host_name("kernel_add_row_fuse_4")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<4>;
1032-
template [[host_name("kernel_add_row_fuse_5")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<5>;
1033-
template [[host_name("kernel_add_row_fuse_6")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<6>;
1034-
template [[host_name("kernel_add_row_fuse_7")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<7>;
1035-
template [[host_name("kernel_add_row_fuse_8")]] kernel kernel_add_row_fuse_t kernel_add_row_fuse_impl<8>;
1028+
template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
1029+
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
1030+
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
1031+
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
1032+
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
1033+
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
1034+
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
1035+
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
10361036

1037-
kernel void kernel_sub_row(
1037+
template <short F>
1038+
kernel void kernel_sub_row_c4_fuse_impl(
10381039
constant ggml_metal_kargs_bin & args,
1039-
device const float4 * src0,
1040-
device const float4 * src1,
1041-
device float4 * dst,
1040+
device const char * src0,
1041+
device const char * src1,
1042+
device char * dst,
10421043
uint tpig[[thread_position_in_grid]]) {
1044+
10431045
const uint nb = args.ne00/4;
1044-
dst[tpig] = src0[tpig] - src1[tpig % nb];
1046+
const uint i = tpig % nb;
1047+
1048+
device const float4 * src0_row = (device const float4 *) (src0);
1049+
device float4 * dst_row = (device float4 *) (dst);
1050+
1051+
device const float4 * src1_row[F];
1052+
for (short j = 0; j < F; ++j) {
1053+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1054+
}
1055+
1056+
float4 res = src0_row[tpig];
1057+
1058+
#pragma unroll(F)
1059+
for (short j = 0; j < F; ++j) {
1060+
res -= src1_row[j][i];
1061+
}
1062+
1063+
dst_row[tpig] = res;
10451064
}
10461065

1047-
kernel void kernel_mul_row(
1066+
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
1067+
1068+
template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
1069+
1070+
template <short F>
1071+
kernel void kernel_mul_row_c4_fuse_impl(
10481072
constant ggml_metal_kargs_bin & args,
1049-
device const float4 * src0,
1050-
device const float4 * src1,
1051-
device float4 * dst,
1073+
device const char * src0,
1074+
device const char * src1,
1075+
device char * dst,
10521076
uint tpig[[thread_position_in_grid]]) {
1077+
10531078
const uint nb = args.ne00/4;
1054-
dst[tpig] = src0[tpig] * src1[tpig % nb];
1079+
const uint i = tpig % nb;
1080+
1081+
device const float4 * src0_row = (device const float4 *) (src0);
1082+
device float4 * dst_row = (device float4 *) (dst);
1083+
1084+
device const float4 * src1_row[F];
1085+
for (short j = 0; j < F; ++j) {
1086+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1087+
}
1088+
1089+
float4 res = src0_row[tpig];
1090+
1091+
#pragma unroll(F)
1092+
for (short j = 0; j < F; ++j) {
1093+
res *= src1_row[j][i];
1094+
}
1095+
1096+
dst_row[tpig] = res;
10551097
}
10561098

1057-
kernel void kernel_div_row(
1099+
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
1100+
1101+
template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
1102+
1103+
template <short F>
1104+
kernel void kernel_div_row_c4_fuse_impl(
10581105
constant ggml_metal_kargs_bin & args,
1059-
device const float4 * src0,
1060-
device const float4 * src1,
1061-
device float4 * dst,
1106+
device const char * src0,
1107+
device const char * src1,
1108+
device char * dst,
10621109
uint tpig[[thread_position_in_grid]]) {
1110+
10631111
const uint nb = args.ne00/4;
1064-
dst[tpig] = src0[tpig] / src1[tpig % nb];
1112+
const uint i = tpig % nb;
1113+
1114+
device const float4 * src0_row = (device const float4 *) (src0);
1115+
device float4 * dst_row = (device float4 *) (dst);
1116+
1117+
device const float4 * src1_row[F];
1118+
for (short j = 0; j < F; ++j) {
1119+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
1120+
}
1121+
1122+
float4 res = src0_row[tpig];
1123+
1124+
#pragma unroll(F)
1125+
for (short j = 0; j < F; ++j) {
1126+
res /= src1_row[j][i];
1127+
}
1128+
1129+
dst_row[tpig] = res;
10651130
}
10661131

1132+
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
1133+
1134+
template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
1135+
10671136
kernel void kernel_scale(
10681137
device const float * src0,
10691138
device float * dst,

0 commit comments

Comments
 (0)