@@ -550,6 +550,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
550
550
(uint16_t )vgetq_lane_u8 (v , 14 ) + (uint16_t )vgetq_lane_u8 (v , 15 );
551
551
}
552
552
553
+ inline static int16_t vaddvq_s8 (int8x16_t v ) {
554
+ return
555
+ (int16_t )vgetq_lane_s8 (v , 0 ) + (int16_t )vgetq_lane_s8 (v , 1 ) +
556
+ (int16_t )vgetq_lane_s8 (v , 2 ) + (int16_t )vgetq_lane_s8 (v , 3 ) +
557
+ (int16_t )vgetq_lane_s8 (v , 4 ) + (int16_t )vgetq_lane_s8 (v , 5 ) +
558
+ (int16_t )vgetq_lane_s8 (v , 6 ) + (int16_t )vgetq_lane_s8 (v , 7 ) +
559
+ (int16_t )vgetq_lane_s8 (v , 8 ) + (int16_t )vgetq_lane_s8 (v , 9 ) +
560
+ (int16_t )vgetq_lane_s8 (v , 10 ) + (int16_t )vgetq_lane_s8 (v , 11 ) +
561
+ (int16_t )vgetq_lane_s8 (v , 12 ) + (int16_t )vgetq_lane_s8 (v , 13 ) +
562
+ (int16_t )vgetq_lane_s8 (v , 14 ) + (int16_t )vgetq_lane_s8 (v , 15 );
563
+ }
564
+
553
565
inline static int32_t vaddvq_s16 (int16x8_t v ) {
554
566
return
555
567
(int32_t )vgetq_lane_s16 (v , 0 ) + (int32_t )vgetq_lane_s16 (v , 1 ) +
@@ -1535,9 +1547,8 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
1535
1547
}
1536
1548
}
1537
1549
1538
- static void ggml_vec_dot_q4_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1539
1550
static void ggml_vec_dot_q4_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1540
- // static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1551
+ static void ggml_vec_dot_q4_1_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1541
1552
static void ggml_vec_dot_q4_2_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy );
1542
1553
1543
1554
static const quantize_fns_t quantize_fns [GGML_TYPE_COUNT ] = {
@@ -1552,8 +1563,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1552
1563
.dequantize_row_q = dequantize_row_q4_1 ,
1553
1564
.quantize_row_q = quantize_row_q4_1 ,
1554
1565
.quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_1_reference ,
1555
- .quantize_row_q_dot = quantize_row_q4_1 ,
1556
- .vec_dot_q = ggml_vec_dot_q4_1 ,
1566
+ .quantize_row_q_dot = quantize_row_q8_0 ,
1567
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_0 ,
1557
1568
},
1558
1569
[GGML_TYPE_Q4_2 ] = {
1559
1570
.dequantize_row_q = dequantize_row_q4_2 ,
@@ -2170,189 +2181,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2170
2181
* s = sumf ;
2171
2182
}
2172
2183
2173
- static void ggml_vec_dot_q4_1 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2174
- const int nb = n / QK4_1 ;
2175
-
2176
- const block_q4_1 * restrict x = vx ;
2177
- const block_q4_1 * restrict y = vy ;
2178
-
2179
- float sumf = 0.0 ;
2180
-
2181
- #if defined(__AVX2__ )
2182
- // Initialize accumulator with zeros
2183
- __m256 acc = _mm256_setzero_ps ();
2184
- // Accumulator for constant offsets
2185
- float acc_offset = 0.0f ;
2186
-
2187
- // Main loop
2188
- for (int i = 0 ; i < nb ; ++ i ) {
2189
- const float * d0 = & x [i ].d ;
2190
- const float * d1 = & y [i ].d ;
2191
-
2192
- const float * m0 = & x [i ].m ;
2193
- const float * m1 = & y [i ].m ;
2194
-
2195
- const __m256 d0v = _mm256_broadcast_ss ( d0 );
2196
- const __m256 d1v = _mm256_broadcast_ss ( d1 );
2197
- const __m256 m0v = _mm256_broadcast_ss ( m0 );
2198
- const __m256 m1v = _mm256_broadcast_ss ( m1 );
2199
-
2200
- // Compute combined scale for the block
2201
- const __m256 scale_01 = _mm256_mul_ps ( d0v , d1v );
2202
-
2203
- // Compute cross scales for the block
2204
- const __m256 scale_0 = _mm256_mul_ps ( d0v , m1v );
2205
- const __m256 scale_1 = _mm256_mul_ps ( m0v , d1v );
2206
- const __m256 cross_scales = _mm256_blend_ps ( scale_0 , scale_1 , 0xAA /* 0b10101010 */ );
2207
-
2208
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2209
- __m256i bx = bytesFromNibbles ( x [i ].qs );
2210
- __m256i by = bytesFromNibbles ( y [i ].qs );
2211
-
2212
- // Now we have a vector with bytes in [ 0 .. 15 ] interval.
2213
-
2214
- // Sign-extend first 16 signed bytes into int16_t
2215
- __m256i x16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( bx ) );
2216
- __m256i y16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2217
- // Compute products of int16_t integers, add pairwise
2218
- __m256i i32 = _mm256_madd_epi16 ( x16 , y16 );
2219
-
2220
- // Sign-extend last 16 signed bytes into int16_t vectors
2221
- __m256i x16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( bx , 1 ) );
2222
- __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2223
- // Accumulate products of int16_t integers
2224
- i32 = _mm256_add_epi32 ( i32 , _mm256_madd_epi16 ( x16_h , y16_h ) );
2225
-
2226
- // compute sums of unsigned bytes in bx, by in blocks of 8.
2227
- // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2228
- // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2229
- // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2230
- __m256i xsumi = _mm256_sad_epu8 ( bx , _mm256_setzero_si256 () );
2231
- __m256i ysumi = _mm256_sad_epu8 ( by , _mm256_setzero_si256 () );
2232
- __m256i sumsi = _mm256_or_si256 ( xsumi , _mm256_slli_si256 ( ysumi , 4 ) );
2233
- __m256 sums = _mm256_cvtepi32_ps ( sumsi );
2234
-
2235
- // Convert int32_t to float
2236
- __m256 p = _mm256_cvtepi32_ps ( i32 );
2237
- // Apply the scale, and accumulate
2238
- // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2239
- acc = _mm256_fmadd_ps ( scale_01 , p , acc );
2240
- acc = _mm256_fmadd_ps ( cross_scales , sums , acc );
2241
- // acc_offset += m0*m1 (for each entry in the block)
2242
- acc_offset += (* m0 )* (* m1 );
2243
- }
2244
-
2245
- // Return horizontal sum of the acc vector
2246
- __m128 res = _mm256_extractf128_ps ( acc , 1 );
2247
- res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2248
- res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2249
- res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2250
-
2251
- sumf = _mm_cvtss_f32 ( res ) + acc_offset * QK4_1 ;
2252
- #elif defined(__ARM_NEON )
2253
- float sum00 = 0.0f ;
2254
- float sum01 = 0.0f ;
2255
- float sum10 = 0.0f ;
2256
- float sum11 = 0.0f ;
2257
-
2258
- for (int i = 0 ; i < nb ; i += 2 ) {
2259
- const block_q4_1 * restrict x0 = & x [i + 0 ];
2260
- const block_q4_1 * restrict y0 = & y [i + 0 ];
2261
- const block_q4_1 * restrict x1 = & x [i + 1 ];
2262
- const block_q4_1 * restrict y1 = & y [i + 1 ];
2263
-
2264
- const uint8x16_t m4b = vdupq_n_u8 (0xf );
2265
-
2266
- const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2267
- const uint8x16_t v1_0 = vld1q_u8 (y0 -> qs );
2268
- const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
2269
- const uint8x16_t v1_1 = vld1q_u8 (y1 -> qs );
2270
-
2271
- // 4-bit -> 8-bit
2272
- const uint8x16_t v0_0l = vandq_u8 (v0_0 , m4b );
2273
- const uint8x16_t v1_0l = vandq_u8 (v1_0 , m4b );
2274
- const uint8x16_t v0_0h = vshrq_n_u8 (v0_0 , 4 );
2275
- const uint8x16_t v1_0h = vshrq_n_u8 (v1_0 , 4 );
2276
-
2277
- const uint8x16_t v0_1l = vandq_u8 (v0_1 , m4b );
2278
- const uint8x16_t v1_1l = vandq_u8 (v1_1 , m4b );
2279
- const uint8x16_t v0_1h = vshrq_n_u8 (v0_1 , 4 );
2280
- const uint8x16_t v1_1h = vshrq_n_u8 (v1_1 , 4 );
2281
-
2282
- sum00 += x0 -> m * y0 -> m ;
2283
- sum01 += y0 -> m * x0 -> d * ((uint16_t )vaddvq_u8 (v0_0l ) + (uint16_t )vaddvq_u8 (v0_0h ));
2284
- sum10 += x0 -> m * y0 -> d * ((uint16_t )vaddvq_u8 (v1_0l ) + (uint16_t )vaddvq_u8 (v1_0h ));
2285
-
2286
- sum00 += x1 -> m * y1 -> m ;
2287
- sum01 += y1 -> m * x1 -> d * ((uint16_t )vaddvq_u8 (v0_1l ) + (uint16_t )vaddvq_u8 (v0_1h ));
2288
- sum10 += x1 -> m * y1 -> d * ((uint16_t )vaddvq_u8 (v1_1l ) + (uint16_t )vaddvq_u8 (v1_1h ));
2289
-
2290
- #if defined(__ARM_FEATURE_DOTPROD )
2291
- // dot product into int32x4_t
2292
- uint32x4_t p_0 = vdotq_u32 (vdupq_n_u32 (0 ), v0_0l , v1_0l );
2293
- uint32x4_t p_1 = vdotq_u32 (vdupq_n_u32 (0 ), v0_1l , v1_1l );
2294
-
2295
- p_0 = vdotq_u32 (p_0 , v0_0h , v1_0h );
2296
- p_1 = vdotq_u32 (p_1 , v0_1h , v1_1h );
2297
-
2298
- sum11 += x0 -> d * y0 -> d * vaddvq_u32 (p_0 );
2299
- sum11 += x1 -> d * y1 -> d * vaddvq_u32 (p_1 );
2300
- #else
2301
- const uint16x8_t pl0l = vmull_u8 (vget_low_u8 (v0_0l ), vget_low_u8 (v1_0l ));
2302
- const uint16x8_t pl0h = vmull_u8 (vget_high_u8 (v0_0l ), vget_high_u8 (v1_0l ));
2303
- const uint16x8_t ph0l = vmull_u8 (vget_low_u8 (v0_0h ), vget_low_u8 (v1_0h ));
2304
- const uint16x8_t ph0h = vmull_u8 (vget_high_u8 (v0_0h ), vget_high_u8 (v1_0h ));
2305
-
2306
- const uint16x8_t pl1l = vmull_u8 (vget_low_u8 (v0_1l ), vget_low_u8 (v1_1l ));
2307
- const uint16x8_t pl1h = vmull_u8 (vget_high_u8 (v0_1l ), vget_high_u8 (v1_1l ));
2308
- const uint16x8_t ph1l = vmull_u8 (vget_low_u8 (v0_1h ), vget_low_u8 (v1_1h ));
2309
- const uint16x8_t ph1h = vmull_u8 (vget_high_u8 (v0_1h ), vget_high_u8 (v1_1h ));
2310
-
2311
- const uint16x8_t pl_0 = vaddq_u16 (pl0l , pl0h );
2312
- const uint16x8_t ph_0 = vaddq_u16 (ph0l , ph0h );
2313
-
2314
- const uint16x8_t pl_1 = vaddq_u16 (pl1l , pl1h );
2315
- const uint16x8_t ph_1 = vaddq_u16 (ph1l , ph1h );
2316
-
2317
- const uint16x8_t p_0 = vaddq_u16 (pl_0 , ph_0 );
2318
- const uint16x8_t p_1 = vaddq_u16 (pl_1 , ph_1 );
2319
-
2320
- sum11 += x0 -> d * y0 -> d * vaddvq_u16 (p_0 );
2321
- sum11 += x1 -> d * y1 -> d * vaddvq_u16 (p_1 );
2322
- #endif
2323
- }
2324
-
2325
- sumf = QK4_1 * sum00 + sum01 + sum10 + sum11 ;
2326
- #else
2327
- // scalar
2328
- for (int i = 0 ; i < nb ; i ++ ) {
2329
- const float d0 = x [i ].d ;
2330
- const float d1 = y [i ].d ;
2331
-
2332
- const float m0 = x [i ].m ;
2333
- const float m1 = y [i ].m ;
2334
-
2335
- const uint8_t * restrict p0 = x [i ].qs ;
2336
- const uint8_t * restrict p1 = y [i ].qs ;
2337
-
2338
- for (int j = 0 ; j < QK4_1 /2 ; j ++ ) {
2339
- const uint8_t v0 = p0 [j ];
2340
- const uint8_t v1 = p1 [j ];
2341
-
2342
- const float f0 = d0 * (v0 & 0xf ) + m0 ;
2343
- const float f1 = d0 * (v0 >> 4 ) + m0 ;
2344
-
2345
- const float f2 = d1 * (v1 & 0xf ) + m1 ;
2346
- const float f3 = d1 * (v1 >> 4 ) + m1 ;
2347
-
2348
- sumf += f0 * f2 + f1 * f3 ;
2349
- }
2350
- }
2351
- #endif
2352
-
2353
- * s = sumf ;
2354
- }
2355
-
2356
2184
static void ggml_vec_dot_q4_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2357
2185
const int nb = n / QK8_0 ;
2358
2186
@@ -2549,6 +2377,121 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2549
2377
* s = sumf ;
2550
2378
}
2551
2379
2380
+ static void ggml_vec_dot_q4_1_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2381
+ const int nb = n / QK8_0 ;
2382
+
2383
+ assert (n % QK8_0 == 0 );
2384
+ assert (nb % 2 == 0 );
2385
+
2386
+ const block_q4_1 * restrict x = vx ;
2387
+ const block_q8_0 * restrict y = vy ;
2388
+
2389
+ float sumf = 0.0 ;
2390
+
2391
+ // TODO: add AVX / WASM SIMD / etc
2392
+ #if defined(__ARM_NEON )
2393
+ float sum00 = 0.0f ;
2394
+ float sum01 = 0.0f ;
2395
+ float sum10 = 0.0f ;
2396
+ float sum11 = 0.0f ;
2397
+
2398
+ for (int i = 0 ; i < nb ; i += 2 ) {
2399
+ const block_q4_1 * restrict x0 = & x [i + 0 ];
2400
+ const block_q4_1 * restrict x1 = & x [i + 1 ];
2401
+ const block_q8_0 * restrict y0 = & y [i + 0 ];
2402
+ const block_q8_0 * restrict y1 = & y [i + 1 ];
2403
+
2404
+ const uint8x16_t m4b = vdupq_n_u8 (0xf );
2405
+
2406
+ const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2407
+ const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
2408
+
2409
+ // 4-bit -> 8-bit
2410
+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2411
+ const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
2412
+ const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
2413
+ const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
2414
+
2415
+ // load y
2416
+ const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
2417
+ const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
2418
+ const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
2419
+ const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
2420
+
2421
+ // interleave
2422
+ const int8x16_t v1_0ls = vuzp1q_s8 (v1_0l , v1_0h );
2423
+ const int8x16_t v1_0hs = vuzp2q_s8 (v1_0l , v1_0h );
2424
+ const int8x16_t v1_1ls = vuzp1q_s8 (v1_1l , v1_1h );
2425
+ const int8x16_t v1_1hs = vuzp2q_s8 (v1_1l , v1_1h );
2426
+
2427
+ // Note: cannot use vaddvq_s8 because it overflows for 8-bit values
2428
+ // TODO: is there a better way to do this?
2429
+ sum00 += (x0 -> m * y0 -> d )* (vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_0ls ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_0ls ))) +
2430
+ vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_0hs ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_0hs ))));
2431
+ sum01 += (x1 -> m * y1 -> d )* (vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_1ls ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_1ls ))) +
2432
+ vaddvq_s16 (vmovl_s8 (vget_low_s8 (v1_1hs ))) + vaddvq_s16 (vmovl_s8 (vget_high_s8 (v1_1hs ))));
2433
+
2434
+ #if defined(__ARM_FEATURE_DOTPROD )
2435
+ // dot product into int32x4_t
2436
+ const int32x4_t p_0 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0ls ), v0_0h , v1_0hs );
2437
+ const int32x4_t p_1 = vdotq_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1ls ), v0_1h , v1_1hs );
2438
+
2439
+ sum10 += (x0 -> d * y0 -> d )* vaddvq_s32 (p_0 );
2440
+ sum11 += (x1 -> d * y1 -> d )* vaddvq_s32 (p_1 );
2441
+ #else
2442
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0l ), vget_low_s8 (v1_0ls ));
2443
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0l ), vget_high_s8 (v1_0ls ));
2444
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0h ), vget_low_s8 (v1_0hs ));
2445
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0h ), vget_high_s8 (v1_0hs ));
2446
+
2447
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1l ), vget_low_s8 (v1_1ls ));
2448
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1l ), vget_high_s8 (v1_1ls ));
2449
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1h ), vget_low_s8 (v1_1hs ));
2450
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1h ), vget_high_s8 (v1_1hs ));
2451
+
2452
+ const int16x8_t pl_0 = vaddq_s16 (pl0l , pl0h );
2453
+ const int16x8_t ph_0 = vaddq_s16 (ph0l , ph0h );
2454
+
2455
+ const int16x8_t pl_1 = vaddq_s16 (pl1l , pl1h );
2456
+ const int16x8_t ph_1 = vaddq_s16 (ph1l , ph1h );
2457
+
2458
+ const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
2459
+ const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
2460
+
2461
+ sum10 += x0 -> d * y0 -> d * vaddvq_s16 (p_0 );
2462
+ sum11 += x1 -> d * y1 -> d * vaddvq_s16 (p_1 );
2463
+ #endif
2464
+ }
2465
+
2466
+ sumf = sum00 + sum01 + sum10 + sum11 ;
2467
+ #else
2468
+ // scalar
2469
+ for (int i = 0 ; i < nb ; i ++ ) {
2470
+ const float d0 = x [i ].d ;
2471
+ const float m0 = x [i ].m ;
2472
+ const float d1 = y [i ].d ;
2473
+
2474
+ const uint8_t * restrict p0 = x [i ].qs ;
2475
+ const int8_t * restrict p1 = y [i ].qs ;
2476
+
2477
+ // TODO: this is very slow ..
2478
+ for (int j = 0 ; j < QK8_0 /2 ; j ++ ) {
2479
+ const uint8_t v0 = p0 [j ];
2480
+
2481
+ const float f0 = d0 * (v0 & 0xf ) + m0 ;
2482
+ const float f1 = d0 * (v0 >> 4 ) + m0 ;
2483
+
2484
+ const float f2 = d1 * p1 [2 * j + 0 ];
2485
+ const float f3 = d1 * p1 [2 * j + 1 ];
2486
+
2487
+ sumf += f0 * f2 + f1 * f3 ;
2488
+ }
2489
+ }
2490
+ #endif
2491
+
2492
+ * s = sumf ;
2493
+ }
2494
+
2552
2495
static void ggml_vec_dot_q4_2_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2553
2496
const int nb = n / QK8_0 ;
2554
2497
0 commit comments