@@ -1768,10 +1768,8 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
1768
1768
float scale = suml2 ? sumlx/suml2 : 0.0f;
1769
1769
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
1770
1770
float best = scale * sumlx;
1771
+ float best_sumlx = sumlx, best_suml2 = suml2;
1771
1772
for (int is = -9; is <= 9; ++is) {
1772
- if (is == 0) {
1773
- continue;
1774
- }
1775
1773
iscale = -(nmax + 0.1f*is) / max;
1776
1774
sumlx = suml2 = 0;
1777
1775
for (int i = 0; i < n; ++i) {
@@ -1787,7 +1785,66 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
1787
1785
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
1788
1786
}
1789
1787
scale = sumlx/suml2; best = scale*sumlx;
1788
+ best_sumlx = sumlx; best_suml2 = suml2;
1789
+ }
1790
+ iscale = (nmax-1 + 0.1f*is) / max;
1791
+ sumlx = suml2 = 0;
1792
+ for (int i = 0; i < n; ++i) {
1793
+ int l = nearest_int(iscale * x[i]);
1794
+ l = MAX(-nmax, MIN(nmax-1, l));
1795
+ float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
1796
+ sumlx += w*x[i]*l;
1797
+ suml2 += w*l*l;
1790
1798
}
1799
+ if (suml2 > 0 && sumlx*sumlx > best*suml2) {
1800
+ for (int i = 0; i < n; ++i) {
1801
+ int l = nearest_int(iscale * x[i]);
1802
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
1803
+ }
1804
+ scale = sumlx/suml2; best = scale*sumlx;
1805
+ best_sumlx = sumlx; best_suml2 = suml2;
1806
+ }
1807
+ }
1808
+
1809
+ sumlx = best_sumlx; suml2 = best_suml2;
1810
+ for (int iter = 0; iter < n*(2*nmax-1); ++iter) {
1811
+ float abs_gmax = 0, gmax = 0;
1812
+ int best_j = -1;
1813
+ for (int j = 0; j < n; ++j) {
1814
+ float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j]));
1815
+ int l = L[j] - nmax;
1816
+ float g = scale * w * (x[j] - scale*l);
1817
+ if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) {
1818
+ float ag = fabsf(g);
1819
+ if (ag > abs_gmax) {
1820
+ abs_gmax = ag; gmax = g; best_j = j;
1821
+ }
1822
+ }
1823
+ }
1824
+ if (best_j < 0) break;
1825
+
1826
+ float new_sumlx = sumlx, new_suml2 = suml2;
1827
+ float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j]));
1828
+ int l = L[best_j] - nmax;
1829
+ if (gmax > 0) {
1830
+ new_sumlx += w*x[best_j];
1831
+ new_suml2 += w*(2*l + 1);
1832
+ l += 1;
1833
+ } else {
1834
+ new_sumlx -= w*x[best_j];
1835
+ new_suml2 -= w*(2*l - 1);
1836
+ l -= 1;
1837
+ }
1838
+ if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) {
1839
+ sumlx = new_sumlx; suml2 = new_suml2;
1840
+ scale = sumlx/suml2; best = scale*sumlx;
1841
+ L[best_j] = l + nmax;
1842
+ GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1);
1843
+ }
1844
+ else {
1845
+ break;
1846
+ }
1847
+
1791
1848
}
1792
1849
return scale;
1793
1850
}
@@ -3254,8 +3311,12 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
3254
3311
const int64_t nb = n_per_row/QK4_0;
3255
3312
for (int ib = 0; ib < nb; ++ib) {
3256
3313
const float * xb = x + QK4_0 * ib;
3257
- const float * qw = quant_weights + QK4_0 * ib;
3258
- for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3314
+ if (quant_weights) {
3315
+ const float * qw = quant_weights + QK4_0 * ib;
3316
+ for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
3317
+ } else {
3318
+ for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j];
3319
+ }
3259
3320
float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
3260
3321
y[ib].d = GGML_FP32_TO_FP16(d);
3261
3322
for (int j = 0; j < 16; ++j) {
@@ -14581,6 +14642,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
14581
14642
}
14582
14643
d = sumqx/sumq2;
14583
14644
float best = d*sumqx;
14645
+ float best_sumqx = sumqx, best_sumq2 = sumq2;
14584
14646
for (int itry = -ntry; itry <= ntry; ++itry) {
14585
14647
id = (itry + values[0])/max;
14586
14648
sumqx = sumq2 = 0;
@@ -14594,8 +14656,67 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
14594
14656
}
14595
14657
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
14596
14658
d = sumqx/sumq2; best = d * sumqx;
14659
+ best_sumqx = sumqx; best_sumq2 = sumq2;
14660
+ for (int j = 0; j < block_size; ++j) {
14661
+ float al = id*xb[j];
14662
+ Lb[j] = best_index_iq4nl(values, al);
14663
+ }
14664
+ }
14665
+ id = (itry + values[15])/max;
14666
+ sumqx = sumq2 = 0;
14667
+ for (int j = 0; j < block_size; ++j) {
14668
+ float al = id*xb[j];
14669
+ int l = best_index_iq4nl(values, al);
14670
+ float q = values[l];
14671
+ float w = weight[j];
14672
+ sumqx += w*q*xb[j];
14673
+ sumq2 += w*q*q;
14674
+ }
14675
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
14676
+ d = sumqx/sumq2; best = d * sumqx;
14677
+ best_sumqx = sumqx; best_sumq2 = sumq2;
14678
+ for (int j = 0; j < block_size; ++j) {
14679
+ float al = id*xb[j];
14680
+ Lb[j] = best_index_iq4nl(values, al);
14681
+ }
14597
14682
}
14598
14683
}
14684
+ sumqx = best_sumqx; sumq2 = best_sumq2;
14685
+ for (int iter = 0; iter < 32*block_size; ++iter) {
14686
+ float min_step = INFINITY;
14687
+ int best_j = -1; int dir = 0;
14688
+ for (int j = 0; j < block_size; ++j) {
14689
+ float w = weight[j];
14690
+ float g = d * w * (xb[j] - d*values[Lb[j]]);
14691
+ if (g > 0 && Lb[j] < 15) {
14692
+ float step = (values[Lb[j]+1] - values[Lb[j]])/g;
14693
+ if (step < min_step) {
14694
+ min_step = step; best_j = j; dir = 1;
14695
+ }
14696
+ }
14697
+ else if (g < 0 && Lb[j] > 0) {
14698
+ float step = (values[Lb[j]-1] - values[Lb[j]])/g;
14699
+ if (step < min_step) {
14700
+ min_step = step; best_j = j; dir = -1;
14701
+ }
14702
+ }
14703
+ }
14704
+ if (best_j < 0) break;
14705
+
14706
+ float new_sumqx = sumqx, new_sumq2 = sumq2;
14707
+ float w = weight[best_j];
14708
+ new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
14709
+ new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
14710
+ if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
14711
+ sumqx = new_sumqx; sumq2 = new_sumq2;
14712
+ d = sumqx/sumq2; best = d*sumqx;
14713
+ Lb[best_j] += dir;
14714
+ }
14715
+ else {
14716
+ break;
14717
+ }
14718
+ }
14719
+
14599
14720
scales[ib] = d;
14600
14721
float abs_d = fabsf(d);
14601
14722
if (abs_d > amax_scale) {
0 commit comments