Skip to content

Commit ba7ffbb

Browse files
ikawrakowIwan Kawrakow
andauthored
metal : Q3_K speedup (#2995)
* Slightly faster Q3_K and Q5_K on metal * Another Q3_K speedup on metal Combined with previous commit, we are now +9.6% for TG. PP is not affected as this happens via the matrix multiplication templates. * Slowly progressing on Q3_K on metal We are now 13% faster than master * nother small improvement for Q3_K on metal --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent e64f5b5 commit ba7ffbb

File tree

1 file changed

+87
-44
lines changed

1 file changed

+87
-44
lines changed

ggml-metal.metal

Lines changed: 87 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32(
11231123
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
11241124
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
11251125

1126-
float yl[16];
1126+
float yl[32];
11271127

1128-
const uint16_t kmask1 = 0x0303;
1128+
const uint16_t kmask1 = 0x3030;
11291129
const uint16_t kmask2 = 0x0f0f;
11301130

1131-
const int tid = tiisg/2;
1132-
const int ix = tiisg%2;
1133-
const int ip = tid/8; // 0 or 1
1134-
const int il = tid/2 - 4*ip; // 0...3
1131+
const int tid = tiisg/4;
1132+
const int ix = tiisg%4;
1133+
const int ip = tid/4; // 0 or 1
1134+
const int il = 2*((tid%4)/2); // 0 or 2
11351135
const int ir = tid%2;
11361136
const int n = 8;
11371137
const int l0 = n*ir;
11381138

1139-
const uint16_t m1 = 1 << (4*ip + il);
1140-
const uint16_t m2 = m1 << 8;
1139+
// One would think that the Metal compiler would figure out that ip and il can only have
1140+
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1141+
// with these two tales.
1142+
//
1143+
// Possible masks for the high bit
1144+
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
1145+
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
1146+
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
1147+
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
1148+
1149+
// Possible masks for the low 2 bits
1150+
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
1151+
1152+
const ushort4 hm = mm[2*ip + il/2];
11411153

11421154
const int shift = 2*il;
1143-
const uint16_t qm1 = 0x0003 << shift;
1144-
const uint16_t qm2 = 0x0300 << shift;
1145-
const int32_t v1 = 4 << shift;
1146-
const int32_t v2 = 1024 << shift;
1155+
const float v1 = il == 0 ? 4.f : 64.f;
1156+
const float v2 = 4.f * v1;
11471157

11481158
const uint16_t s_shift1 = 4*ip;
1149-
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1150-
const int ik = 4 + (il%2);
1159+
const uint16_t s_shift2 = s_shift1 + il;
11511160

11521161
const int q_offset = 32*ip + l0;
11531162
const int y_offset = 128*ip + 32*il + l0;
@@ -1156,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32(
11561165

11571166
device const float * y1 = yy + ix*QK_K + y_offset;
11581167

1159-
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1160-
for (int i = ix; i < nb; i += 2) {
1168+
uint32_t scales32, aux32;
1169+
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1170+
thread const int8_t * scales = (thread const int8_t *)&scales32;
1171+
1172+
float sumf1[2] = {0.f};
1173+
float sumf2[2] = {0.f};
1174+
for (int i = ix; i < nb; i += 4) {
11611175

11621176
for (int l = 0; l < 8; ++l) {
1163-
yl[l+0] = y1[l+ 0];
1164-
yl[l+8] = y1[l+16];
1177+
yl[l+ 0] = y1[l+ 0];
1178+
yl[l+ 8] = y1[l+16];
1179+
yl[l+16] = y1[l+32];
1180+
yl[l+24] = y1[l+48];
11651181
}
11661182

11671183
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32(
11721188
for (int row = 0; row < 2; ++row) {
11731189

11741190
const float d_all = (float)dh[0];
1175-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
11761191

1177-
float s1 = 0, s2 = 0;
1192+
scales16[0] = a[4];
1193+
scales16[1] = a[5];
1194+
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
1195+
scales16[0] = a[il+0];
1196+
scales16[1] = a[il+1];
1197+
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
1198+
1199+
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
11781200
for (int l = 0; l < n; l += 2) {
1179-
const uint16_t qs = q[l/2];
1180-
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1181-
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1201+
const int32_t qs = q[l/2];
1202+
s1 += yl[l+0] * (qs & qm[il/2][0]);
1203+
s2 += yl[l+1] * (qs & qm[il/2][1]);
1204+
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
1205+
s4 += yl[l+16] * (qs & qm[il/2][2]);
1206+
s5 += yl[l+17] * (qs & qm[il/2][3]);
1207+
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
11821208
}
1183-
float d = d_all * (s1 + 1.f/256.f * s2);
1184-
sumf1[row] += d * scales[0];
1185-
sumf2[row] += d;
1209+
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1210+
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1211+
sumf1[row] += d1 * (scales[0] - 32);
1212+
sumf2[row] += d2 * (scales[2] - 32);
11861213

1187-
s1 = s2 = 0;
1214+
s1 = s2 = s3 = s4 = s5 = s6 = 0;
11881215
for (int l = 0; l < n; l += 2) {
1189-
const uint16_t qs = q[l/2+8];
1190-
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1191-
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1216+
const int32_t qs = q[l/2+8];
1217+
s1 += yl[l+8] * (qs & qm[il/2][0]);
1218+
s2 += yl[l+9] * (qs & qm[il/2][1]);
1219+
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
1220+
s4 += yl[l+24] * (qs & qm[il/2][2]);
1221+
s5 += yl[l+25] * (qs & qm[il/2][3]);
1222+
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
11921223
}
1193-
d = d_all * (s1 + 1.f/256.f * s2);
1194-
sumf1[row] += d * scales[1];
1195-
sumf2[row] += d;
1224+
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
1225+
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
1226+
sumf1[row] += d1 * (scales[1] - 32);
1227+
sumf2[row] += d2 * (scales[3] - 32);
11961228

11971229
q += step;
11981230
h += step;
@@ -1201,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32(
12011233

12021234
}
12031235

1204-
y1 += 2 * QK_K;
1236+
y1 += 4 * QK_K;
12051237

12061238
}
12071239

12081240
for (int row = 0; row < 2; ++row) {
1209-
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1210-
const float tot = simd_sum(sumf);
1211-
if (tiisg == 0) {
1212-
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1241+
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
1242+
sumf1[row] = simd_sum(sumf);
1243+
}
1244+
if (tiisg == 0) {
1245+
for (int row = 0; row < 2; ++row) {
1246+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
12131247
}
12141248
}
1249+
12151250
}
12161251
#else
12171252
kernel void kernel_mul_mat_q3_K_f32(
@@ -1564,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32(
15641599
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
15651600
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
15661601

1567-
float4 acc = {0.f, 0.f, 0.f, 0.f};
1602+
float4 acc1 = {0.f};
1603+
float4 acc2 = {0.f};
15681604
for (int l = 0; l < n; ++l) {
15691605
uint8_t h = qh[l];
1570-
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1571-
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1572-
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1573-
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1606+
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
1607+
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
1608+
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
1609+
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
1610+
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
1611+
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
1612+
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
1613+
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
15741614
}
15751615
const float dall = dh[0];
15761616
const float dmin = dh[1];
1577-
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1617+
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
1618+
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
1619+
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
1620+
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
15781621
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
15791622

15801623
q1 += step;

0 commit comments

Comments
 (0)