diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c6bf1b72362bc..5f8c43944dc7c 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -25,6 +25,7 @@ static const std::map LLAMA_FTYPE_MAP = { {"q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S}, {"q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M}, {"q6_K", LLAMA_FTYPE_MOSTLY_Q6_K}, + {"qx_0", LLAMA_FTYPE_MOSTLY_QX_0}, }; bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) { diff --git a/ggml.c b/ggml.c index a13de511527bc..2fff7152387da 100644 --- a/ggml.c +++ b/ggml.c @@ -488,6 +488,44 @@ int64_t ggml_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +// +// bit manipulation helpers +// + +// writes "bit_count" bits of "data" at a "bit_offset" offset in "dst" +// only used for data <= 16bits; only useful to quantize_qx_0 +inline static void write_bits(uint32_t * dst, uint32_t bit_offset, uint16_t data, uint16_t bit_count) { + const uint32_t chunk_size = (sizeof(uint32_t) * 8); + const uint32_t chunk_id = bit_offset / chunk_size; + + dst = dst + chunk_id; + bit_offset %= (sizeof(uint32_t) * 8); + + if (bit_offset + bit_count > chunk_size) { + // first fill the current chunk + uint16_t bitcount_1 = chunk_size - bit_offset; + + uint32_t bitmask = ((1 << bitcount_1) - 1) << (bit_offset); + *dst &= ~bitmask; + *dst |= data << bit_offset; + + // move onto the next chunk + data >>= bitcount_1; + + bit_count -= bitcount_1; + bit_offset = 0; + dst += 1; + + bitmask = ((1 << bit_count) - 1) << (bit_offset); + *dst &= ~bitmask; + *dst |= data << bit_offset; + } else { + uint32_t bitmask = ((1 << bit_count) - 1) << (bit_offset); + *dst &= ~bitmask; + *dst |= data << bit_offset; + } +} + // // quantization // @@ -835,6 +873,25 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); + +// max block size is 256 because some feed_forward tensors have a width of 11008 weights, which is not divisible by 512 +#define QKX_0 256 + +// There is no byte-exact C struct to represent a QX_0 block, but a high-level representation of a block is: +// ggml_fp16_t delta; +// ggml_fp16_t min; +// uint8_t block_metadata; +// [bitstream of weights] + + +// Quantization parameters for QX_0 (used only when running ./quantize, irrelevant during inference) +// Quantization starts at QX_0_STARTING_QBITS bits, and then moves down to QX_0_START_OF_ATTEMPTED_QBITS +// and tries lower and lower bit precisions from there +// TODO maybe move these to commandline arguments...? +#define QX_0_STARTING_QBITS 4 +#define QX_0_START_OF_ATTEMPTED_QBITS 2 + + // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { static const int qk = QK4_0; @@ -1530,6 +1587,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in } } +static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -1627,6 +1685,16 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, }, #endif + // GGML_TYPE_QX_0's quantize/dequantize functions aren't the same as other quantization methods' functions + // so we need to supply NULL instead and use if statements in the places where they are actually used + [GGML_TYPE_QX_0] = { + .dequantize_row_q = (dequantize_row_q_t) NULL, + .quantize_row_q = NULL, + .quantize_row_q_reference = (quantize_row_q_t) NULL, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_qx_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, }; // For internal test use @@ -3122,6 +3190,197 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * #endif } +__attribute__((optimize("unroll-loops"))) +static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + uint32_t nb = n / QKX_0; + GGML_ASSERT(QKX_0 % QK8_0 == 0); + + *s = 0; + + const uint8_t * quant_row = (const uint8_t *) vx; + const block_q8_0 * restrict column = vy; + uint32_t column_i = 0; // current index in column + + // row_data is a buffer which stores dequantized float values for a current block + float f32_row_data[QKX_0]; + + // __AVX2__ doesn't seem to actually make much of a difference, + // a lot of optimizing could possibly be done, including possibly using AVX2 + // for dequantization...? + + #if defined(__AVX2__) + __m256 rolling_sum = _mm256_setzero_ps(); + #endif + + float qvals[1 << 4]; + + for (uint32_t b = 0; b < nb; b++) { + float * row_ptr = f32_row_data; + + const uint64_t * block_start = (const uint64_t *) quant_row; + + const float min_value = GGML_FP16_TO_FP32(*((const uint16_t *) (block_start + (QKX_0 / 64)))); + float mult_value = GGML_FP16_TO_FP32(*((const uint16_t *) (block_start + (QKX_0 / 64)) + 1)); + const uint16_t * data_start = (const uint16_t *) (block_start + (QKX_0 / 64)) + 2; + const uint8_t qbits = *((const uint8_t *) data_start); + data_start = (const uint16_t*) ((const uint8_t*) data_start + 1); + + quant_row = (const uint8_t * ) data_start; + + // Any qbits are supported, but the size of qvals needs to be changed to 1 << max_expected_qbits. + // So if you have at most 7bit values, you can change qvals's declaration to qvals[1 << 7]. + // Additionally, the "fp_chooser == 0" optimized branch only works if qbits is "3" or a power of 2, + // so feel free to disable it entirely and run the slower "else" statement which works for pretty much + // any qbit value. + GGML_ASSERT(qbits <= 4); + + uint32_t offset = 0; + uint8_t data_offset = 0; + + // Cache quantized values + for (int i = 0; i < (1 << qbits); i++) { + qvals[i] = min_value + mult_value * i; + } + + // Parse in sub-blocks of 64 since they are managed by a single uint64_t which decides if a given weight + // is on 16bit or quantized. This means that we can do a fast fp16_indicator == 0 check (i.e. all weights are quantized) + // to speed up peformance + for (int subblock_i = 0; subblock_i < QKX_0 / 64; subblock_i++) { + uint64_t fp16_indicator = block_start[subblock_i]; + + // all weights are quantized in this section; ALSO this ONLY works when qbits is <= 4, since (qbits != 3) simply checks if qbits is a power of 2 + if (fp16_indicator == 0) { + if (qbits == 3) { + // same principle as on the regular data_offset branch, but this time the qbits cross byte boundaries, so we need to manage it by hand + for (int i = 0; i < 5; i++) { + for (int k = 0; k < 11; k ++) { + // here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range + row_ptr[i * 11 + k] = qvals[((((const uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))]; + } + + data_start += 2; // this is the same event as in if (data_start >= 16), but happening twice + data_offset += 1; // it's actually +33, but the "+32" is represented in data_start above, so the remainder is simply +1 + } + + for (int k = 0; k < 9; k ++) { + // here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range + row_ptr[55 + k] = qvals[((((const uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))]; + } + + data_start += 1; + data_offset += 9 * 3 - 16; + + if (data_offset >= 16) { + data_start += 1; + data_offset -= 16; + } + + } else if (data_offset == 0) { + // This only properly works for QBits = power of 2 + const uint8_t data_block_size = 64; + // we can take a full 64bit block + const uint8_t weights_per_u64_data_block = data_block_size / qbits; + const uint8_t num_of_data_blocks_needed = 64 / weights_per_u64_data_block; // because we have 64 qbit-sized weights here + + for (int i = 0; i < num_of_data_blocks_needed; i++) { + for (int k = 0; k < weights_per_u64_data_block; k ++) { + row_ptr[i * weights_per_u64_data_block + k] = qvals[(((const uint64_t *) data_start)[0] >> (k * qbits)) & ((1 << qbits) - 1)]; + } + + data_start += (data_block_size / 8) / sizeof(uint16_t); + } + } else { + // We are doing u32 instead of a simple u64, since data_offset may not be 0 and we need to account for that + const uint8_t data_block_size = 32; + const uint8_t weights_per_u32_data_block = data_block_size / qbits; + const uint8_t num_of_data_blocks_needed = 64 / weights_per_u32_data_block; + + for (int i = 0; i < num_of_data_blocks_needed; i++) { + for (int k = 0; k < weights_per_u32_data_block; k ++) { + // here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range + row_ptr[i * weights_per_u32_data_block + k] = qvals[((((const uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))]; + } + + data_start += (data_block_size / 8) / sizeof(uint16_t); + } + } + + offset += qbits * 64; + } else { + for (int i = 0; i < 64; i++) { + if (fp16_indicator & 1) { + // Current weight is fp16 + offset += 16; + row_ptr[i] = GGML_FP16_TO_FP32((((const uint32_t *) data_start)[0] >> data_offset) & ((1 << 16) - 1)); + + data_start += 1; + } else { + // Current weight is quantized + offset += qbits; + row_ptr[i] = qvals[((((const uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1))]; + + data_offset += qbits; + + if (data_offset >= 16) { + data_start += 1; + data_offset -= 16; + } + } + + // Shift the fp16 indicator to the right, to move to the next weight + fp16_indicator >>= 1; + } + } + + for (int jb = 0; jb < 64 / QK8_0; jb++) { + #if defined(__AVX2__) + __m256 column_multiplier = _mm256_set1_ps(GGML_FP16_TO_FP32(column[column_i].d)); + + for (int i = 0; i < QK8_0/8; i++) { + __m128i test = _mm_loadu_si128((const __m128i *) (column[column_i].qs + i * 8)); + __m256i work = _mm256_cvtepi8_epi32(test); + __m256 workf = _mm256_cvtepi32_ps(work); + + // multiply with our 8 parts of the row at row_data + __m256 row = _mm256_loadu_ps(row_ptr + jb * QK8_0 + i * 8); + + workf = _mm256_mul_ps(workf, row); + rolling_sum = _mm256_fmadd_ps(workf, column_multiplier, rolling_sum); + } + + #else + // scalar + float sub_sum = 0; + + for (int i = 0; i < QK8_0; i++) { + sub_sum += row_ptr[jb * QK8_0 + i] * column[column_i].qs[i]; + } + + sub_sum *= GGML_FP16_TO_FP32(column[column_i].d); + *s += sub_sum; + + #endif + + column_i += 1; + } + + row_ptr += 64; + } + + GGML_ASSERT(offset % 8 == 0); + quant_row += offset / 8; + } + + #if defined(__AVX2__) + float rolling_sum_vec[8]; + _mm256_store_ps(rolling_sum_vec, rolling_sum); + + for (int i = 0; i < 8; i++) { + *s += rolling_sum_vec[i]; + } + #endif +} + static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int qk = QK8_0; const int nb = n / qk; @@ -3514,11 +3773,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q6_K] = QK_K, [GGML_TYPE_Q8_K] = QK_K, #endif + // [GGML_TYPE_QX_0], // QX_0 doesn't have a fixed block size [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3537,11 +3797,12 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q6_K] = sizeof(block_q6_K), [GGML_TYPE_Q8_K] = sizeof(block_q8_K), #endif + // [GGML_TYPE_QX_0], // QX_0 doesn't have a fixed type size [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3559,11 +3820,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_K] = "q5_K", [GGML_TYPE_Q6_K] = "q6_K", [GGML_TYPE_Q8_K] = "q8_K", + [GGML_TYPE_QX_0] = "qx_0", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3580,11 +3842,12 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_K] = true, [GGML_TYPE_Q6_K] = true, [GGML_TYPE_Q8_K] = true, + [GGML_TYPE_QX_0] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "NONE", @@ -3890,6 +4153,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_MOSTLY_QX_0: wtype = GGML_TYPE_QX_0; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -4266,7 +4530,14 @@ struct ggml_tensor * ggml_new_tensor_impl( } result->nb[0] = GGML_TYPE_SIZE[type]; - result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]); + + if (type == GGML_TYPE_QX_0) { + // QX_0 doesn't have a set stride size for a row; that value is stored in the "extra" part of the tensor + result->nb[1] = 0; + } else { + result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]); + } + for (int i = 2; i < GGML_MAX_DIMS; i++) { result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; } @@ -7719,6 +7990,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -8027,6 +8299,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -8154,6 +8427,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: default: { GGML_ASSERT(false); @@ -10189,12 +10463,21 @@ static void ggml_compute_forward_mul_mat_q_f32( const int i2 = i02; const int i3 = i03; - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); + void * src0_row; - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + if (type == GGML_TYPE_QX_0) { + if (ir > 0) { + src0_row = (void *) ((char *) src0->data + ((uint64_t *) src0->extra)[ir - 1]); + } else { + src0_row = (void *) ((char *) src0->data); + } + } else { + src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + assert(ne00 % 32 == 0); + } - assert(ne00 % 32 == 0); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); for (int64_t ic = 0; ic < ne11; ++ic) { vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); @@ -10231,6 +10514,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -10419,6 +10703,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: default: { GGML_ASSERT(false); @@ -10589,6 +10874,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11141,6 +11427,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: + case GGML_TYPE_QX_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -11218,6 +11505,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: + case GGML_TYPE_QX_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -16236,7 +16524,276 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * return (n/QK8_0*sizeof(block_q8_0)); } -size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { +size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width) { + assert(n % QKX_0 == 0); + assert(tensor_width % QKX_0 == 0); + const int nb = n / QKX_0; + + uint8_t * dst_8 = dst; + uint64_t dst_offset = 0; + + // define max quantization errors for every bit precision + // i.e max_quantization_errors[1] holds max error for 1bit quantized weights + // max_quantization_errors[2] holds max error for 2bit quantized weights + // max_quantization_errors[3] holds max error for 3bit quantized weights + // etc. + // + // max quantization error here means that every single quantized weight is within + // said value (e.g. 0.004) from its original value + // + // this can be replaced with a max allowed RMSE, a set percentage of weights being within + // a certain range, etc... The current implementation here is pretty much just an example + float max_quantization_errors[5] = {0, 0.004, 0.004, 0, 0.004}; + + + // How maximum quantization error is implemented here: + // + // Each block holds both fp16 and "qbit" quantized weights mixed together arbitrarily. + // This mixing is handled by a few numbers at the start of each block, the bit of each number + // indicating if a given weight (corresponding to that bit) is stored on 16bit or is quantized. + // + // There is a metadata byte which indicates the qbit precision of the current block, and + // its values are in [1,2,3,4], but this can easily be extended to allow any other bit precisions, + // such as 5, 6, 9, 13 bits or anything else. + // + // To guarantee that each weight is within max_quantization_error, we first need to look at what range + // of values this allows us to have. Since we have "qbits" bits, then we have (1 << qbits) possible values + // the quantized weights can take. The maximum distance between two quantized points can be "2 * max_quantization_error" + // since any weight situated within these two points will be <= max_quantization_error of its closest point. + // + // A visual 2bit example would be: -->|<---->|<---->|<---->|<-- + // Where "|" are the quantized points, and "-->" represents max_quantization_error on the number line. + // + // Any value outside this range will have to be kept on 16bit, since it cannot be within max_quantization_error + // of its quantized point. + // + // + // Note: Each block is kept byte-aligned for simplicity, which means that the number of 16bit weights and qbit weights + // in the bitstream has to be balanced such that the total number of bits is divisible by 8. + // e.g. If we have 3 4bit values and 253 16bit values, we will need to revert a 4bit value to 16bit in order + // to keep the total number of bits divisble by 8. If we were to quantize a weight instead, we would lose + // the "max_quantization_error" guarantee. However, each block doesn't need to remain byte-aligned, the requirement + // only holds for each row, so a big potential improvement could be made here, since we have quite a few unnecessary + // 16bit weights. + + for (int i = 0; i < nb; i++) { + // each 64bit value holds binary data of whether the current weight (corresponding to a specific bit) + // is stored on 16bit or is quantized. "QKX_0 / 64" is here since we need multiple 64bit numbers if + // the QX_0 block is larger than 64 weights. + uint64_t fp16_indicators[QKX_0 / 64]; + memset(fp16_indicators, 0, sizeof(uint64_t) * (QKX_0 / 64)); + + uint8_t qbits = QX_0_STARTING_QBITS; + float thresh = max_quantization_errors[qbits] * (1 << qbits); + + int fp16_count = 0; + + for (int j = 0; j < QKX_0; j++) { + float x = src[i * QKX_0 + j]; + + if (fabsf(x) > thresh) { + // store this value on 16bits + fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64); + fp16_count += 1; + } + } + + uint16_t total_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits; + + while ((total_bits % 8) != 0) { + total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one (needed for a block's byte alignment) + } + + float min_value = -(max_quantization_errors[qbits] * ((1 << qbits) - 1)); + float mult_range = 2 * max_quantization_errors[qbits]; + + // The quantizer starts at a QX_0_STARTING_QBITS quantized block (e.g. 4bits), but then + // attempts to move to a lower precision defined by QX_0_START_OF_ATTEMPTED_QBITS. + // It keeps looking to see if 3, 2 or 1 bit precision leads to a smaller file size. + // + // The decrease in precision does not always lead to a smaller file when we need to maintain + // a fixed max quantization error, since lower bits mean a smaller value range, which might lead + // to more values being moved to 16bits, which might in the end actually increase our block's size. + // + // If values are very close to the mean, then a lower precision is more advantageous since we don't + // need a large quantization range, but otherwise it's likely more beneficial to stay at a higher precision. + // The loop below calculates this ideal trade-off for us! + + for (uint8_t test_qbit = QX_0_START_OF_ATTEMPTED_QBITS; test_qbit >= 1; test_qbit--) { + // calculate the mean of non-fp16 values and define that as the center of the quantization range + float mean = 0; + for (int j = 0; j < QKX_0; j++) { + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float x_fp32 = src[i * QKX_0 + j]; + mean += x_fp32; + } + } + mean /= (QKX_0 - fp16_count); + + uint16_t total_fp16s_in_test_qbit = 0; + thresh = max_quantization_errors[test_qbit] * (1 << test_qbit); + + for (int j = 0; j < QKX_0; j++) { + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float x = src[i * QKX_0 + j]; + + // new outlier found for our current qbit + if (x < mean - thresh || x > mean + thresh) { + total_fp16s_in_test_qbit += 1; + } + } else { + total_fp16s_in_test_qbit += 1; + } + } + + uint16_t total_bits_in_test_qbit = total_fp16s_in_test_qbit * 16 + test_qbit * (QKX_0 - total_fp16s_in_test_qbit); + while ((total_bits_in_test_qbit % 8) != 0) { + total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one + } + + if (total_bits_in_test_qbit < total_bits) { + total_bits = total_bits_in_test_qbit; + qbits = test_qbit; + + min_value = mean - (max_quantization_errors[test_qbit] * ((1 << qbits) - 1)); + mult_range = 2 * max_quantization_errors[test_qbit]; + + for (int j = 0; j < QKX_0; j++) { + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float x = src[i * QKX_0 + j]; + + // mark outlier as stored on 16bit + if (x < mean - thresh || x > mean + thresh) { + fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64); + fp16_count += 1; + } + } + } + + uint16_t total_test_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits; + while ((total_test_bits % 8) != 0) { + total_test_bits += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one + } + + GGML_ASSERT(total_bits == total_test_bits); + } + } + + // keep converting the largest qbit values to fp16 until the block is byte-aligned + while (((QKX_0 - fp16_count) * qbits) % 8 != 0) { + float maxi = 0; + int target = -1; + + for (int j = 0; j < QKX_0; j++) { + float x = src[i * QKX_0 + j]; + + // weight is not on 16bit + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float diff = fabsf(x); + if (diff > maxi || target == -1) { + maxi = diff; + target = j; + } + } + } + + GGML_ASSERT(target != -1); + fp16_indicators[target / 64] |= (uint64_t) 1 << (target % 64); + fp16_count += 1; + } + + // store the current byte-offset of the current row, if "i" indicates that this is the first + // block of a row + if (((i * QKX_0) % tensor_width == 0) && i != 0) { + uint32_t row = (i * QKX_0) / tensor_width; + extra_data[row - 1] = dst_offset; + } + + // write the fp16 indicators to dst + uint64_t * stored_fp16_indicators = (uint64_t *) (dst_8 + dst_offset); + + for (int j = 0; j < QKX_0 / 64; j++) { + stored_fp16_indicators[j] = fp16_indicators[j]; + } + + dst_offset += (QKX_0 / 64) * sizeof(uint64_t); + + // Each weight is stored as min_value + mult * quantized_weight + // Similar to Zero-point quantization, or Q4_1 + + // Write min value and multiplier to dst + *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(min_value); + dst_offset += sizeof(uint16_t); + + *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(mult_range); + dst_offset += sizeof(uint16_t); + + // Store the "metadata" byte (for now it's just "qbits") + *((uint8_t*) (dst_8 + dst_offset)) = qbits; + dst_offset += sizeof(uint8_t); + + + // Store the quantization pivots / points + // IMPORTANT: Change qvals's size depending on the maximum qbits expected + GGML_ASSERT(qbits <= 8); + float qvals[1 << 8]; + + for (int j = 0; j < (1 << qbits); j++) { + qvals[j] = min_value + (mult_range * j); + } + + uint64_t bit_offset = 0; + uint32_t * data = (uint32_t*) (dst_8 + dst_offset); + + int fp16_count_chk = 0; + + for (int j = 0; j < QKX_0; j++) { + float x = src[i * QKX_0 + j]; + + if (fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) { + ggml_fp16_t x_f16 = ggml_fp32_to_fp16(x); + + // store the full fp16 weight + write_bits(data, bit_offset, x_f16, 16); + bit_offset += 16; + fp16_count_chk += 1; + } else { + uint8_t q = 0; + float min_dist = fabsf(x - qvals[0]); + + // find closest quantization point + for (int iv = 0; iv < (1 << qbits); iv++) { + float dist = fabsf(x - qvals[iv]); + if (dist < min_dist) { + q = iv; + min_dist = dist; + } + } + + write_bits(data, bit_offset, q, qbits); + bit_offset += qbits; + } + } + + // check that the reported fp16_count is coherent with the bits stored in fp16_indicators + GGML_ASSERT(fp16_count == fp16_count_chk); + + // check that the number of bits from quantized values is divisible by 8 + GGML_ASSERT((((QKX_0 - fp16_count) * qbits) % 8) == 0); + + dst_offset += ((QKX_0 - fp16_count) * qbits) / 8; + dst_offset += fp16_count * 2; + } + + // store the total size of the tensor as the last element of extra_data + extra_data[n / tensor_width - 1] = dst_offset; + + return dst_offset; +} + +// Pass in additional information such as the tensor's "extra_data" and width, since QX_0 needs this info. We can't pass in a pointer to +// a ggml_tensor (since none exists where quantize_chunk is created), nor to llama_load_tensor since ggml.c doesn't have access to the struct +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width) { size_t result = 0; switch (type) { case GGML_TYPE_Q4_0: @@ -16301,6 +16858,10 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i result = ggml_quantize_q6_K(src + start, block, n, n, hist); } break; #endif + case GGML_TYPE_QX_0: + { + result = ggml_quantize_qx_0(src, dst, n, hist, extra_data, tensor_width); + } break; default: assert(false); } diff --git a/ggml.h b/ggml.h index 1b26da3adca74..a474e3e8cf382 100644 --- a/ggml.h +++ b/ggml.h @@ -248,6 +248,7 @@ extern "C" { GGML_TYPE_Q5_K = 13, GGML_TYPE_Q6_K = 14, GGML_TYPE_Q8_K = 15, + GGML_TYPE_QX_0 = 16, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -276,6 +277,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + GGML_FTYPE_MOSTLY_QX_0 = 15, // except 1d tensors }; // available tensor operations: @@ -1135,13 +1137,14 @@ extern "C" { // quantization // + GGML_API size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width); GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width); // // system info diff --git a/llama.cpp b/llama.cpp index f0f9124d8dafd..105e3e169742f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -342,8 +342,11 @@ struct llama_load_tensor_shard { enum ggml_type type; size_t file_idx; size_t file_off; + size_t extra_data_file_off; void calc_size() { + // For QX_0, the size is manually written-in, since it comes from extra_data + GGML_ASSERT(type != GGML_TYPE_QX_0); size = llama_calc_tensor_size(ne, type); } }; @@ -364,6 +367,7 @@ struct llama_load_tensor { size_t size; struct ggml_tensor * ggml_tensor = NULL; uint8_t * data; + uint64_t * extra_data = NULL; llama_load_tensor(const std::string & name) : name(name) {} @@ -424,7 +428,18 @@ struct llama_load_tensor { } void calc_size() { - size = llama_calc_tensor_size(ne, type); + // For QX_0 the size comes from extra_data, but since extra_data might not be initialized here + // we can take it from the shard instead + if (type == GGML_TYPE_QX_0) { + GGML_ASSERT(shards.size() == 1); + GGML_ASSERT(ne.size() == 2); + + size = shards.at(0).size; + + GGML_ASSERT(size != 0); + } else { + size = llama_calc_tensor_size(ne, type); + } } }; @@ -520,6 +535,7 @@ struct llama_file_loader { shard.ne.resize(n_dims); file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims); std::string name = file.read_string(name_len); + if (n_dims < 1 || n_dims > 2) { throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims)); } @@ -536,6 +552,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: break; default: { throw std::runtime_error(format("unrecognized tensor type %u\n", shard.type)); @@ -546,12 +563,38 @@ struct llama_file_loader { // skip to the next multiple of 32 bytes file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); } + + if (shard.type == GGML_TYPE_QX_0) { + shard.extra_data_file_off = file.tell(); + + // seek until before the last element of extra_data + file.seek(sizeof(uint64_t) * (shard.ne[1] - 1), SEEK_CUR); + + // get the tensor's size from here + uint64_t tensor_size = 0; + file.read_raw(&tensor_size, sizeof(uint64_t)); + shard.size = tensor_size; + + // realign, just in case extra_data isn't a multiple of 32B + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } else { + shard.extra_data_file_off = 0; + } + shard.file_idx = file_idx; shard.file_off = file.tell(); - shard.calc_size(); + if (shard.type != GGML_TYPE_QX_0) { + shard.calc_size(); + } + file.seek(shard.size, SEEK_CUR); + // QX_0's data may not be 32-byte aligned + if (shard.type == GGML_TYPE_QX_0) { + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } + auto it = tensors_map.name_to_idx.find(name); size_t idx; if (it != tensors_map.name_to_idx.end()) { @@ -602,7 +645,9 @@ struct llama_file_saver { file.write_raw(&token_score.score, sizeof(token_score.score)); } } - void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { + + // pass extra_data by reference to avoid excessive copying + void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size, llama_buffer & extra_data) { switch (new_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -616,6 +661,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: break; default: LLAMA_ASSERT(false); } @@ -624,9 +670,29 @@ struct llama_file_saver { file.write_u32(new_type); file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size()); file.write_raw(tensor.name.data(), tensor.name.size()); + + size_t tensor_size; + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); - LLAMA_ASSERT(new_size == llama_calc_tensor_size(tensor.ne, new_type)); + + // The tensor's size for QX_0 is stored in the last element of extra_data + if (new_type == GGML_TYPE_QX_0) { + file.write_raw(extra_data.addr, sizeof(uint64_t) * tensor.ne[1]); + tensor_size = ((uint64_t *) extra_data.addr)[tensor.ne[1] - 1]; + + // realign, just in case extra_data isn't a multiple of 32B + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } else { + tensor_size = llama_calc_tensor_size(tensor.ne, new_type); + } + + LLAMA_ASSERT(new_size == tensor_size); file.write_raw(new_data, new_size); + + // QX_0 data may not be 32-byte aligned + if (new_type == GGML_TYPE_QX_0) { + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } } }; @@ -666,7 +732,7 @@ struct llama_model_loader { bool alignment_prevents_mmap() { for (const llama_load_tensor & lt : tensors_map.tensors) { for (const llama_load_tensor_shard & shard : lt.shards) { - if (shard.file_off & 3) { + if ((shard.file_off & 3)) { return true; } } @@ -725,6 +791,7 @@ struct llama_model_loader { tensor->backend = backend; lt.ggml_tensor = tensor; num_ggml_tensors_created++; + return tensor; } @@ -771,6 +838,13 @@ struct llama_model_loader { switch(lt.ggml_tensor->backend) { case GGML_BACKEND_CPU: lt.ggml_tensor->data = lt.data; + + if (lt.type == GGML_TYPE_QX_0) { + // QX_0 uses the extra field to store byte offsets in *data for each row except row 0 + // (so extra[0] stores where row 1 starts, extra[1] is for row 2, and the last element + // in extra stores the total tensor size) + lt.ggml_tensor->extra = lt.extra_data; + } if (use_mmap && lmlock) { lock_size += lt.size; lmlock->grow_to(lock_size); @@ -801,9 +875,17 @@ struct llama_model_loader { } void load_data_for(llama_load_tensor & lt) { + // QX_0 only supports mmap + GGML_ASSERT(use_mmap || lt.type != GGML_TYPE_QX_0); + if (use_mmap) { LLAMA_ASSERT(lt.shards.size() == 1); lt.data = (uint8_t *) mapping->addr + lt.shards.at(0).file_off; + + if (lt.shards.at(0).extra_data_file_off != 0) { + lt.extra_data = (uint64_t *) ((uint8_t *) mapping->addr + lt.shards.at(0).extra_data_file_off); + } + } else if (lt.split_type == SPLIT_NONE) { llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file; file.seek(lt.shards.at(0).file_off, SEEK_SET); @@ -988,6 +1070,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; + case LLAMA_FTYPE_MOSTLY_QX_0: return "mostly QX_0"; default: return "unknown, may not work"; } } @@ -1665,6 +1748,8 @@ static bool llama_eval_internal( lctx.n_p_eval += N; } + // fprintf(stderr, "\nmodel eval time: %ldms\n", (ggml_time_us() - t_start_us) / 1000); + // fflush(stderr); return true; } @@ -2309,6 +2394,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_K_S: case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; + case LLAMA_FTYPE_MOSTLY_QX_0: quantized_type = GGML_TYPE_QX_0; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -2316,6 +2402,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } + // multithreaded QX_0 quantization is not compatible with the current multithreaded quantization impl. + // because, since blocks have an unknown size in bytes, we cannot section the output data in exact + // chunks assigned to 1 thread. Multithreading would technically only be possible if we quantize + // multiple entire tensors at once, but the overall implementation doesn't seem to allow that to be done easily + if (quantized_type == GGML_TYPE_QX_0) { + nthread = 1; + printf("Setting nthread to 1 due to the implementation for QX_0 quantization being single-threaded.\n"); + } + std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false, /*vocab_only*/ false)); llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), params->ftype); @@ -2363,12 +2458,23 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (!params->quantize_output_tensor && tensor.name == "output.weight") { quantize = false; } + + // Allow only attention and FFN matrices to be quantized under QX_0, since they only require vec_dot + // to be implemented. Output weights and other matrices require more fuctions to be implemented, so + // for simplicity we'll only quantize attn and ffn for now. + if (quantized_type == GGML_TYPE_QX_0) { + if (tensor.name.find("attention") == std::string::npos && tensor.name.find("feed_forward") == std::string::npos) { + quantize = false; + } + } + quantize = quantize && quantized_type != tensor.type; enum ggml_type new_type; void * new_data; size_t new_size; llama_buffer work; + llama_buffer extra_data; if (!quantize) { new_type = tensor.type; @@ -2421,11 +2527,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_data = work.addr; std::vector hist_cur(1 << 4, 0); + if (new_type == GGML_TYPE_QX_0) { + extra_data.resize(sizeof(uint64_t) * tensor.ne[1]); + } + int chunk_size = 32 * 512; const int nchunk = (nelements + chunk_size - 1)/chunk_size; const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + if (nthread_use < 2) { - new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); + new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data(), (uint64_t *) extra_data.addr, tensor.ne[0]); } else { size_t counter = 0; new_size = 0; @@ -2449,7 +2560,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (local_hist.empty()) { local_hist.resize(hist_cur.size(), 0); } - local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data()); + + // pass in NULL for extra_data, since it's only required for QX_0, which doesn't support quantized threading + local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data(), NULL, 0); } }; if ((int) workers.size() < nthread_use - 1) { @@ -2480,7 +2593,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } total_size_org += tensor.size; total_size_new += new_size; - file_saver.write_tensor(tensor, new_type, new_data, new_size); + file_saver.write_tensor(tensor, new_type, new_data, new_size, extra_data); } printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); diff --git a/llama.h b/llama.h index 7c7fd481cba9c..920779d021401 100644 --- a/llama.h +++ b/llama.h @@ -113,6 +113,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + LLAMA_FTYPE_MOSTLY_QX_0 = 19, // except 1d tensors }; // model quantization parameters