@@ -173,6 +173,132 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
173
173
return scale ;
174
174
}
175
175
176
+ static float make_qxg_quants (int n , int nmax , const float * restrict x , int8_t * restrict L , int rmse_type ,
177
+ const float * restrict qw ) {
178
+ float max = 0 ;
179
+ float amax = 0 ;
180
+ for (int i = 0 ; i < n ; ++ i ) {
181
+ float ax = fabsf (x [i ]);
182
+ if (ax > amax ) { amax = ax ; max = x [i ]; }
183
+ }
184
+ if (amax < GROUP_MAX_EPS ) { // all zero
185
+ for (int i = 0 ; i < n ; ++ i ) {
186
+ L [i ] = 0 ;
187
+ }
188
+ return 0.f ;
189
+ }
190
+ float iscale = - nmax / max ;
191
+ if (rmse_type == 0 ) {
192
+ for (int i = 0 ; i < n ; ++ i ) {
193
+ int l = nearest_int (iscale * x [i ]);
194
+ L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
195
+ }
196
+ return 1 /iscale ;
197
+ }
198
+ bool return_early = false;
199
+ if (rmse_type < 0 ) {
200
+ rmse_type = - rmse_type ;
201
+ return_early = true;
202
+ }
203
+ float sumlx = 0 ;
204
+ float suml2 = 0 ;
205
+ #ifdef HAVE_BUGGY_APPLE_LINKER
206
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
207
+ for (volatile int i = 0 ; i < n ; ++ i ) {
208
+ #else
209
+ for (int i = 0 ; i < n ; ++ i ) {
210
+ #endif
211
+ int l = nearest_int (iscale * x [i ]);
212
+ l = MAX (- nmax , MIN (nmax - 1 , l ));
213
+ L [i ] = l + nmax ;
214
+ float w = qw ? qw [i ] : rmse_type == 1 ? x [i ] * x [i ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [i ]) : sqrtf (fabsf (x [i ]));
215
+ sumlx += w * x [i ]* l ;
216
+ suml2 += w * l * l ;
217
+ }
218
+ float scale = suml2 ? sumlx /suml2 : 0.0f ;
219
+ if (return_early ) return suml2 > 0 ? 0.5f * (scale + 1 /iscale ) : 1 /iscale ;
220
+ float best = scale * sumlx ;
221
+ float best_sumlx = sumlx , best_suml2 = suml2 ;
222
+ for (int is = -9 ; is <= 9 ; ++ is ) {
223
+ iscale = - (nmax + 0.1f * is ) / max ;
224
+ sumlx = suml2 = 0 ;
225
+ for (int i = 0 ; i < n ; ++ i ) {
226
+ int l = nearest_int (iscale * x [i ]);
227
+ l = MAX (- nmax , MIN (nmax - 1 , l ));
228
+ float w = qw ? qw [i ] : rmse_type == 1 ? x [i ] * x [i ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [i ]) : sqrtf (fabsf (x [i ]));
229
+ sumlx += w * x [i ]* l ;
230
+ suml2 += w * l * l ;
231
+ }
232
+ if (suml2 > 0 && sumlx * sumlx > best * suml2 ) {
233
+ for (int i = 0 ; i < n ; ++ i ) {
234
+ int l = nearest_int (iscale * x [i ]);
235
+ L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
236
+ }
237
+ scale = sumlx /suml2 ; best = scale * sumlx ;
238
+ best_sumlx = sumlx ; best_suml2 = suml2 ;
239
+ }
240
+ iscale = (nmax - 1 + 0.1f * is ) / max ;
241
+ sumlx = suml2 = 0 ;
242
+ for (int i = 0 ; i < n ; ++ i ) {
243
+ int l = nearest_int (iscale * x [i ]);
244
+ l = MAX (- nmax , MIN (nmax - 1 , l ));
245
+ float w = qw ? qw [i ] : rmse_type == 1 ? x [i ] * x [i ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [i ]) : sqrtf (fabsf (x [i ]));
246
+ sumlx += w * x [i ]* l ;
247
+ suml2 += w * l * l ;
248
+ }
249
+ if (suml2 > 0 && sumlx * sumlx > best * suml2 ) {
250
+ for (int i = 0 ; i < n ; ++ i ) {
251
+ int l = nearest_int (iscale * x [i ]);
252
+ L [i ] = nmax + MAX (- nmax , MIN (nmax - 1 , l ));
253
+ }
254
+ scale = sumlx /suml2 ; best = scale * sumlx ;
255
+ best_sumlx = sumlx ; best_suml2 = suml2 ;
256
+ }
257
+ }
258
+
259
+ sumlx = best_sumlx ; suml2 = best_suml2 ;
260
+ for (int iter = 0 ; iter < n * (2 * nmax - 1 ); ++ iter ) {
261
+ float abs_gmax = 0 , gmax = 0 ;
262
+ int best_j = -1 ;
263
+ for (int j = 0 ; j < n ; ++ j ) {
264
+ float w = qw ? qw [j ] : rmse_type == 1 ? x [j ] * x [j ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [j ]) : sqrtf (fabsf (x [j ]));
265
+ int l = L [j ] - nmax ;
266
+ float g = scale * w * (x [j ] - scale * l );
267
+ if ((g > 0 && l < nmax - 1 ) || (g < 0 && l > - nmax )) {
268
+ float ag = fabsf (g );
269
+ if (ag > abs_gmax ) {
270
+ abs_gmax = ag ; gmax = g ; best_j = j ;
271
+ }
272
+ }
273
+ }
274
+ if (best_j < 0 ) break ;
275
+
276
+ float new_sumlx = sumlx , new_suml2 = suml2 ;
277
+ float w = qw ? qw [best_j ] : rmse_type == 1 ? x [best_j ] * x [best_j ] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf (x [best_j ]) : sqrtf (fabsf (x [best_j ]));
278
+ int l = L [best_j ] - nmax ;
279
+ if (gmax > 0 ) {
280
+ new_sumlx += w * x [best_j ];
281
+ new_suml2 += w * (2 * l + 1 );
282
+ l += 1 ;
283
+ } else {
284
+ new_sumlx -= w * x [best_j ];
285
+ new_suml2 -= w * (2 * l - 1 );
286
+ l -= 1 ;
287
+ }
288
+ if (new_suml2 > 0 && new_sumlx * new_sumlx > best * new_suml2 ) {
289
+ sumlx = new_sumlx ; suml2 = new_suml2 ;
290
+ scale = sumlx /suml2 ; best = scale * sumlx ;
291
+ L [best_j ] = l + nmax ;
292
+ GGML_ASSERT (L [best_j ] >= 0 && L [best_j ] <= 2 * nmax - 1 );
293
+ }
294
+ else {
295
+ break ;
296
+ }
297
+
298
+ }
299
+ return scale ;
300
+ }
301
+
176
302
static float make_q3_quants (int n , int nmax , const float * restrict x , int8_t * restrict L , bool do_rmse ) {
177
303
float max = 0 ;
178
304
float amax = 0 ;
@@ -634,6 +760,194 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
634
760
}
635
761
}
636
762
763
+ static const int8_t iq4nl_index [241 ] = {
764
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 16 , 16 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
765
+ 1 , 17 , 17 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 18 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 ,
766
+ 3 , 3 , 3 , 3 , 3 , 3 , 19 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 4 , 20 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 ,
767
+ 5 , 5 , 21 , 21 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 6 , 22 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 7 , 23 , 23 , 8 , 8 , 8 , 8 ,
768
+ 8 , 8 , 8 , 8 , 8 , 8 , 24 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 9 , 25 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 26 , 26 ,
769
+ 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 11 , 27 , 27 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 12 , 28 , 13 , 13 , 13 ,
770
+ 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 13 , 29 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 14 ,
771
+ 14 , 14 , 14 , 14 , 30 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15 , 15
772
+ };
773
+ static inline int best_index_iq4nl (const int8_t * values , float x ) {
774
+ int ix = (int )x - values [0 ];
775
+ if (ix < 0 || ix >= 241 ) return ix < 0 ? 0 : 15 ;
776
+ ix = iq4nl_index [ix ];
777
+ return ix < 16 ? ix : x - values [ix - 16 ] < values [ix - 15 ] - x ? ix - 16 : ix - 15 ;
778
+ }
779
+
780
+ static void quantize_row_iq4_nl_g_impl (const int super_block_size , const int block_size , const float * restrict x ,
781
+ ggml_fp16_t * dh , uint8_t * q4 , uint16_t * scales_h , uint8_t * scales_l ,
782
+ float * scales , float * weight , uint8_t * L ,
783
+ const int8_t * values ,
784
+ const float * quant_weights ,
785
+ const int ntry ) {
786
+
787
+ float sigma2 = 0 ;
788
+ for (int j = 0 ; j < super_block_size ; ++ j ) sigma2 += x [j ]* x [j ];
789
+ sigma2 *= 2.f /super_block_size ;
790
+
791
+ memset (q4 , 0 , super_block_size /2 );
792
+ dh [0 ] = GGML_FP32_TO_FP16 (0.f );
793
+
794
+ float max_scale = 0 , amax_scale = 0 ;
795
+ for (int ib = 0 ; ib < super_block_size /block_size ; ++ ib ) {
796
+ const float * xb = x + ib * block_size ;
797
+ uint8_t * Lb = L + ib * block_size ;
798
+ if (quant_weights ) {
799
+ const float * qw = quant_weights + ib * block_size ;
800
+ for (int j = 0 ; j < block_size ; ++ j ) weight [j ] = qw [j ] * sqrtf (sigma2 + xb [j ]* xb [j ]);
801
+ } else {
802
+ for (int j = 0 ; j < block_size ; ++ j ) weight [j ] = xb [j ]* xb [j ];
803
+ }
804
+ float amax = 0 , max = 0 ;
805
+ for (int j = 0 ; j < block_size ; ++ j ) {
806
+ float ax = fabsf (xb [j ]);
807
+ if (ax > amax ) {
808
+ amax = ax ; max = xb [j ];
809
+ }
810
+ }
811
+ if (amax < GROUP_MAX_EPS ) {
812
+ scales [ib ] = 0 ;
813
+ continue ;
814
+ }
815
+ float d = ntry > 0 ? - max /values [0 ] : max /values [0 ];
816
+ float id = 1 /d ;
817
+ float sumqx = 0 , sumq2 = 0 ;
818
+ for (int j = 0 ; j < block_size ; ++ j ) {
819
+ float al = id * xb [j ];
820
+ int l = best_index_iq4nl (values , al );
821
+ Lb [j ] = l ;
822
+ float q = values [l ];
823
+ float w = weight [j ];
824
+ sumqx += w * q * xb [j ];
825
+ sumq2 += w * q * q ;
826
+ }
827
+ d = sumqx /sumq2 ;
828
+ float best = d * sumqx ;
829
+ float best_sumqx = sumqx , best_sumq2 = sumq2 ;
830
+ for (int itry = - ntry ; itry <= ntry ; ++ itry ) {
831
+ id = (itry + values [0 ])/max ;
832
+ sumqx = sumq2 = 0 ;
833
+ for (int j = 0 ; j < block_size ; ++ j ) {
834
+ float al = id * xb [j ];
835
+ int l = best_index_iq4nl (values , al );
836
+ float q = values [l ];
837
+ float w = weight [j ];
838
+ sumqx += w * q * xb [j ];
839
+ sumq2 += w * q * q ;
840
+ }
841
+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
842
+ d = sumqx /sumq2 ; best = d * sumqx ;
843
+ best_sumqx = sumqx ; best_sumq2 = sumq2 ;
844
+ for (int j = 0 ; j < block_size ; ++ j ) {
845
+ float al = id * xb [j ];
846
+ Lb [j ] = best_index_iq4nl (values , al );
847
+ }
848
+ }
849
+ id = (itry + values [15 ])/max ;
850
+ sumqx = sumq2 = 0 ;
851
+ for (int j = 0 ; j < block_size ; ++ j ) {
852
+ float al = id * xb [j ];
853
+ int l = best_index_iq4nl (values , al );
854
+ float q = values [l ];
855
+ float w = weight [j ];
856
+ sumqx += w * q * xb [j ];
857
+ sumq2 += w * q * q ;
858
+ }
859
+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
860
+ d = sumqx /sumq2 ; best = d * sumqx ;
861
+ best_sumqx = sumqx ; best_sumq2 = sumq2 ;
862
+ for (int j = 0 ; j < block_size ; ++ j ) {
863
+ float al = id * xb [j ];
864
+ Lb [j ] = best_index_iq4nl (values , al );
865
+ }
866
+ }
867
+ }
868
+ sumqx = best_sumqx ; sumq2 = best_sumq2 ;
869
+ for (int iter = 0 ; iter < 32 * block_size ; ++ iter ) {
870
+ float min_step = INFINITY ;
871
+ int best_j = -1 ; int dir = 0 ;
872
+ for (int j = 0 ; j < block_size ; ++ j ) {
873
+ float w = weight [j ];
874
+ float g = d * w * (xb [j ] - d * values [Lb [j ]]);
875
+ if (g > 0 && Lb [j ] < 15 ) {
876
+ float step = (values [Lb [j ]+ 1 ] - values [Lb [j ]])/g ;
877
+ if (step < min_step ) {
878
+ min_step = step ; best_j = j ; dir = 1 ;
879
+ }
880
+ }
881
+ else if (g < 0 && Lb [j ] > 0 ) {
882
+ float step = (values [Lb [j ]- 1 ] - values [Lb [j ]])/g ;
883
+ if (step < min_step ) {
884
+ min_step = step ; best_j = j ; dir = -1 ;
885
+ }
886
+ }
887
+ }
888
+ if (best_j < 0 ) break ;
889
+
890
+ float new_sumqx = sumqx , new_sumq2 = sumq2 ;
891
+ float w = weight [best_j ];
892
+ new_sumqx += w * xb [best_j ]* (values [Lb [best_j ]+ dir ] - values [Lb [best_j ]]);
893
+ new_sumq2 += w * (values [Lb [best_j ]+ dir ]* values [Lb [best_j ]+ dir ] - values [Lb [best_j ]]* values [Lb [best_j ]]);
894
+ if (new_sumq2 > 0 && new_sumqx * new_sumqx > best * new_sumq2 ) {
895
+ sumqx = new_sumqx ; sumq2 = new_sumq2 ;
896
+ d = sumqx /sumq2 ; best = d * sumqx ;
897
+ Lb [best_j ] += dir ;
898
+ }
899
+ else {
900
+ break ;
901
+ }
902
+ }
903
+
904
+ scales [ib ] = d ;
905
+ float abs_d = fabsf (d );
906
+ if (abs_d > amax_scale ) {
907
+ amax_scale = abs_d ; max_scale = d ;
908
+ }
909
+ }
910
+
911
+ if (super_block_size /block_size > 1 ) {
912
+ int nb = super_block_size /block_size ;
913
+ memset (scales_h , 0 , ((nb + 7 )/8 )* sizeof (uint16_t ));
914
+ float d = - max_scale /32 ;
915
+ dh [0 ] = GGML_FP32_TO_FP16 (d );
916
+ float id = d ? 1 /d : 0.f ;
917
+ for (int ib = 0 ; ib < super_block_size /block_size ; ++ ib ) {
918
+ int l = nearest_int (id * scales [ib ]);
919
+ l = MAX (-32 , MIN (31 , l ));
920
+ float dl = d * l ;
921
+ float idl = dl ? 1 /dl : 0.f ;
922
+ uint8_t * Lb = L + ib * block_size ;
923
+ const float * xb = x + ib * block_size ;
924
+ for (int j = 0 ; j < block_size ; ++ j ) {
925
+ Lb [j ] = best_index_iq4nl (values , idl * xb [j ]);
926
+ }
927
+ l += 32 ;
928
+ uint8_t l_l = l & 0xf ;
929
+ uint8_t l_h = l >> 4 ;
930
+ if (ib %2 == 0 ) scales_l [ib /2 ] = l_l ;
931
+ else scales_l [ib /2 ] |= (l_l << 4 );
932
+ scales_h [ib /8 ] |= (l_h << 2 * (ib %8 ));
933
+ }
934
+ } else {
935
+ dh [0 ] = GGML_FP32_TO_FP16 (scales [0 ]);
936
+ if (ntry > 0 ) {
937
+ float id = scales [0 ] ? 1 /scales [0 ] : 0 ;
938
+ for (int j = 0 ; j < super_block_size ; ++ j ) {
939
+ L [j ] = best_index_iq4nl (values , id * x [j ]);
940
+ }
941
+ }
942
+ }
943
+
944
+ for (int i = 0 ; i < super_block_size /32 ; ++ i ) {
945
+ for (int j = 0 ; j < 16 ; ++ j ) {
946
+ q4 [16 * i + j ] = L [32 * i + j ] | (L [32 * i + 16 + j ] << 4 );
947
+ }
948
+ }
949
+ }
950
+
637
951
// ---- Custom experiments ----
638
952
639
953
struct fraction {
@@ -2636,6 +2950,16 @@ void anyrize_qx(const float * x, const float * w, float * v, int ne0, int ne1, i
2636
2950
}
2637
2951
}
2638
2952
2953
+ void anyrize_qxg (const float * x , const float * w , float * v , int ne0 , int ne1 , int nmax ) {
2954
+ int8_t L [ne0 ];
2955
+ for (int i = 0 ; i < ne1 ; ++ i ) {
2956
+ float scale = make_qxg_quants (ne0 , nmax , x + ne0 * i , L , 1 , w ? w + i * ne0 : NULL );
2957
+ for (int j = 0 ; j < ne0 ; ++ j ) {
2958
+ v [i * ne0 + j ] = (L [j ] - nmax ) * scale ;
2959
+ }
2960
+ }
2961
+ }
2962
+
2639
2963
void anyrize_qkxs (const float * x , const float * w , float * v , int ne0 , int ne1 , int nmin , int nmax , bool signed_scale ) {
2640
2964
struct fraction Faux [ne0 * MAX (abs (nmin ), abs (nmax ))];
2641
2965
int8_t L [ne0 ];
@@ -2832,6 +3156,23 @@ void anyrize_iq4nl(const float * x, const float * w, float * v, int ne0, int ne1
2832
3156
}
2833
3157
}
2834
3158
3159
+ void anyrize_iq4nl_g (const float * x , const float * w , float * v , int ne0 , int ne1 ) {
3160
+ uint8_t L [ne0 ];
3161
+ uint8_t Laux [ne0 ];
3162
+ ggml_fp16_t unused_dh ;
3163
+ uint8_t unused_q4 [ne0 ];
3164
+ uint16_t unused_h ;
3165
+ uint8_t * unused_l = NULL ;
3166
+ float weight [ne0 ];
3167
+ for (int i = 0 ; i < ne1 ; ++ i ) {
3168
+ float scale = 0.0f ;
3169
+ quantize_row_iq4_nl_g_impl (ne0 , ne0 , x + i * ne0 , & unused_dh , unused_q4 , & unused_h , unused_l , & scale , weight , L , kvalues_iq4nl , w ? w + i * ne0 : NULL , 7 );
3170
+ for (int j = 0 ; j < ne0 ; ++ j ) {
3171
+ v [i * ne0 + j ] = kvalues_iq4nl [L [j ]] * scale ;
3172
+ }
3173
+ }
3174
+ }
3175
+
2835
3176
void anyrize_qkxs_iq4nl (const float * x , const float * w , float * v , int ne0 , int ne1 ) {
2836
3177
uint8_t L [ne0 ];
2837
3178
uint8_t Laux [ne0 ];
0 commit comments