@@ -5667,7 +5667,114 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
5667
5667
5668
5668
const int nb = n / QK_K;
5669
5669
5670
- #if defined __ARM_NEON
5670
+ #if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
5671
+ float sumf = 0.0f;
5672
+
5673
+ uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
5674
+
5675
+ const uint8x16_t shift = vld1q_u8(k_shift);
5676
+
5677
+ for (int i = 0; i < nb; ++i) {
5678
+ int32x4_t sumi0 = vdupq_n_s32(0);
5679
+ int32x4_t sumi1 = vdupq_n_s32(0);
5680
+
5681
+ // first 32 bytes of 5 elements
5682
+ {
5683
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
5684
+ uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
5685
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
5686
+ uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
5687
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
5688
+ uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
5689
+ uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
5690
+ uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
5691
+ uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
5692
+ uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
5693
+
5694
+ // multiply by 3 and keep the 2 bits above 8 bits
5695
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5696
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5697
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5698
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5699
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5700
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5701
+ int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
5702
+ int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
5703
+ int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
5704
+ int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
5705
+
5706
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
5707
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
5708
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
5709
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
5710
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
5711
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
5712
+ const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
5713
+ const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
5714
+ const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
5715
+ const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
5716
+
5717
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5718
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5719
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5720
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5721
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5722
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5723
+ sumi0 = vdotq_s32(sumi0, sqx6, qy6);
5724
+ sumi1 = vdotq_s32(sumi1, sqx7, qy7);
5725
+ sumi0 = vdotq_s32(sumi0, sqx8, qy8);
5726
+ sumi1 = vdotq_s32(sumi1, sqx9, qy9);
5727
+ }
5728
+
5729
+ // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
5730
+ {
5731
+ uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
5732
+ uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
5733
+ uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
5734
+ uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
5735
+ uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
5736
+ uint32_t qh;
5737
+ memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
5738
+ uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
5739
+ qx5 = vmulq_u8(qx5, shift);
5740
+
5741
+ // multiply by 3 and keep the 2 bits above 8 bits
5742
+ int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
5743
+ int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
5744
+ int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
5745
+ int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
5746
+ int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
5747
+ int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
5748
+
5749
+ const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
5750
+ const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
5751
+ const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
5752
+ const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
5753
+ const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
5754
+ const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
5755
+
5756
+ sumi0 = vdotq_s32(sumi0, sqx0, qy0);
5757
+ sumi1 = vdotq_s32(sumi1, sqx1, qy1);
5758
+ sumi0 = vdotq_s32(sumi0, sqx2, qy2);
5759
+ sumi1 = vdotq_s32(sumi1, sqx3, qy3);
5760
+ sumi0 = vdotq_s32(sumi0, sqx4, qy4);
5761
+ sumi1 = vdotq_s32(sumi1, sqx5, qy5);
5762
+ }
5763
+
5764
+ const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
5765
+ const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
5766
+
5767
+ sumi0 = vaddq_s32(sumi0, sumi1);
5768
+ sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
5769
+
5770
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5771
+
5772
+ sumf += d * (float) vaddvq_s32(sumi0);
5773
+ }
5774
+
5775
+ *s = sumf;
5776
+
5777
+ #elif defined __ARM_NEON
5671
5778
float sumf = 0.0f;
5672
5779
5673
5780
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
0 commit comments