@@ -700,7 +700,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
700
700
}
701
701
702
702
// reference implementation for deterministic creation of model files
703
- void quantize_row_q4_0_b16_reference (const float * restrict x, block_q4_0 * restrict y, int64_t k) {
703
+ void quantize_row_q4_0_b16_ref (const float * restrict x, block_q4_0 * restrict y, int64_t k) {
704
704
static const int qk = QK4_0;
705
705
706
706
assert(k % qk == 0);
@@ -738,7 +738,7 @@ void quantize_row_q4_0_b16_reference(const float * restrict x, block_q4_0 * rest
738
738
}
739
739
740
740
void quantize_row_q4_0_b16(const float * restrict x, void * restrict y, int64_t k) {
741
- quantize_row_q4_0_b16_reference (x, y, k);
741
+ quantize_row_q4_0_b16_ref (x, y, k);
742
742
}
743
743
744
744
@@ -1190,6 +1190,132 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
1190
1190
#endif
1191
1191
}
1192
1192
1193
+ void quantize_row_q8_0_b16_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
1194
+ assert(k % QK8_0 == 0);
1195
+ const int nb = k / QK8_0;
1196
+
1197
+ for (int i = 0; i < nb; i++) {
1198
+ float amax = 0.0f; // absolute max
1199
+
1200
+ for (int j = 0; j < QK8_0; j++) {
1201
+ const float v = x[i*QK8_0 + j];
1202
+ amax = MAX(amax, fabsf(v));
1203
+ }
1204
+
1205
+ const float d = amax / ((1 << 7) - 1);
1206
+ const float id = d ? 1.0f/d : 0.0f;
1207
+
1208
+ y[i].d = (GGML_FP32_TO_BF16(d)).bits;
1209
+
1210
+ for (int j = 0; j < QK8_0; ++j) {
1211
+ const float x0 = x[i*QK8_0 + j]*id;
1212
+
1213
+ y[i].qs[j] = roundf(x0);
1214
+ }
1215
+ }
1216
+ }
1217
+
1218
+ void quantize_row_q8_0_b16(const float * restrict x, void * restrict vy, int64_t k) {
1219
+ assert(QK8_0 == 32);
1220
+ assert(k % QK8_0 == 0);
1221
+ const int nb = k / QK8_0;
1222
+
1223
+ block_q8_0 * restrict y = vy;
1224
+
1225
+ #if defined(__AVX2__) || defined(__AVX__)
1226
+ for (int i = 0; i < nb; i++) {
1227
+ // Load elements into 4 AVX vectors
1228
+ __m256 v0 = _mm256_loadu_ps( x );
1229
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1230
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1231
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1232
+ x += 32;
1233
+
1234
+ // Compute max(abs(e)) for the block
1235
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1236
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1237
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1238
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1239
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1240
+
1241
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1242
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1243
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1244
+ const float maxScalar = _mm_cvtss_f32( max4 );
1245
+
1246
+ // Quantize these floats
1247
+ const float d = maxScalar / 127.f;
1248
+
1249
+ y[i].d = (GGML_FP32_TO_BF16(d)).bits;
1250
+
1251
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1252
+ const __m256 mul = _mm256_set1_ps( id );
1253
+
1254
+ // Apply the multiplier
1255
+ v0 = _mm256_mul_ps( v0, mul );
1256
+ v1 = _mm256_mul_ps( v1, mul );
1257
+ v2 = _mm256_mul_ps( v2, mul );
1258
+ v3 = _mm256_mul_ps( v3, mul );
1259
+
1260
+ // Round to nearest integer
1261
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1262
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1263
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1264
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1265
+
1266
+ // Convert floats to integers
1267
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1268
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1269
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1270
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1271
+
1272
+ #if defined(__AVX2__)
1273
+ // Convert int32 to int16
1274
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1275
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1276
+ // Convert int16 to int8
1277
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1278
+
1279
+ // We got our precious signed bytes, but the order is now wrong
1280
+ // These AVX2 pack instructions process 16-byte pieces independently
1281
+ // The following instruction is fixing the order
1282
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1283
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1284
+
1285
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1286
+ #else
1287
+ // Since we don't have in AVX some necessary functions,
1288
+ // we split the registers in half and call AVX2 analogs from SSE
1289
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1290
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1291
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1292
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1293
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1294
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1295
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1296
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1297
+
1298
+ // Convert int32 to int16
1299
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1300
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1301
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1302
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1303
+ // Convert int16 to int8
1304
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1305
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1306
+
1307
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1308
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1309
+ #endif
1310
+ }
1311
+ #else
1312
+ GGML_UNUSED(nb);
1313
+ // scalar
1314
+ quantize_row_q8_0_b16_ref(x, y, k);
1315
+ #endif
1316
+ }
1317
+
1318
+
1193
1319
// reference implementation for deterministic creation of model files
1194
1320
void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
1195
1321
assert(QK8_1 == 32);
@@ -3217,7 +3343,7 @@ static void quantize_row_q4_0_b16_impl(const float * restrict x, block_q4_0 * re
3217
3343
static_assert(QK4_0 == 32, "QK4_0 must be 32");
3218
3344
3219
3345
if (!quant_weights) {
3220
- quantize_row_q4_0_b16_reference (x, y, n_per_row);
3346
+ quantize_row_q4_0_b16_ref (x, y, n_per_row);
3221
3347
return;
3222
3348
}
3223
3349
@@ -3258,7 +3384,7 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nr
3258
3384
3259
3385
size_t quantize_q4_0_b16(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3260
3386
if (!quant_weights) {
3261
- quantize_row_q4_0_b16_reference (src, dst, (int64_t)nrow*n_per_row);
3387
+ quantize_row_q4_0_b16_ref (src, dst, (int64_t)nrow*n_per_row);
3262
3388
return nrow * ggml_row_size(GGML_TYPE_Q4_0_B16, n_per_row);
3263
3389
}
3264
3390
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0_B16, n_per_row);
@@ -3433,7 +3559,7 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
3433
3559
size_t quantize_q8_0_b16(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3434
3560
(void)quant_weights; // not used
3435
3561
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0_B16, n_per_row);
3436
- quantize_row_q8_0_b16_reference (src, dst, (int64_t)nrow*n_per_row);
3562
+ quantize_row_q8_0_b16_ref (src, dst, (int64_t)nrow*n_per_row);
3437
3563
return nrow * row_size;
3438
3564
}
3439
3565
0 commit comments