@@ -1123,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32(
1123
1123
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1124
1124
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1125
1125
1126
- float yl[16 ];
1126
+ float yl[32 ];
1127
1127
1128
- const uint16_t kmask1 = 0x0303 ;
1128
+ const uint16_t kmask1 = 0x3030 ;
1129
1129
const uint16_t kmask2 = 0x0f0f ;
1130
1130
1131
- const int tid = tiisg/2 ;
1132
- const int ix = tiisg%2 ;
1133
- const int ip = tid/8 ; // 0 or 1
1134
- const int il = tid/ 2 - 4 *ip ; // 0...3
1131
+ const int tid = tiisg/4 ;
1132
+ const int ix = tiisg%4 ;
1133
+ const int ip = tid/4 ; // 0 or 1
1134
+ const int il = 2 *(( tid% 4 )/ 2 ) ; // 0 or 2
1135
1135
const int ir = tid%2 ;
1136
1136
const int n = 8 ;
1137
1137
const int l0 = n*ir;
1138
1138
1139
- const uint16_t m1 = 1 << (4 *ip + il);
1140
- const uint16_t m2 = m1 << 8 ;
1139
+ // One would think that the Metal compiler would figure out that ip and il can only have
1140
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
1141
+ // with these two tales.
1142
+ //
1143
+ // Possible masks for the high bit
1144
+ const ushort4 mm[4 ] = {{0x0001 , 0x0100 , 0x0002 , 0x0200 }, // ip = 0, il = 0
1145
+ {0x0004 , 0x0400 , 0x0008 , 0x0800 }, // ip = 0, il = 2
1146
+ {0x0010 , 0x1000 , 0x0020 , 0x2000 }, // ip = 1, il = 0
1147
+ {0x0040 , 0x4000 , 0x0080 , 0x8000 }}; // ip = 1, il = 2
1148
+
1149
+ // Possible masks for the low 2 bits
1150
+ const int4 qm[2 ] = {{0x0003 , 0x0300 , 0x000c , 0x0c00 }, {0x0030 , 0x3000 , 0x00c0 , 0xc000 }};
1151
+
1152
+ const ushort4 hm = mm[2 *ip + il/2 ];
1141
1153
1142
1154
const int shift = 2 *il;
1143
- const uint16_t qm1 = 0x0003 << shift;
1144
- const uint16_t qm2 = 0x0300 << shift;
1145
- const int32_t v1 = 4 << shift;
1146
- const int32_t v2 = 1024 << shift;
1155
+ const float v1 = il == 0 ? 4 .f : 64 .f ;
1156
+ const float v2 = 4 .f * v1;
1147
1157
1148
1158
const uint16_t s_shift1 = 4 *ip;
1149
- const uint16_t s_shift2 = s_shift1 + 2 *(il/2 );
1150
- const int ik = 4 + (il%2 );
1159
+ const uint16_t s_shift2 = s_shift1 + il;
1151
1160
1152
1161
const int q_offset = 32 *ip + l0;
1153
1162
const int y_offset = 128 *ip + 32 *il + l0;
@@ -1156,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32(
1156
1165
1157
1166
device const float * y1 = yy + ix*QK_K + y_offset;
1158
1167
1159
- float sumf1[2 ] = {0 .f }, sumf2[2 ] = {0 .f };
1160
- for (int i = ix; i < nb; i += 2 ) {
1168
+ uint32_t scales32, aux32;
1169
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
1170
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
1171
+
1172
+ float sumf1[2 ] = {0 .f };
1173
+ float sumf2[2 ] = {0 .f };
1174
+ for (int i = ix; i < nb; i += 4 ) {
1161
1175
1162
1176
for (int l = 0 ; l < 8 ; ++l) {
1163
- yl[l+0 ] = y1[l+ 0 ];
1164
- yl[l+8 ] = y1[l+16 ];
1177
+ yl[l+ 0 ] = y1[l+ 0 ];
1178
+ yl[l+ 8 ] = y1[l+16 ];
1179
+ yl[l+16 ] = y1[l+32 ];
1180
+ yl[l+24 ] = y1[l+48 ];
1165
1181
}
1166
1182
1167
1183
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@@ -1172,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32(
1172
1188
for (int row = 0 ; row < 2 ; ++row) {
1173
1189
1174
1190
const float d_all = (float )dh[0 ];
1175
- const char2 scales = as_type<char2>((uint16_t )(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4 )));
1176
1191
1177
- float s1 = 0 , s2 = 0 ;
1192
+ scales16[0 ] = a[4 ];
1193
+ scales16[1 ] = a[5 ];
1194
+ aux32 = ((scales32 >> s_shift2) << 4 ) & 0x30303030 ;
1195
+ scales16[0 ] = a[il+0 ];
1196
+ scales16[1 ] = a[il+1 ];
1197
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f ) | aux32;
1198
+
1199
+ float s1 = 0 , s2 = 0 , s3 = 0 , s4 = 0 , s5 = 0 , s6 = 0 ;
1178
1200
for (int l = 0 ; l < n; l += 2 ) {
1179
- const uint16_t qs = q[l/2 ];
1180
- s1 += yl[l+0 ] * ((int32_t )(qs & qm1) - ((h[l/2 ] & m1) ? 0 : v1));
1181
- s2 += yl[l+1 ] * ((int32_t )(qs & qm2) - ((h[l/2 ] & m2) ? 0 : v2));
1201
+ const int32_t qs = q[l/2 ];
1202
+ s1 += yl[l+0 ] * (qs & qm[il/2 ][0 ]);
1203
+ s2 += yl[l+1 ] * (qs & qm[il/2 ][1 ]);
1204
+ s3 += ((h[l/2 ] & hm[0 ]) ? 0 .f : yl[l+0 ]) + ((h[l/2 ] & hm[1 ]) ? 0 .f : yl[l+1 ]);
1205
+ s4 += yl[l+16 ] * (qs & qm[il/2 ][2 ]);
1206
+ s5 += yl[l+17 ] * (qs & qm[il/2 ][3 ]);
1207
+ s6 += ((h[l/2 ] & hm[2 ]) ? 0 .f : yl[l+16 ]) + ((h[l/2 ] & hm[3 ]) ? 0 .f : yl[l+17 ]);
1182
1208
}
1183
- float d = d_all * (s1 + 1 .f /256 .f * s2);
1184
- sumf1[row] += d * scales[0 ];
1185
- sumf2[row] += d;
1209
+ float d1 = d_all * (s1 + 1 .f /256 .f * s2 - s3*v1);
1210
+ float d2 = d_all * (s4 + 1 .f /256 .f * s5 - s6*v2);
1211
+ sumf1[row] += d1 * (scales[0 ] - 32 );
1212
+ sumf2[row] += d2 * (scales[2 ] - 32 );
1186
1213
1187
- s1 = s2 = 0 ;
1214
+ s1 = s2 = s3 = s4 = s5 = s6 = 0 ;
1188
1215
for (int l = 0 ; l < n; l += 2 ) {
1189
- const uint16_t qs = q[l/2 +8 ];
1190
- s1 += yl[l+8 ] * ((int32_t )(qs & qm1) - ((h[l/2 +8 ] & m1) ? 0 : v1));
1191
- s2 += yl[l+9 ] * ((int32_t )(qs & qm2) - ((h[l/2 +8 ] & m2) ? 0 : v2));
1216
+ const int32_t qs = q[l/2 +8 ];
1217
+ s1 += yl[l+8 ] * (qs & qm[il/2 ][0 ]);
1218
+ s2 += yl[l+9 ] * (qs & qm[il/2 ][1 ]);
1219
+ s3 += ((h[l/2 +8 ] & hm[0 ]) ? 0 .f : yl[l+8 ]) + ((h[l/2 +8 ] & hm[1 ]) ? 0 .f : yl[l+9 ]);
1220
+ s4 += yl[l+24 ] * (qs & qm[il/2 ][2 ]);
1221
+ s5 += yl[l+25 ] * (qs & qm[il/2 ][3 ]);
1222
+ s6 += ((h[l/2 +8 ] & hm[2 ]) ? 0 .f : yl[l+24 ]) + ((h[l/2 +8 ] & hm[3 ]) ? 0 .f : yl[l+25 ]);
1192
1223
}
1193
- d = d_all * (s1 + 1 .f /256 .f * s2);
1194
- sumf1[row] += d * scales[1 ];
1195
- sumf2[row] += d;
1224
+ d1 = d_all * (s1 + 1 .f /256 .f * s2 - s3*v1);
1225
+ d2 = d_all * (s4 + 1 .f /256 .f * s5 - s6*v2);
1226
+ sumf1[row] += d1 * (scales[1 ] - 32 );
1227
+ sumf2[row] += d2 * (scales[3 ] - 32 );
1196
1228
1197
1229
q += step;
1198
1230
h += step;
@@ -1201,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32(
1201
1233
1202
1234
}
1203
1235
1204
- y1 += 2 * QK_K;
1236
+ y1 += 4 * QK_K;
1205
1237
1206
1238
}
1207
1239
1208
1240
for (int row = 0 ; row < 2 ; ++row) {
1209
- const float sumf = (sumf1[row] - 32 .f *sumf2[row]) / (1 << shift);
1210
- const float tot = simd_sum (sumf);
1211
- if (tiisg == 0 ) {
1212
- dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1241
+ const float sumf = (sumf1[row] + 0 .25f * sumf2[row]) / (1 << shift);
1242
+ sumf1[row] = simd_sum (sumf);
1243
+ }
1244
+ if (tiisg == 0 ) {
1245
+ for (int row = 0 ; row < 2 ; ++row) {
1246
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
1213
1247
}
1214
1248
}
1249
+
1215
1250
}
1216
1251
#else
1217
1252
kernel void kernel_mul_mat_q3_K_f32 (
@@ -1564,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32(
1564
1599
sc16[2 ] = ((a[4 ] >> 0 ) & kmask2) | ((a[0 ] & kmask3) >> 2 );
1565
1600
sc16[3 ] = ((a[4 ] >> 4 ) & kmask2) | ((a[2 ] & kmask3) >> 2 );
1566
1601
1567
- float4 acc = {0 .f , 0 .f , 0 .f , 0 .f };
1602
+ float4 acc1 = {0 .f };
1603
+ float4 acc2 = {0 .f };
1568
1604
for (int l = 0 ; l < n; ++l) {
1569
1605
uint8_t h = qh[l];
1570
- acc[0 ] += yl[l+0 ] * ((uint16_t )(q1[l] & 0x0F ) + (h & hm1 ? 16 : 0 ));
1571
- acc[1 ] += yl[l+8 ] * ((uint16_t )(q1[l] & 0xF0 ) + (h & hm2 ? 256 : 0 ));
1572
- acc[2 ] += yh[l+0 ] * ((uint16_t )(q2[l] & 0x0F ) + (h & hm3 ? 16 : 0 ));
1573
- acc[3 ] += yh[l+8 ] * ((uint16_t )(q2[l] & 0xF0 ) + (h & hm4 ? 256 : 0 ));
1606
+ acc1[0 ] += yl[l+0 ] * (q1[l] & 0x0F );
1607
+ acc1[1 ] += yl[l+8 ] * (q1[l] & 0xF0 );
1608
+ acc1[2 ] += yh[l+0 ] * (q2[l] & 0x0F );
1609
+ acc1[3 ] += yh[l+8 ] * (q2[l] & 0xF0 );
1610
+ acc2[0 ] += h & hm1 ? yl[l+0 ] : 0 .f ;
1611
+ acc2[1 ] += h & hm2 ? yl[l+8 ] : 0 .f ;
1612
+ acc2[2 ] += h & hm3 ? yh[l+0 ] : 0 .f ;
1613
+ acc2[3 ] += h & hm4 ? yh[l+8 ] : 0 .f ;
1574
1614
}
1575
1615
const float dall = dh[0 ];
1576
1616
const float dmin = dh[1 ];
1577
- sumf[row] += dall * (acc[0 ] * sc8[0 ] + acc[1 ] * sc8[1 ] * 1 .f /16 .f + acc[2 ] * sc8[4 ] + acc[3 ] * sc8[5 ] * 1 .f /16 .f ) -
1617
+ sumf[row] += dall * (sc8[0 ] * (acc1[0 ] + 16 .f *acc2[0 ]) +
1618
+ sc8[1 ] * (acc1[1 ]/16 .f + 16 .f *acc2[1 ]) +
1619
+ sc8[4 ] * (acc1[2 ] + 16 .f *acc2[2 ]) +
1620
+ sc8[5 ] * (acc1[3 ]/16 .f + 16 .f *acc2[3 ])) -
1578
1621
dmin * (sumy[0 ] * sc8[2 ] + sumy[1 ] * sc8[3 ] + sumy[2 ] * sc8[6 ] + sumy[3 ] * sc8[7 ]);
1579
1622
1580
1623
q1 += step;
0 commit comments