Skip to content

Commit 69f7726

Browse files
committed
ggml-quants : allow using ARM dot product instructions for TQ1_0
1 parent 895004f commit 69f7726

File tree

1 file changed

+108
-1
lines changed

1 file changed

+108
-1
lines changed

ggml/src/ggml-quants.c

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5667,7 +5667,114 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
56675667

56685668
const int nb = n / QK_K;
56695669

5670-
#if defined __ARM_NEON
5670+
#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5671+
float sumf = 0.0f;
5672+
5673+
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
5674+
5675+
const uint8x16_t shift = vld1q_u8(k_shift);
5676+
5677+
for (int i = 0; i < nb; ++i) {
5678+
int32x4_t sumi0 = vdupq_n_s32(0);
5679+
int32x4_t sumi1 = vdupq_n_s32(0);
5680+
5681+
// first 32 bytes of 5 elements
5682+
{
5683+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
5684+
uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
5685+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
5686+
uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
5687+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
5688+
uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
5689+
uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
5690+
uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
5691+
uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
5692+
uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
5693+
5694+
// multiply by 3 and keep the 2 bits above 8 bits
5695+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5696+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5697+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5698+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5699+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5700+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5701+
int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
5702+
int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
5703+
int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
5704+
int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
5705+
5706+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
5707+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
5708+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
5709+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
5710+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
5711+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
5712+
const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
5713+
const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
5714+
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
5715+
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
5716+
5717+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5718+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5719+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5720+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5721+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5722+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5723+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
5724+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
5725+
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
5726+
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
5727+
}
5728+
5729+
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
5730+
{
5731+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
5732+
uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
5733+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
5734+
uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
5735+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
5736+
uint32_t qh;
5737+
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
5738+
uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
5739+
qx5 = vmulq_u8(qx5, shift);
5740+
5741+
// multiply by 3 and keep the 2 bits above 8 bits
5742+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5743+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5744+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5745+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5746+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5747+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5748+
5749+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
5750+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
5751+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
5752+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
5753+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
5754+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
5755+
5756+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5757+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5758+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5759+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5760+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5761+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5762+
}
5763+
5764+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
5765+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
5766+
5767+
sumi0 = vaddq_s32(sumi0, sumi1);
5768+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
5769+
5770+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5771+
5772+
sumf += d * (float) vaddvq_s32(sumi0);
5773+
}
5774+
5775+
*s = sumf;
5776+
5777+
#elif defined __ARM_NEON
56715778
float sumf = 0.0f;
56725779

56735780
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};

0 commit comments

Comments
 (0)