diff --git a/ggml.c b/ggml.c index db8babbf71fcd..a081207c1a5ea 100644 --- a/ggml.c +++ b/ggml.c @@ -583,7 +583,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc typedef struct { float d; // delta - uint8_t qs[QK]; // nibbles / quants + int8_t qs[QK]; // quants } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding"); @@ -1060,9 +1060,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r for (int l = 0; l < QK; ++l) { const float v = x[i*QK + l]*id; - const uint8_t vi = (int8_t)roundf(v) + 128; - - y[i].qs[l] = vi; + y[i].qs[l] = roundf(v); } } } @@ -1095,8 +1093,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int for (int l = 0; l < 8; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); - const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(128.5f)); - const int32x4_t vi = vcvtq_s32_f32(vf); + const int32x4_t vi = vcvtnq_s32_f32(v); y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0); y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); @@ -1104,6 +1101,90 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); } } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } #else // scalar quantize_row_q8_0_reference(x, y, k); @@ -2508,7 +2589,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const uint8x16_t m4b = vdupq_n_u8(0xf); const int8x16_t s8b = vdupq_n_s8(0x8); - const uint8x16_t u128b = vdupq_n_u8(128); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2526,21 +2606,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); // load y - const uint8x16_t v1_0l = vld1q_u8(y0->qs); - const uint8x16_t v1_0h = vld1q_u8(y0->qs + 16); - const uint8x16_t v1_1l = vld1q_u8(y1->qs); - const uint8x16_t v1_1h = vld1q_u8(y1->qs + 16); + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); // interleave - const uint8x16_t v1_0lz = vuzp1q_u8(v1_0l, v1_0h); - const uint8x16_t v1_0hz = vuzp2q_u8(v1_0l, v1_0h); - const uint8x16_t v1_1lz = vuzp1q_u8(v1_1l, v1_1h); - const uint8x16_t v1_1hz = vuzp2q_u8(v1_1l, v1_1h); - - const int8x16_t v1_0ls = vreinterpretq_s8_u8(vsubq_u8(v1_0lz, u128b)); - const int8x16_t v1_0hs = vreinterpretq_s8_u8(vsubq_u8(v1_0hz, u128b)); - const int8x16_t v1_1ls = vreinterpretq_s8_u8(vsubq_u8(v1_1lz, u128b)); - const int8x16_t v1_1hs = vreinterpretq_s8_u8(vsubq_u8(v1_1hz, u128b)); + const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h); + const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h); + const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); + const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t @@ -2578,6 +2653,94 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * } sumf = sum0 + sum1; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + + __m256i bx = bytesFromNibbles(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(bx, bx); + + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(by, bx); + + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + + const __m256i ones = _mm256_set1_epi16(1); + __m256i xy_q = _mm256_madd_epi16(ones, dot); + + /* Convert to vectore of 8 int32_t to 8 floats */ + __m256 q = _mm256_cvtepi32_ps( xy_q ); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + + __m128i i32[2]; + for (int j = 0; j < 2; ++j) { + // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes + __m128i bx = bytesFromNibbles( x[i].qs + 8*j ); + __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j)); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m128i off = _mm_set1_epi8( 8 ); + bx = _mm_sub_epi8( bx, off ); + + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(bx, bx); + + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(by, bx); + + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + + const __m128i ones = _mm_set1_epi16(1); + i32[j] = _mm_madd_epi16(ones, dot); + } + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] )); + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ); #else // scalar for (int i = 0; i < nb; i++) { @@ -2585,7 +2748,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const float d1 = y[i].d; const uint8_t * restrict p0 = x[i].qs; - const uint8_t * restrict p1 = y[i].qs; + const int8_t * restrict p1 = y[i].qs; int sumi = 0; for (int j = 0; j < QK/2; j++) { @@ -2594,10 +2757,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int i0 = (int8_t) (v0 & 0xf) - 8; const int i1 = (int8_t) (v0 >> 4) - 8; - const int i2 = (int) p1[2*j + 0] - 128; - const int i3 = (int) p1[2*j + 1] - 128; - - /*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/ + const int i2 = p1[2*j + 0]; + const int i3 = p1[2*j + 1]; sumi += i0*i2 + i1*i3; } @@ -9923,7 +10084,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); } else #endif - cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + { + cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + } } else { GGML_ASSERT(false); }