Skip to content

Commit a5b1943

Browse files
committed
ggml-quants : fix some edge cases in make_qkxh_nl_quants
1 parent 8b8b88f commit a5b1943

File tree

1 file changed

+28
-38
lines changed

1 file changed

+28
-38
lines changed

ggml/src/ggml-quants.c

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,10 +1149,11 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
11491149
amax = ax;
11501150
amax_i = i;
11511151
}
1152-
Laux[i] = k_heap->mid_k;
11531152
sumlx += w * x[i] * kmin;
11541153
suml2 += w * kmin * kmin;
11551154
}
1155+
memset(Laux, k_heap->mid_k, n);
1156+
memset(L, k_heap->mid_k, n);
11561157

11571158
const bool neg_scale = signed_scale && fast ? (x[amax_i] < 0.0f) != (k_heap->kmax < 0) : false;
11581159

@@ -1163,57 +1164,49 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
11631164
float best_suml2;
11641165
if (suml2 != 0.0f) {
11651166
best = sumlx * sumlx;
1166-
best_sumlx = neg_scale ? -sumlx : sumlx;
1167-
best_suml2 = suml2 != 0.0f ? suml2 : 1.0f;
1167+
best_sumlx = sumlx; // can't change the sign of kmin
1168+
best_suml2 = suml2;
11681169
} else {
11691170
best = 0.0f;
11701171
best_sumlx = 0.0f;
11711172
best_suml2 = 1.0f;
11721173
}
1173-
{
1174-
float sumlx_p = neg_scale ? -sumlx : sumlx;
1175-
float suml2_p = suml2;
1176-
int best_p_i = -2; // not consecutive with 0..n_frac
1177-
int i = 0;
1178-
while (k_heap->n > 0) {
1179-
struct fraction frac = k_heap_pop(k_heap);
1180-
const int ii = frac.i;
1181-
const float w = weights ? weights[ii] : x[ii] * x[ii];
1182-
sumlx_p += w * frac.numer;
1183-
suml2_p += w * frac.denom;
1184-
const float current = sumlx_p * sumlx_p;
1185-
Laux[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1;
1186-
if (suml2_p > 0.0f && current * best_suml2 > best * suml2_p) {
1187-
best = current;
1188-
best_sumlx = neg_scale ? -sumlx_p : sumlx_p;
1189-
best_suml2 = suml2_p;
1190-
if (i == best_p_i + 1) {
1191-
// reduce copies for consecutive bests
1192-
L[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1;
1193-
} else {
1194-
for (int j = 0; j < n; ++j) {
1195-
L[j] = Laux[j];
1196-
}
1197-
}
1198-
best_p_i = i;
1174+
float sumlx_p = neg_scale ? -sumlx : sumlx;
1175+
float suml2_p = suml2;
1176+
int best_p_i = -1; // consecutive with 0..n_frac
1177+
for (int i = 0; k_heap->n > 0; ++i) {
1178+
struct fraction frac = k_heap_pop(k_heap);
1179+
const int ii = frac.i;
1180+
const float w = weights ? weights[ii] : x[ii] * x[ii];
1181+
sumlx_p += w * frac.numer;
1182+
suml2_p += w * frac.denom;
1183+
const float current = sumlx_p * sumlx_p;
1184+
Laux[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1;
1185+
if (suml2_p > 0.0f && current * best_suml2 > best * suml2_p) {
1186+
best = current;
1187+
best_sumlx = neg_scale ? -sumlx_p : sumlx_p;
1188+
best_suml2 = suml2_p;
1189+
if (i == best_p_i + 1) {
1190+
// reduce copies for consecutive bests
1191+
L[ii] += (x[ii] < 0.0f) != neg_scale ? -1 : 1;
1192+
} else {
1193+
memcpy(L, Laux, n);
11991194
}
1195+
best_p_i = i;
12001196
}
12011197
}
12021198

12031199
// Non-linear mappings are usually not symmetric, so try negating the scale
12041200
// This is the same as above, but keeping the old best if the new best is not better.
12051201
if (signed_scale && !fast) {
1206-
for (int i = 0; i < n; ++i) {
1207-
Laux[i] = k_heap->mid_k;
1208-
}
1202+
memset(Laux, k_heap->mid_k, n);
12091203

12101204
k_heap_set_x(k_heap, x, n, true);
12111205

12121206
float sumlx_n = -sumlx;
12131207
float suml2_n = suml2;
12141208
int best_n_i = -2; // not consecutive with 0..n_frac
1215-
int i = 0;
1216-
while (k_heap->n > 0) {
1209+
for (int i = 0; k_heap->n > 0; ++i) {
12171210
struct fraction frac = k_heap_pop(k_heap);
12181211
const int ii = frac.i;
12191212
const float w = weights ? weights[ii] : x[ii] * x[ii];
@@ -1229,13 +1222,10 @@ static float make_qkxh_nl_quants(int n, const float * GGML_RESTRICT x, const flo
12291222
// reduce copies for consecutive bests
12301223
L[ii] += x[ii] >= 0.0f ? -1 : 1;
12311224
} else {
1232-
for (int j = 0; j < n; ++j) {
1233-
L[j] = Laux[j];
1234-
}
1225+
memcpy(L, Laux, n);
12351226
}
12361227
best_n_i = i;
12371228
}
1238-
++i;
12391229
}
12401230
}
12411231

0 commit comments

Comments
 (0)