Skip to content

Quantization improvements #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 126 additions & 5 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -1768,10 +1768,8 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
float scale = suml2 ? sumlx/suml2 : 0.0f;
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
float best = scale * sumlx;
float best_sumlx = sumlx, best_suml2 = suml2;
for (int is = -9; is <= 9; ++is) {
if (is == 0) {
continue;
}
iscale = -(nmax + 0.1f*is) / max;
sumlx = suml2 = 0;
for (int i = 0; i < n; ++i) {
Expand All @@ -1787,7 +1785,66 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
}
scale = sumlx/suml2; best = scale*sumlx;
best_sumlx = sumlx; best_suml2 = suml2;
}
iscale = (nmax-1 + 0.1f*is) / max;
sumlx = suml2 = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MAX(-nmax, MIN(nmax-1, l));
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
sumlx += w*x[i]*l;
suml2 += w*l*l;
}
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
}
scale = sumlx/suml2; best = scale*sumlx;
best_sumlx = sumlx; best_suml2 = suml2;
}
}

sumlx = best_sumlx; suml2 = best_suml2;
for (int iter = 0; iter < n*(2*nmax-1); ++iter) {
float abs_gmax = 0, gmax = 0;
int best_j = -1;
for (int j = 0; j < n; ++j) {
float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j]));
int l = L[j] - nmax;
float g = scale * w * (x[j] - scale*l);
if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) {
float ag = fabsf(g);
if (ag > abs_gmax) {
abs_gmax = ag; gmax = g; best_j = j;
}
}
}
if (best_j < 0) break;

float new_sumlx = sumlx, new_suml2 = suml2;
float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j]));
int l = L[best_j] - nmax;
if (gmax > 0) {
new_sumlx += w*x[best_j];
new_suml2 += w*(2*l + 1);
l += 1;
} else {
new_sumlx -= w*x[best_j];
new_suml2 -= w*(2*l - 1);
l -= 1;
}
if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) {
sumlx = new_sumlx; suml2 = new_suml2;
scale = sumlx/suml2; best = scale*sumlx;
L[best_j] = l + nmax;
GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1);
}
else {
break;
}

}
return scale;
}
Expand Down Expand Up @@ -3254,8 +3311,12 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
const int64_t nb = n_per_row/QK4_0;
for (int ib = 0; ib < nb; ++ib) {
const float * xb = x + QK4_0 * ib;
const float * qw = quant_weights + QK4_0 * ib;
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
if (quant_weights) {
const float * qw = quant_weights + QK4_0 * ib;
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
} else {
for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j];
}
float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
y[ib].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 16; ++j) {
Expand Down Expand Up @@ -14581,6 +14642,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
}
d = sumqx/sumq2;
float best = d*sumqx;
float best_sumqx = sumqx, best_sumq2 = sumq2;
for (int itry = -ntry; itry <= ntry; ++itry) {
id = (itry + values[0])/max;
sumqx = sumq2 = 0;
Expand All @@ -14594,8 +14656,67 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
d = sumqx/sumq2; best = d * sumqx;
best_sumqx = sumqx; best_sumq2 = sumq2;
for (int j = 0; j < block_size; ++j) {
float al = id*xb[j];
Lb[j] = best_index_iq4nl(values, al);
}
}
id = (itry + values[15])/max;
sumqx = sumq2 = 0;
for (int j = 0; j < block_size; ++j) {
float al = id*xb[j];
int l = best_index_iq4nl(values, al);
float q = values[l];
float w = weight[j];
sumqx += w*q*xb[j];
sumq2 += w*q*q;
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
d = sumqx/sumq2; best = d * sumqx;
best_sumqx = sumqx; best_sumq2 = sumq2;
for (int j = 0; j < block_size; ++j) {
float al = id*xb[j];
Lb[j] = best_index_iq4nl(values, al);
}
}
}
sumqx = best_sumqx; sumq2 = best_sumq2;
for (int iter = 0; iter < 32*block_size; ++iter) {
float min_step = INFINITY;
int best_j = -1; int dir = 0;
for (int j = 0; j < block_size; ++j) {
float w = weight[j];
float g = d * w * (xb[j] - d*values[Lb[j]]);
if (g > 0 && Lb[j] < 15) {
float step = (values[Lb[j]+1] - values[Lb[j]])/g;
if (step < min_step) {
min_step = step; best_j = j; dir = 1;
}
}
else if (g < 0 && Lb[j] > 0) {
float step = (values[Lb[j]-1] - values[Lb[j]])/g;
if (step < min_step) {
min_step = step; best_j = j; dir = -1;
}
}
}
if (best_j < 0) break;

float new_sumqx = sumqx, new_sumq2 = sumq2;
float w = weight[best_j];
new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
sumqx = new_sumqx; sumq2 = new_sumq2;
d = sumqx/sumq2; best = d*sumqx;
Lb[best_j] += dir;
}
else {
break;
}
}

scales[ib] = d;
float abs_d = fabsf(d);
if (abs_d > amax_scale) {
Expand Down