Skip to content

Commit e9c07f7

Browse files
committed
ggml : use 8-bit precision for Q4_1 intermediate results (ARM)
1 parent 7cd5c4a commit e9c07f7

File tree

1 file changed

+130
-187
lines changed

1 file changed

+130
-187
lines changed

ggml.c

Lines changed: 130 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
550550
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
551551
}
552552

553+
inline static int16_t vaddvq_s8(int8x16_t v) {
554+
return
555+
(int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
556+
(int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
557+
(int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
558+
(int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
559+
(int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
560+
(int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
561+
(int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
562+
(int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
563+
}
564+
553565
inline static int32_t vaddvq_s16(int16x8_t v) {
554566
return
555567
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -1535,9 +1547,8 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
15351547
}
15361548
}
15371549

1538-
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
15391550
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1540-
//static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1551+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
15411552
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
15421553

15431554
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
@@ -1552,8 +1563,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
15521563
.dequantize_row_q = dequantize_row_q4_1,
15531564
.quantize_row_q = quantize_row_q4_1,
15541565
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1555-
.quantize_row_q_dot = quantize_row_q4_1,
1556-
.vec_dot_q = ggml_vec_dot_q4_1,
1566+
.quantize_row_q_dot = quantize_row_q8_0,
1567+
.vec_dot_q = ggml_vec_dot_q4_1_q8_0,
15571568
},
15581569
[GGML_TYPE_Q4_2] = {
15591570
.dequantize_row_q = dequantize_row_q4_2,
@@ -2170,189 +2181,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
21702181
*s = sumf;
21712182
}
21722183

2173-
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2174-
const int nb = n / QK4_1;
2175-
2176-
const block_q4_1 * restrict x = vx;
2177-
const block_q4_1 * restrict y = vy;
2178-
2179-
float sumf = 0.0;
2180-
2181-
#if defined(__AVX2__)
2182-
// Initialize accumulator with zeros
2183-
__m256 acc = _mm256_setzero_ps();
2184-
// Accumulator for constant offsets
2185-
float acc_offset = 0.0f;
2186-
2187-
// Main loop
2188-
for (int i = 0; i < nb; ++i) {
2189-
const float * d0 = &x[i].d;
2190-
const float * d1 = &y[i].d;
2191-
2192-
const float * m0 = &x[i].m;
2193-
const float * m1 = &y[i].m;
2194-
2195-
const __m256 d0v = _mm256_broadcast_ss( d0 );
2196-
const __m256 d1v = _mm256_broadcast_ss( d1 );
2197-
const __m256 m0v = _mm256_broadcast_ss( m0 );
2198-
const __m256 m1v = _mm256_broadcast_ss( m1 );
2199-
2200-
// Compute combined scale for the block
2201-
const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
2202-
2203-
// Compute cross scales for the block
2204-
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
2205-
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
2206-
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
2207-
2208-
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2209-
__m256i bx = bytesFromNibbles( x[i].qs );
2210-
__m256i by = bytesFromNibbles( y[i].qs );
2211-
2212-
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
2213-
2214-
// Sign-extend first 16 signed bytes into int16_t
2215-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
2216-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2217-
// Compute products of int16_t integers, add pairwise
2218-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
2219-
2220-
// Sign-extend last 16 signed bytes into int16_t vectors
2221-
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
2222-
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2223-
// Accumulate products of int16_t integers
2224-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
2225-
2226-
// compute sums of unsigned bytes in bx, by in blocks of 8.
2227-
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2228-
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2229-
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2230-
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
2231-
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
2232-
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
2233-
__m256 sums = _mm256_cvtepi32_ps( sumsi );
2234-
2235-
// Convert int32_t to float
2236-
__m256 p = _mm256_cvtepi32_ps( i32 );
2237-
// Apply the scale, and accumulate
2238-
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2239-
acc = _mm256_fmadd_ps( scale_01, p, acc );
2240-
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
2241-
// acc_offset += m0*m1 (for each entry in the block)
2242-
acc_offset += (*m0)*(*m1);
2243-
}
2244-
2245-
// Return horizontal sum of the acc vector
2246-
__m128 res = _mm256_extractf128_ps( acc, 1 );
2247-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2248-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2249-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2250-
2251-
sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2252-
#elif defined(__ARM_NEON)
2253-
float sum00 = 0.0f;
2254-
float sum01 = 0.0f;
2255-
float sum10 = 0.0f;
2256-
float sum11 = 0.0f;
2257-
2258-
for (int i = 0; i < nb; i += 2) {
2259-
const block_q4_1 * restrict x0 = &x[i + 0];
2260-
const block_q4_1 * restrict y0 = &y[i + 0];
2261-
const block_q4_1 * restrict x1 = &x[i + 1];
2262-
const block_q4_1 * restrict y1 = &y[i + 1];
2263-
2264-
const uint8x16_t m4b = vdupq_n_u8(0xf);
2265-
2266-
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2267-
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2268-
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2269-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2270-
2271-
// 4-bit -> 8-bit
2272-
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2273-
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2274-
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2275-
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2276-
2277-
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2278-
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2279-
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2280-
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2281-
2282-
sum00 += x0->m*y0->m;
2283-
sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2284-
sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
2285-
2286-
sum00 += x1->m*y1->m;
2287-
sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2288-
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
2289-
2290-
#if defined(__ARM_FEATURE_DOTPROD)
2291-
// dot product into int32x4_t
2292-
uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2293-
uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2294-
2295-
p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2296-
p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2297-
2298-
sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2299-
sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2300-
#else
2301-
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2302-
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2303-
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2304-
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2305-
2306-
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2307-
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2308-
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2309-
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2310-
2311-
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2312-
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2313-
2314-
const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2315-
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2316-
2317-
const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2318-
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2319-
2320-
sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2321-
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2322-
#endif
2323-
}
2324-
2325-
sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
2326-
#else
2327-
// scalar
2328-
for (int i = 0; i < nb; i++) {
2329-
const float d0 = x[i].d;
2330-
const float d1 = y[i].d;
2331-
2332-
const float m0 = x[i].m;
2333-
const float m1 = y[i].m;
2334-
2335-
const uint8_t * restrict p0 = x[i].qs;
2336-
const uint8_t * restrict p1 = y[i].qs;
2337-
2338-
for (int j = 0; j < QK4_1/2; j++) {
2339-
const uint8_t v0 = p0[j];
2340-
const uint8_t v1 = p1[j];
2341-
2342-
const float f0 = d0*(v0 & 0xf) + m0;
2343-
const float f1 = d0*(v0 >> 4) + m0;
2344-
2345-
const float f2 = d1*(v1 & 0xf) + m1;
2346-
const float f3 = d1*(v1 >> 4) + m1;
2347-
2348-
sumf += f0*f2 + f1*f3;
2349-
}
2350-
}
2351-
#endif
2352-
2353-
*s = sumf;
2354-
}
2355-
23562184
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
23572185
const int nb = n / QK8_0;
23582186

@@ -2549,6 +2377,121 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25492377
*s = sumf;
25502378
}
25512379

2380+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2381+
const int nb = n / QK8_0;
2382+
2383+
assert(n % QK8_0 == 0);
2384+
assert(nb % 2 == 0);
2385+
2386+
const block_q4_1 * restrict x = vx;
2387+
const block_q8_0 * restrict y = vy;
2388+
2389+
float sumf = 0.0;
2390+
2391+
// TODO: add AVX / WASM SIMD / etc
2392+
#if defined(__ARM_NEON)
2393+
float sum00 = 0.0f;
2394+
float sum01 = 0.0f;
2395+
float sum10 = 0.0f;
2396+
float sum11 = 0.0f;
2397+
2398+
for (int i = 0; i < nb; i += 2) {
2399+
const block_q4_1 * restrict x0 = &x[i + 0];
2400+
const block_q4_1 * restrict x1 = &x[i + 1];
2401+
const block_q8_0 * restrict y0 = &y[i + 0];
2402+
const block_q8_0 * restrict y1 = &y[i + 1];
2403+
2404+
const uint8x16_t m4b = vdupq_n_u8(0xf);
2405+
2406+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2407+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2408+
2409+
// 4-bit -> 8-bit
2410+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2411+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2412+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2413+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2414+
2415+
// load y
2416+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
2417+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2418+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
2419+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2420+
2421+
// interleave
2422+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2423+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2424+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2425+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2426+
2427+
// Note: cannot use vaddvq_s8 because it overflows for 8-bit values
2428+
// TODO: is there a better way to do this?
2429+
sum00 += (x0->m*y0->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_0ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0ls))) +
2430+
vaddvq_s16(vmovl_s8(vget_low_s8(v1_0hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0hs))));
2431+
sum01 += (x1->m*y1->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_1ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1ls))) +
2432+
vaddvq_s16(vmovl_s8(vget_low_s8(v1_1hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1hs))));
2433+
2434+
#if defined(__ARM_FEATURE_DOTPROD)
2435+
// dot product into int32x4_t
2436+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2437+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2438+
2439+
sum10 += (x0->d*y0->d)*vaddvq_s32(p_0);
2440+
sum11 += (x1->d*y1->d)*vaddvq_s32(p_1);
2441+
#else
2442+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2443+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2444+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2445+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2446+
2447+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2448+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2449+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2450+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2451+
2452+
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2453+
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2454+
2455+
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2456+
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2457+
2458+
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2459+
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2460+
2461+
sum10 += x0->d*y0->d*vaddvq_s16(p_0);
2462+
sum11 += x1->d*y1->d*vaddvq_s16(p_1);
2463+
#endif
2464+
}
2465+
2466+
sumf = sum00 + sum01 + sum10 + sum11;
2467+
#else
2468+
// scalar
2469+
for (int i = 0; i < nb; i++) {
2470+
const float d0 = x[i].d;
2471+
const float m0 = x[i].m;
2472+
const float d1 = y[i].d;
2473+
2474+
const uint8_t * restrict p0 = x[i].qs;
2475+
const int8_t * restrict p1 = y[i].qs;
2476+
2477+
// TODO: this is very slow ..
2478+
for (int j = 0; j < QK8_0/2; j++) {
2479+
const uint8_t v0 = p0[j];
2480+
2481+
const float f0 = d0*(v0 & 0xf) + m0;
2482+
const float f1 = d0*(v0 >> 4) + m0;
2483+
2484+
const float f2 = d1*p1[2*j + 0];
2485+
const float f3 = d1*p1[2*j + 1];
2486+
2487+
sumf += f0*f2 + f1*f3;
2488+
}
2489+
}
2490+
#endif
2491+
2492+
*s = sumf;
2493+
}
2494+
25522495
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
25532496
const int nb = n / QK8_0;
25542497

0 commit comments

Comments
 (0)