Skip to content

Commit d9fa3c1

Browse files
committed
Add gradient-based make_qx_quants and iq4nl
From <ikawrakow/ik_llama.cpp#295>
1 parent f479d42 commit d9fa3c1

File tree

2 files changed

+401
-0
lines changed

2 files changed

+401
-0
lines changed

rounding-impl.c

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,132 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
173173
return scale;
174174
}
175175

176+
static float make_qxg_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
177+
const float * restrict qw) {
178+
float max = 0;
179+
float amax = 0;
180+
for (int i = 0; i < n; ++i) {
181+
float ax = fabsf(x[i]);
182+
if (ax > amax) { amax = ax; max = x[i]; }
183+
}
184+
if (amax < GROUP_MAX_EPS) { // all zero
185+
for (int i = 0; i < n; ++i) {
186+
L[i] = 0;
187+
}
188+
return 0.f;
189+
}
190+
float iscale = -nmax / max;
191+
if (rmse_type == 0) {
192+
for (int i = 0; i < n; ++i) {
193+
int l = nearest_int(iscale * x[i]);
194+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
195+
}
196+
return 1/iscale;
197+
}
198+
bool return_early = false;
199+
if (rmse_type < 0) {
200+
rmse_type = -rmse_type;
201+
return_early = true;
202+
}
203+
float sumlx = 0;
204+
float suml2 = 0;
205+
#ifdef HAVE_BUGGY_APPLE_LINKER
206+
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
207+
for (volatile int i = 0; i < n; ++i) {
208+
#else
209+
for (int i = 0; i < n; ++i) {
210+
#endif
211+
int l = nearest_int(iscale * x[i]);
212+
l = MAX(-nmax, MIN(nmax-1, l));
213+
L[i] = l + nmax;
214+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
215+
sumlx += w*x[i]*l;
216+
suml2 += w*l*l;
217+
}
218+
float scale = suml2 ? sumlx/suml2 : 0.0f;
219+
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
220+
float best = scale * sumlx;
221+
float best_sumlx = sumlx, best_suml2 = suml2;
222+
for (int is = -9; is <= 9; ++is) {
223+
iscale = -(nmax + 0.1f*is) / max;
224+
sumlx = suml2 = 0;
225+
for (int i = 0; i < n; ++i) {
226+
int l = nearest_int(iscale * x[i]);
227+
l = MAX(-nmax, MIN(nmax-1, l));
228+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
229+
sumlx += w*x[i]*l;
230+
suml2 += w*l*l;
231+
}
232+
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
233+
for (int i = 0; i < n; ++i) {
234+
int l = nearest_int(iscale * x[i]);
235+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
236+
}
237+
scale = sumlx/suml2; best = scale*sumlx;
238+
best_sumlx = sumlx; best_suml2 = suml2;
239+
}
240+
iscale = (nmax-1 + 0.1f*is) / max;
241+
sumlx = suml2 = 0;
242+
for (int i = 0; i < n; ++i) {
243+
int l = nearest_int(iscale * x[i]);
244+
l = MAX(-nmax, MIN(nmax-1, l));
245+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
246+
sumlx += w*x[i]*l;
247+
suml2 += w*l*l;
248+
}
249+
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
250+
for (int i = 0; i < n; ++i) {
251+
int l = nearest_int(iscale * x[i]);
252+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
253+
}
254+
scale = sumlx/suml2; best = scale*sumlx;
255+
best_sumlx = sumlx; best_suml2 = suml2;
256+
}
257+
}
258+
259+
sumlx = best_sumlx; suml2 = best_suml2;
260+
for (int iter = 0; iter < n*(2*nmax-1); ++iter) {
261+
float abs_gmax = 0, gmax = 0;
262+
int best_j = -1;
263+
for (int j = 0; j < n; ++j) {
264+
float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j]));
265+
int l = L[j] - nmax;
266+
float g = scale * w * (x[j] - scale*l);
267+
if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) {
268+
float ag = fabsf(g);
269+
if (ag > abs_gmax) {
270+
abs_gmax = ag; gmax = g; best_j = j;
271+
}
272+
}
273+
}
274+
if (best_j < 0) break;
275+
276+
float new_sumlx = sumlx, new_suml2 = suml2;
277+
float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j]));
278+
int l = L[best_j] - nmax;
279+
if (gmax > 0) {
280+
new_sumlx += w*x[best_j];
281+
new_suml2 += w*(2*l + 1);
282+
l += 1;
283+
} else {
284+
new_sumlx -= w*x[best_j];
285+
new_suml2 -= w*(2*l - 1);
286+
l -= 1;
287+
}
288+
if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) {
289+
sumlx = new_sumlx; suml2 = new_suml2;
290+
scale = sumlx/suml2; best = scale*sumlx;
291+
L[best_j] = l + nmax;
292+
GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1);
293+
}
294+
else {
295+
break;
296+
}
297+
298+
}
299+
return scale;
300+
}
301+
176302
static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
177303
float max = 0;
178304
float amax = 0;
@@ -634,6 +760,194 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
634760
}
635761
}
636762

763+
static const int8_t iq4nl_index[241] = {
764+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
765+
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
766+
3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
767+
5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8,
768+
8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26,
769+
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13,
770+
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
771+
14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15
772+
};
773+
static inline int best_index_iq4nl(const int8_t * values, float x) {
774+
int ix = (int)x - values[0];
775+
if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15;
776+
ix = iq4nl_index[ix];
777+
return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15;
778+
}
779+
780+
static void quantize_row_iq4_nl_g_impl(const int super_block_size, const int block_size, const float * restrict x,
781+
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
782+
float * scales, float * weight, uint8_t * L,
783+
const int8_t * values,
784+
const float * quant_weights,
785+
const int ntry) {
786+
787+
float sigma2 = 0;
788+
for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
789+
sigma2 *= 2.f/super_block_size;
790+
791+
memset(q4, 0, super_block_size/2);
792+
dh[0] = GGML_FP32_TO_FP16(0.f);
793+
794+
float max_scale = 0, amax_scale = 0;
795+
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
796+
const float * xb = x + ib*block_size;
797+
uint8_t * Lb = L + ib*block_size;
798+
if (quant_weights) {
799+
const float * qw = quant_weights + ib*block_size;
800+
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
801+
} else {
802+
for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
803+
}
804+
float amax = 0, max = 0;
805+
for (int j = 0; j < block_size; ++j) {
806+
float ax = fabsf(xb[j]);
807+
if (ax > amax) {
808+
amax = ax; max = xb[j];
809+
}
810+
}
811+
if (amax < GROUP_MAX_EPS) {
812+
scales[ib] = 0;
813+
continue;
814+
}
815+
float d = ntry > 0 ? -max/values[0] : max/values[0];
816+
float id = 1/d;
817+
float sumqx = 0, sumq2 = 0;
818+
for (int j = 0; j < block_size; ++j) {
819+
float al = id*xb[j];
820+
int l = best_index_iq4nl(values, al);
821+
Lb[j] = l;
822+
float q = values[l];
823+
float w = weight[j];
824+
sumqx += w*q*xb[j];
825+
sumq2 += w*q*q;
826+
}
827+
d = sumqx/sumq2;
828+
float best = d*sumqx;
829+
float best_sumqx = sumqx, best_sumq2 = sumq2;
830+
for (int itry = -ntry; itry <= ntry; ++itry) {
831+
id = (itry + values[0])/max;
832+
sumqx = sumq2 = 0;
833+
for (int j = 0; j < block_size; ++j) {
834+
float al = id*xb[j];
835+
int l = best_index_iq4nl(values, al);
836+
float q = values[l];
837+
float w = weight[j];
838+
sumqx += w*q*xb[j];
839+
sumq2 += w*q*q;
840+
}
841+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
842+
d = sumqx/sumq2; best = d * sumqx;
843+
best_sumqx = sumqx; best_sumq2 = sumq2;
844+
for (int j = 0; j < block_size; ++j) {
845+
float al = id*xb[j];
846+
Lb[j] = best_index_iq4nl(values, al);
847+
}
848+
}
849+
id = (itry + values[15])/max;
850+
sumqx = sumq2 = 0;
851+
for (int j = 0; j < block_size; ++j) {
852+
float al = id*xb[j];
853+
int l = best_index_iq4nl(values, al);
854+
float q = values[l];
855+
float w = weight[j];
856+
sumqx += w*q*xb[j];
857+
sumq2 += w*q*q;
858+
}
859+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
860+
d = sumqx/sumq2; best = d * sumqx;
861+
best_sumqx = sumqx; best_sumq2 = sumq2;
862+
for (int j = 0; j < block_size; ++j) {
863+
float al = id*xb[j];
864+
Lb[j] = best_index_iq4nl(values, al);
865+
}
866+
}
867+
}
868+
sumqx = best_sumqx; sumq2 = best_sumq2;
869+
for (int iter = 0; iter < 32*block_size; ++iter) {
870+
float min_step = INFINITY;
871+
int best_j = -1; int dir = 0;
872+
for (int j = 0; j < block_size; ++j) {
873+
float w = weight[j];
874+
float g = d * w * (xb[j] - d*values[Lb[j]]);
875+
if (g > 0 && Lb[j] < 15) {
876+
float step = (values[Lb[j]+1] - values[Lb[j]])/g;
877+
if (step < min_step) {
878+
min_step = step; best_j = j; dir = 1;
879+
}
880+
}
881+
else if (g < 0 && Lb[j] > 0) {
882+
float step = (values[Lb[j]-1] - values[Lb[j]])/g;
883+
if (step < min_step) {
884+
min_step = step; best_j = j; dir = -1;
885+
}
886+
}
887+
}
888+
if (best_j < 0) break;
889+
890+
float new_sumqx = sumqx, new_sumq2 = sumq2;
891+
float w = weight[best_j];
892+
new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
893+
new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
894+
if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
895+
sumqx = new_sumqx; sumq2 = new_sumq2;
896+
d = sumqx/sumq2; best = d*sumqx;
897+
Lb[best_j] += dir;
898+
}
899+
else {
900+
break;
901+
}
902+
}
903+
904+
scales[ib] = d;
905+
float abs_d = fabsf(d);
906+
if (abs_d > amax_scale) {
907+
amax_scale = abs_d; max_scale = d;
908+
}
909+
}
910+
911+
if (super_block_size/block_size > 1) {
912+
int nb = super_block_size/block_size;
913+
memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));
914+
float d = -max_scale/32;
915+
dh[0] = GGML_FP32_TO_FP16(d);
916+
float id = d ? 1/d : 0.f;
917+
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
918+
int l = nearest_int(id*scales[ib]);
919+
l = MAX(-32, MIN(31, l));
920+
float dl = d * l;
921+
float idl = dl ? 1/dl : 0.f;
922+
uint8_t * Lb = L + ib*block_size;
923+
const float * xb = x + ib*block_size;
924+
for (int j = 0; j < block_size; ++j) {
925+
Lb[j] = best_index_iq4nl(values, idl*xb[j]);
926+
}
927+
l += 32;
928+
uint8_t l_l = l & 0xf;
929+
uint8_t l_h = l >> 4;
930+
if (ib%2 == 0) scales_l[ib/2] = l_l;
931+
else scales_l[ib/2] |= (l_l << 4);
932+
scales_h[ib/8] |= (l_h << 2*(ib%8));
933+
}
934+
} else {
935+
dh[0] = GGML_FP32_TO_FP16(scales[0]);
936+
if (ntry > 0) {
937+
float id = scales[0] ? 1/scales[0] : 0;
938+
for (int j = 0; j < super_block_size; ++j) {
939+
L[j] = best_index_iq4nl(values, id*x[j]);
940+
}
941+
}
942+
}
943+
944+
for (int i = 0; i < super_block_size/32; ++i) {
945+
for (int j = 0; j < 16; ++j) {
946+
q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
947+
}
948+
}
949+
}
950+
637951
// ---- Custom experiments ----
638952

639953
struct fraction {
@@ -2636,6 +2950,16 @@ void anyrize_qx(const float * x, const float * w, float * v, int ne0, int ne1, i
26362950
}
26372951
}
26382952

2953+
void anyrize_qxg(const float * x, const float * w, float * v, int ne0, int ne1, int nmax) {
2954+
int8_t L[ne0];
2955+
for (int i = 0; i < ne1; ++i) {
2956+
float scale = make_qxg_quants(ne0, nmax, x + ne0*i, L, 1, w ? w + i*ne0 : NULL);
2957+
for (int j = 0; j < ne0; ++j) {
2958+
v[i*ne0 + j] = (L[j] - nmax) * scale;
2959+
}
2960+
}
2961+
}
2962+
26392963
void anyrize_qkxs(const float * x, const float * w, float * v, int ne0, int ne1, int nmin, int nmax, bool signed_scale) {
26402964
struct fraction Faux[ne0 * MAX(abs(nmin), abs(nmax))];
26412965
int8_t L[ne0];
@@ -2832,6 +3156,23 @@ void anyrize_iq4nl(const float * x, const float * w, float * v, int ne0, int ne1
28323156
}
28333157
}
28343158

3159+
void anyrize_iq4nl_g(const float * x, const float * w, float * v, int ne0, int ne1) {
3160+
uint8_t L[ne0];
3161+
uint8_t Laux[ne0];
3162+
ggml_fp16_t unused_dh;
3163+
uint8_t unused_q4[ne0];
3164+
uint16_t unused_h;
3165+
uint8_t * unused_l = NULL;
3166+
float weight[ne0];
3167+
for (int i = 0; i < ne1; ++i) {
3168+
float scale = 0.0f;
3169+
quantize_row_iq4_nl_g_impl(ne0, ne0, x + i*ne0, &unused_dh, unused_q4, &unused_h, unused_l, &scale, weight, L, kvalues_iq4nl, w ? w + i*ne0 : NULL, 7);
3170+
for (int j = 0; j < ne0; ++j) {
3171+
v[i*ne0 + j] = kvalues_iq4nl[L[j]] * scale;
3172+
}
3173+
}
3174+
}
3175+
28353176
void anyrize_qkxs_iq4nl(const float * x, const float * w, float * v, int ne0, int ne1) {
28363177
uint8_t L[ne0];
28373178
uint8_t Laux[ne0];

0 commit comments

Comments
 (0)