Skip to content

Commit be32443

Browse files
committed
ggml : fix q5_0 histogram stats
1 parent 9fd0e45 commit be32443

File tree

1 file changed

+39
-34
lines changed

1 file changed

+39
-34
lines changed

ggml.c

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,7 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
13271327

13281328
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
13291329

1330+
// get the 5-th bit and store it in qh at the right position
13301331
y[i].qh |= ((vi0 & 0x10) >> 4) << (l + 0);
13311332
y[i].qh |= ((vi1 & 0x10) >> 4) << (l + 1);
13321333
}
@@ -1624,7 +1625,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
16241625
const uint8x8_t v8 = vld1_u8(pp + l/2);
16251626

16261627
// Expand 4-bit qs to 8-bit bytes
1627-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1628+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
16281629
const uint8x8_t v1 = vshr_n_u8(v8, 4);
16291630

16301631
// Convert to signed 8-bit integers
@@ -1674,7 +1675,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
16741675
for (int l = 0; l < QK4_0; l += 2) {
16751676
const uint8_t vi = pp[l/2];
16761677

1677-
const int8_t vi0 = vi & 0xf;
1678+
const int8_t vi0 = vi & 0x0F;
16781679
const int8_t vi1 = vi >> 4;
16791680

16801681
const float v0 = (vi0 - 8)*d;
@@ -1740,7 +1741,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
17401741
const uint8x8_t v8 = vld1_u8(pp + l/2);
17411742

17421743
// Expand 4-bit qs to 8-bit bytes
1743-
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1744+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
17441745
const uint8x8_t v1 = vshr_n_u8(v8, 4);
17451746

17461747
// Interleave and combine
@@ -1782,7 +1783,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
17821783
for (int l = 0; l < QK4_1; l += 2) {
17831784
const uint8_t vi = pp[l/2];
17841785

1785-
const int8_t vi0 = vi & 0xf;
1786+
const int8_t vi0 = vi & 0x0F;
17861787
const int8_t vi1 = vi >> 4;
17871788

17881789
const float v0 = vi0*d + m;
@@ -1812,7 +1813,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
18121813
for (int l = 0; l < QK4_2; l += 2) {
18131814
const uint8_t vi = pp[l/2];
18141815

1815-
const int8_t vi0 = vi & 0xf;
1816+
const int8_t vi0 = vi & 0x0F;
18161817
const int8_t vi1 = vi >> 4;
18171818

18181819
const float v0 = (vi0 - 8)*d;
@@ -1842,7 +1843,7 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
18421843
for (int l = 0; l < QK4_3; l += 2) {
18431844
const uint8_t vi = pp[l/2];
18441845

1845-
const int8_t vi0 = vi & 0xf;
1846+
const int8_t vi0 = vi & 0x0F;
18461847
const int8_t vi1 = vi >> 4;
18471848

18481849
const float v0 = vi0*d + m;
@@ -1874,11 +1875,12 @@ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, in
18741875
for (int l = 0; l < QK5_0; l += 2) {
18751876
const uint8_t vi = pp[l/2];
18761877

1877-
const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
1878-
const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
1878+
// extract the 5-th bit from qh
1879+
const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
1880+
const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
18791881

1880-
const int8_t vi0 = (vi & 0xf) | vh0;
1881-
const int8_t vi1 = (vi >> 4) | vh1;
1882+
const uint8_t vi0 = (vi & 0x0F) | vh0;
1883+
const uint8_t vi1 = (vi >> 4) | vh1;
18821884

18831885
const float v0 = vi0*d + m;
18841886
const float v1 = vi1*d + m;
@@ -2593,7 +2595,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25932595
const block_q8_0 * restrict y0 = &y[i + 0];
25942596
const block_q8_0 * restrict y1 = &y[i + 1];
25952597

2596-
const uint8x16_t m4b = vdupq_n_u8(0xf);
2598+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
25972599
const int8x16_t s8b = vdupq_n_s8(0x8);
25982600

25992601
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2729,8 +2731,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
27292731
for (int j = 0; j < QK8_0/2; j++) {
27302732
const uint8_t v0 = p0[j];
27312733

2732-
const int i0 = (int8_t) (v0 & 0xf) - 8;
2733-
const int i1 = (int8_t) (v0 >> 4) - 8;
2734+
const int i0 = (int8_t) (v0 & 0x0F) - 8;
2735+
const int i1 = (int8_t) (v0 >> 4) - 8;
27342736

27352737
const int i2 = p1[2*j + 0];
27362738
const int i3 = p1[2*j + 1];
@@ -2767,7 +2769,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
27672769

27682770
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
27692771

2770-
const uint8x16_t m4b = vdupq_n_u8(0xf);
2772+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
27712773

27722774
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
27732775
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2864,8 +2866,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
28642866
for (int j = 0; j < QK8_1/2; j++) {
28652867
const uint8_t v0 = p0[j];
28662868

2867-
const float f0 = d0*(v0 & 0xf) + m0;
2868-
const float f1 = d0*(v0 >> 4) + m0;
2869+
const float f0 = d0*(v0 & 0x0F) + m0;
2870+
const float f1 = d0*(v0 >> 4) + m0;
28692871

28702872
const float f2 = d1*p1[2*j + 0];
28712873
const float f3 = d1*p1[2*j + 1];
@@ -2900,7 +2902,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
29002902
const block_q8_0 * restrict y0 = &y[i + 0];
29012903
const block_q8_0 * restrict y1 = &y[i + 1];
29022904

2903-
const uint8x16_t m4b = vdupq_n_u8(0xf);
2905+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
29042906
const int8x16_t s8b = vdupq_n_s8(0x8);
29052907

29062908
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
@@ -3011,11 +3013,11 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
30113013
const uint8_t v0 = x0[j];
30123014
const uint8_t v1 = x1[j];
30133015

3014-
const int i0_0 = (int8_t) (v0 & 0xf) - 8;
3015-
const int i1_0 = (int8_t) (v0 >> 4) - 8;
3016+
const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
3017+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
30163018

3017-
const int i0_1 = (int8_t) (v1 & 0xf) - 8;
3018-
const int i1_1 = (int8_t) (v1 >> 4) - 8;
3019+
const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
3020+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
30193021

30203022
const int i2_0 = y0[2*j + 0];
30213023
const int i3_0 = y0[2*j + 1];
@@ -3063,7 +3065,7 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
30633065
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
30643066

30653067
// 4-bit -> 8-bit
3066-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf)));
3068+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
30673069
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
30683070

30693071
// interleave
@@ -3142,10 +3144,10 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
31423144
const uint8_t v0 = x0[j];
31433145
const uint8_t v1 = x1[j];
31443146

3145-
const int x0_0 = v0 & 0xf;
3147+
const int x0_0 = v0 & 0x0F;
31463148
const int x1_0 = v0 >> 4;
31473149

3148-
const int x0_1 = v1 & 0xf;
3150+
const int x0_1 = v1 & 0x0F;
31493151
const int x1_1 = v1 >> 4;
31503152

31513153
const int y0_0 = y0[2*j + 0];
@@ -3195,7 +3197,7 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void *
31953197
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
31963198

31973199
// 4-bit -> 8-bit
3198-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf)));
3200+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
31993201
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
32003202

32013203
// interleave
@@ -3274,10 +3276,10 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void *
32743276
const uint8_t v0 = x0[j];
32753277
const uint8_t v1 = x1[j];
32763278

3277-
const int x0_0 = v0 & 0xf;
3279+
const int x0_0 = v0 & 0x0F;
32783280
const int x1_0 = v0 >> 4;
32793281

3280-
const int x0_1 = v1 & 0xf;
3282+
const int x0_1 = v1 & 0x0F;
32813283
const int x1_1 = v1 >> 4;
32823284

32833285
const int y0_0 = y0[2*j + 0];
@@ -12500,7 +12502,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
1250012502

1250112503
for (int i = 0; i < nb; i++) {
1250212504
for (int l = 0; l < QK4_0; l += 2) {
12503-
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12505+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
1250412506
const uint8_t vi1 = y[i].qs[l/2] >> 4;
1250512507

1250612508
hist[vi0]++;
@@ -12523,7 +12525,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
1252312525

1252412526
for (int i = 0; i < nb; i++) {
1252512527
for (int l = 0; l < QK4_1; l += 2) {
12526-
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12528+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
1252712529
const uint8_t vi1 = y[i].qs[l/2] >> 4;
1252812530

1252912531
hist[vi0]++;
@@ -12546,7 +12548,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
1254612548

1254712549
for (int i = 0; i < nb; i++) {
1254812550
for (int l = 0; l < QK4_2; l += 2) {
12549-
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12551+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
1255012552
const uint8_t vi1 = y[i].qs[l/2] >> 4;
1255112553

1255212554
hist[vi0]++;
@@ -12569,7 +12571,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
1256912571

1257012572
for (int i = 0; i < nb; i++) {
1257112573
for (int l = 0; l < QK4_3; l += 2) {
12572-
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12574+
const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
1257312575
const uint8_t vi1 = y[i].qs[l/2] >> 4;
1257412576

1257512577
hist[vi0]++;
@@ -12590,11 +12592,14 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t *
1259012592

1259112593
quantize_row_q5_0_reference(src + j, y, k);
1259212594

12593-
// TODO: this is wrong
1259412595
for (int i = 0; i < nb; i++) {
1259512596
for (int l = 0; l < QK5_0; l += 2) {
12596-
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12597-
const uint8_t vi1 = y[i].qs[l/2] >> 4;
12597+
const uint8_t vh0 = ((y[i].qh & (1 << (l + 0))) >> (l + 0)) << 4;
12598+
const uint8_t vh1 = ((y[i].qh & (1 << (l + 1))) >> (l + 1)) << 4;
12599+
12600+
// cast to 16 bins
12601+
const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12602+
const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
1259812603

1259912604
hist[vi0]++;
1260012605
hist[vi1]++;

0 commit comments

Comments
 (0)