From 31e2f5668ce10fc9c8a17e6c3d6ee6f6ef55382e Mon Sep 17 00:00:00 2001 From: Julia Bruckner Date: Tue, 23 Apr 2024 13:33:05 +0200 Subject: [PATCH 1/6] custom quantization schemas --- quant.cfg | 115 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 quant.cfg diff --git a/quant.cfg b/quant.cfg new file mode 100644 index 0000000000000..f97dbe47e9aeb --- /dev/null +++ b/quant.cfg @@ -0,0 +1,115 @@ +ftype=15 + +blk.12.ffn_down.weight=11 +blk.12.ffn_up.weight=11 + +blk.13.ffn_down.weight=11 +blk.13.ffn_up.weight=11 + +blk.14.ffn_down.weight=11 +blk.14.ffn_up.weight=11 + +blk.15.ffn_down.weight=11 +blk.15.ffn_up.weight=11 + +blk.16.ffn_up.weight=10 +blk.17.ffn_up.weight=10 +blk.18.ffn_up.weight=10 +blk.19.ffn_up.weight=10 +blk.20.ffn_up.weight=10 +blk.21.ffn_up.weight=10 +blk.22.ffn_up.weight=10 +blk.23.ffn_up.weight=10 +blk.24.ffn_up.weight=10 +blk.25.ffn_up.weight=10 + +blk.16.ffn_down.weight=10 +blk.17.ffn_down.weight=10 +blk.18.ffn_down.weight=10 +blk.19.ffn_down.weight=10 +blk.20.ffn_down.weight=10 +blk.21.ffn_down.weight=10 +blk.22.ffn_down.weight=10 +blk.23.ffn_down.weight=10 +blk.24.ffn_down.weight=10 +blk.25.ffn_down.weight=10 + +blk.26.ffn_down.weight=10 +blk.26.ffn_up.weight=10 + +blk.27.ffn_down.weight=11 +blk.27.ffn_up.weight=11 + +blk.28.ffn_down.weight=11 +blk.28.ffn_up.weight=11 + +blk.29.ffn_down.weight=11 +blk.29.ffn_up.weight=11 + +token_embd.weight=21 +output.weight=21 + +# LLAMA_FTYPE_ALL_F32 = 0, +# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 +# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed +# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed +# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors + +# GGML_TYPE_F32 = 0, +# GGML_TYPE_F16 = 1, +# GGML_TYPE_Q4_0 = 2, +# GGML_TYPE_Q4_1 = 3, +# // GGML_TYPE_Q4_2 = 4, support has been removed +# // GGML_TYPE_Q4_3 = 5, support has been removed +# GGML_TYPE_Q5_0 = 6, +# GGML_TYPE_Q5_1 = 7, +# GGML_TYPE_Q8_0 = 8, +# GGML_TYPE_Q8_1 = 9, +# GGML_TYPE_Q2_K = 10, +# GGML_TYPE_Q3_K = 11, +# GGML_TYPE_Q4_K = 12, +# GGML_TYPE_Q5_K = 13, +# GGML_TYPE_Q6_K = 14, +# GGML_TYPE_Q8_K = 15, +# GGML_TYPE_IQ2_XXS = 16, +# GGML_TYPE_IQ2_XS = 17, +# GGML_TYPE_IQ3_XXS = 18, +# GGML_TYPE_IQ1_S = 19, +# GGML_TYPE_IQ4_NL = 20, +# GGML_TYPE_IQ3_S = 21, +# GGML_TYPE_IQ2_S = 22, +# GGML_TYPE_IQ4_XS = 23, +# GGML_TYPE_I8 = 24, +# GGML_TYPE_I16 = 25, +# GGML_TYPE_I32 = 26, +# GGML_TYPE_I64 = 27, +# GGML_TYPE_F64 = 28, +# GGML_TYPE_IQ1_M = 29, + From dbe6483e7ef564f6900ecae209e3178af877e0c7 Mon Sep 17 00:00:00 2001 From: Julia Bruckner Date: Tue, 23 Apr 2024 13:35:03 +0200 Subject: [PATCH 2/6] custom quantization schemas --- examples/quantize/quantize.cpp | 67 +++++++++++++++++++++++++++++++++- llama.cpp | 28 ++++++++++++-- llama.h | 11 +++++- 3 files changed, 100 insertions(+), 6 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 64cb6db19d004..2c22f84501a7b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -49,11 +49,11 @@ static const std::vector QUANT_OPTIONS = { { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, + { "CUSTOM", LLAMA_FTYPE_CUSTOM, "per-layer scheme from file (quant.cfg)", }, // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, }; - static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) { std::string ftype_str; @@ -247,6 +247,60 @@ static bool parse_kv_override(const char * data, std::vector names; + std::vector types; + + printf("%s: reading custom quantization scheme from %s:\n", __func__, filename.c_str()); + + if (!file.is_open()) { + fprintf(stderr, "%s: failed to open file: '%s'\n", __func__, filename.c_str()); + return false; + } + + while (getline(file, line)) { + // Skip empty lines and comments + if (line.empty() || line[0] == '#') continue; + printf(" %s\n", line.c_str()); + + // default file type + if (line.find("ftype=") == 0) { + int ftype = std::stoi(line.substr(6)); + override.default_ftype = static_cast(ftype); + printf(" default ftype = %i\n", ftype); + continue; + } + + // tensor overrides + size_t pos = line.find('='); + if (pos != std::string::npos) { + std::string name = line.substr(0, pos); + int type = std::stoi(line.substr(pos + 1)); + names.push_back(name); + types.push_back(static_cast(type)); + printf(" %s = %i\n", name.c_str(), type); + } + } + + printf("\n"); + + // allocate memory for names and types + override.names = new const char*[names.size()]; + override.types = new ggml_type[types.size()]; + override.count = names.size(); + + for (size_t i = 0; i < names.size(); ++i) { + override.names[i] = strdup(names[i].c_str()); + override.types[i] = types[i]; + } + + file.close(); + + return true; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -352,13 +406,24 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: missing ftype\n", __func__); return 1; } + if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]); return 1; } + if (ftype_str == "COPY") { params.only_copy = true; } + + if (ftype_str == "CUSTOM") { + params.override_ftype = new llama_model_quantize_ftype_override; + if(!read_custom_quant_config("quant.cfg", *params.override_ftype)) { + fprintf(stderr, "%s: failed to read custom quant config file!\n", __func__); + return 1; + } + } + arg_idx++; } diff --git a/llama.cpp b/llama.cpp index a25d115c1d82a..bf48e38e3ef93 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3610,6 +3610,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + // Custom quantization scheme + case LLAMA_FTYPE_CUSTOM: return "CUSTOM"; + default: return "unknown, may not work"; } } @@ -14195,9 +14198,13 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; - llama_ftype ftype = params->ftype; - switch (params->ftype) { + llama_ftype ftype = + params->override_ftype + ? params->override_ftype->default_ftype + : params->ftype; + + switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; @@ -14279,7 +14286,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // copy the KV pairs from the input file gguf_set_kv (ctx_out, ml.meta); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); - gguf_set_val_u32(ctx_out, "general.file_type", ftype); + gguf_set_val_u32(ctx_out, "general.file_type", params->ftype); // Remove split metadata gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str()); gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str()); @@ -14417,6 +14424,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type = params->output_tensor_type; } + // look up tensor name in type override map, if not found use default + // type as determined by the ftype. + if(params->override_ftype) { + for (uint32_t i = 0; i < params->override_ftype->count; ++i) { + if (strcmp(params->override_ftype->names[i], tensor->name) == 0) { + //LLAMA_LOG_INFO("\n%s: %s %s ---> %s\n", __func__, tensor->name, ggml_type_name(new_type), ggml_type_name(params->override_ftype->types[i])); + new_type = params->override_ftype->types[i]; + break; + } + } + } + // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type; @@ -14886,7 +14905,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.only_copy =*/ false, /*.pure =*/ false, /*.imatrix =*/ nullptr, - /*.kv_overrides =*/ nullptr, + /*.kv_overrides =*/ nullptr, + /*.override_ftype =*/ nullptr }; return result; diff --git a/llama.h b/llama.h index 4effca42cc65d..ea40345b3223f 100644 --- a/llama.h +++ b/llama.h @@ -122,6 +122,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors + LLAMA_FTYPE_CUSTOM = 32, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -278,6 +279,13 @@ extern "C" { void * abort_callback_data; }; + typedef struct llama_model_quantize_ftype_override { + enum llama_ftype default_ftype; // default type if not overriden + uint32_t count; // number of overrides + const char ** names; // tensor names + enum ggml_type * types; // tensor type override + } llama_model_quantize_custom_ftype; + // model quantization parameters typedef struct llama_model_quantize_params { int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() @@ -286,10 +294,11 @@ extern "C" { enum ggml_type token_embedding_type; // itoken embeddings tensor type bool allow_requantize; // allow quantizing non-f32/f16 tensors bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool only_copy; // only copy tensors - ftype,override_ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides + struct llama_model_quantize_ftype_override * override_ftype; // custom quantization scheme } llama_model_quantize_params; // grammar types From 054e73e02155fd2d5df24c950674f04d18084ede Mon Sep 17 00:00:00 2001 From: Julia Bruckner Date: Tue, 23 Apr 2024 13:39:16 +0200 Subject: [PATCH 3/6] fix spaces --- llama.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.h b/llama.h index ea40345b3223f..aaba8e576dc5b 100644 --- a/llama.h +++ b/llama.h @@ -294,7 +294,7 @@ extern "C" { enum ggml_type token_embedding_type; // itoken embeddings tensor type bool allow_requantize; // allow quantizing non-f32/f16 tensors bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype,override_ftype, allow_requantize and quantize_output_tensor are ignored + bool only_copy; // only copy tensors - ftype, override_ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides From 6e09a2650446dc5a331fd90039c4f284c1875ef3 Mon Sep 17 00:00:00 2001 From: Julia Bruckner Date: Wed, 24 Apr 2024 11:17:22 +0200 Subject: [PATCH 4/6] allow wildcards for tensor names --- examples/quantize/quantize.cpp | 1 - llama.cpp | 24 +++++++- quant.cfg | 100 +++++++++++++++++---------------- 3 files changed, 75 insertions(+), 50 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 2c22f84501a7b..6a6892a0564ba 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -263,7 +263,6 @@ static bool read_custom_quant_config(const std::string& filename, llama_model_qu while (getline(file, line)) { // Skip empty lines and comments if (line.empty() || line[0] == '#') continue; - printf(" %s\n", line.c_str()); // default file type if (line.find("ftype=") == 0) { diff --git a/llama.cpp b/llama.cpp index bf48e38e3ef93..c2c5be35d293b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14196,6 +14196,26 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa return new_size; } +static bool match_string(const std::string& str, const std::string& pattern, uint32_t string_index = 0, uint32_t pattern_index = 0) { + // if both index pointers reach the end of str and pattern respectively + if (string_index == str.size() && pattern_index == pattern.size()) { + return true; + } + + // if pattern character is '*', it can match with any sequence of characters. + if (pattern_index < pattern.size() && pattern[pattern_index] == '*') { + // move pattern index by 1 and match rest, or keep string index same and move pattern index + return match_string(str, pattern, string_index, pattern_index + 1) || (string_index < str.size() && match_string(str, pattern, string_index + 1, pattern_index)); + } + + // if current characters match or pattern character is '?' + if (string_index < str.size() && pattern_index < pattern.size() && (str[string_index] == pattern[pattern_index] || pattern[pattern_index] == '?')) { + return match_string(str, pattern, string_index + 1, pattern_index + 1); + } + + return false; +} + static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; @@ -14428,8 +14448,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // type as determined by the ftype. if(params->override_ftype) { for (uint32_t i = 0; i < params->override_ftype->count; ++i) { - if (strcmp(params->override_ftype->names[i], tensor->name) == 0) { - //LLAMA_LOG_INFO("\n%s: %s %s ---> %s\n", __func__, tensor->name, ggml_type_name(new_type), ggml_type_name(params->override_ftype->types[i])); + if (match_string(tensor->name, params->override_ftype->names[i])) { + // printf("\n -----> %s, %s\n", params->override_ftype->names[i], tensor->name); new_type = params->override_ftype->types[i]; break; } diff --git a/quant.cfg b/quant.cfg index f97dbe47e9aeb..1861667165f8e 100644 --- a/quant.cfg +++ b/quant.cfg @@ -1,54 +1,60 @@ -ftype=15 +# this defines the default ftype (the quantization mix code, +# that you pass to quantize if you're not using custom mix). +# tensors that are not overriden below will be quantized +# according to this scheme. -blk.12.ffn_down.weight=11 -blk.12.ffn_up.weight=11 - -blk.13.ffn_down.weight=11 -blk.13.ffn_up.weight=11 - -blk.14.ffn_down.weight=11 -blk.14.ffn_up.weight=11 - -blk.15.ffn_down.weight=11 -blk.15.ffn_up.weight=11 - -blk.16.ffn_up.weight=10 -blk.17.ffn_up.weight=10 -blk.18.ffn_up.weight=10 -blk.19.ffn_up.weight=10 -blk.20.ffn_up.weight=10 -blk.21.ffn_up.weight=10 -blk.22.ffn_up.weight=10 -blk.23.ffn_up.weight=10 -blk.24.ffn_up.weight=10 -blk.25.ffn_up.weight=10 - -blk.16.ffn_down.weight=10 -blk.17.ffn_down.weight=10 -blk.18.ffn_down.weight=10 -blk.19.ffn_down.weight=10 -blk.20.ffn_down.weight=10 -blk.21.ffn_down.weight=10 -blk.22.ffn_down.weight=10 -blk.23.ffn_down.weight=10 -blk.24.ffn_down.weight=10 -blk.25.ffn_down.weight=10 - -blk.26.ffn_down.weight=10 -blk.26.ffn_up.weight=10 - -blk.27.ffn_down.weight=11 -blk.27.ffn_up.weight=11 - -blk.28.ffn_down.weight=11 -blk.28.ffn_up.weight=11 +ftype=7 + +# allowed values: +# LLAMA_FTYPE_ALL_F32 = 0, +# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 +# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed +# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed +# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors -blk.29.ffn_down.weight=11 -blk.29.ffn_up.weight=11 +# this defines an override for tensors with names matching +# a given string. filters are processed in order given, and the +# first matching will be used. +# Wildcards are allowed: +# ? single character +# * multiple characters -token_embd.weight=21 -output.weight=21 +blk.10.ffn_up.weight=7 +blk.1?.ffn_up.weight=10 +blk.2?.ffn_up.weight=10 +blk.1?.attn*=23 +blk.2?.attn*=23 +*down*=14 +*gate*=12 +# allowed values: # LLAMA_FTYPE_ALL_F32 = 0, # LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors # LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors From 238551ed8c8c06f4667a841def3c9bf620cfa814 Mon Sep 17 00:00:00 2001 From: Julia Bruckner Date: Thu, 25 Apr 2024 11:42:09 +0200 Subject: [PATCH 5/6] parse gmml_type and llama_ftype, allow specifiying cfg file --- examples/quantize/quantize.cpp | 79 +++++++++++++----- quant.cfg | 141 +++++++-------------------------- 2 files changed, 86 insertions(+), 134 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 6a6892a0564ba..a33bc915afbd7 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -32,34 +32,55 @@ static const std::vector QUANT_OPTIONS = { { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, - { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, + { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization" , }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", }, { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, - { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, + { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", }, { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", }, - { "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", }, + { "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", }, { "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", }, { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, - { "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", }, - { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, - { "CUSTOM", LLAMA_FTYPE_CUSTOM, "per-layer scheme from file (quant.cfg)", }, + { "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", }, + { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, + { "CUSTOM", LLAMA_FTYPE_CUSTOM, "[:filename] Custom quant config (quant.cfg if not specified", }, // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. - { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, + { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, }; -static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) { +static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out, std::string & custom_cfg_filename_out) { std::string ftype_str; for (auto ch : ftype_str_in) { ftype_str.push_back(std::toupper(ch)); } + + if (ftype_str.find("CUSTOM:") == 0) { + // custom quant mix + ftype = LLAMA_FTYPE_CUSTOM; + ftype_str_out = "CUSTOM"; + if (ftype_str.length() > 7) { + // extract config filename (take from ftype_str_in to get original casing) + std::string custom_cfg = ftype_str_in.substr(7); + custom_cfg_filename_out = custom_cfg; + } else { + return false; + } + return true; + } else if (ftype_str.find("CUSTOM") == 0) { + // custom quant mix with default config + ftype = LLAMA_FTYPE_CUSTOM; + ftype_str_out = "CUSTOM"; + custom_cfg_filename_out = "quant.cfg"; + return true; + } + for (auto & it : QUANT_OPTIONS) { if (it.name == ftype_str) { ftype = it.ftype; @@ -203,7 +224,7 @@ static ggml_type parse_ggml_type(const char * arg) { for (int j = 0; j < GGML_TYPE_COUNT; ++j) { auto type = ggml_type(j); const auto * name = ggml_type_name(type); - if (name && strcmp(arg, name) == 0) { + if (name && strcasecmp(arg, name) == 0) { result = type; break; } } @@ -253,7 +274,7 @@ static bool read_custom_quant_config(const std::string& filename, llama_model_qu std::vector names; std::vector types; - printf("%s: reading custom quantization scheme from %s:\n", __func__, filename.c_str()); + printf("reading custom quantization mix from %s:\n", filename.c_str()); if (!file.is_open()) { fprintf(stderr, "%s: failed to open file: '%s'\n", __func__, filename.c_str()); @@ -261,25 +282,41 @@ static bool read_custom_quant_config(const std::string& filename, llama_model_qu } while (getline(file, line)) { - // Skip empty lines and comments + // skip empty lines and comments if (line.empty() || line[0] == '#') continue; // default file type if (line.find("ftype=") == 0) { - int ftype = std::stoi(line.substr(6)); + std::string ftype_str = line.substr(6); + std::string ftype_name; + std::string custom_quant_config_filename; + llama_ftype ftype; + if(!try_parse_ftype(ftype_str, ftype, ftype_name, custom_quant_config_filename)) { + fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, ftype_str.c_str()); + file.close(); + return false; + } + override.default_ftype = static_cast(ftype); - printf(" default ftype = %i\n", ftype); + printf(" default ftype = %i (%s)\n", ftype, ftype_name.c_str()); continue; } // tensor overrides size_t pos = line.find('='); if (pos != std::string::npos) { - std::string name = line.substr(0, pos); - int type = std::stoi(line.substr(pos + 1)); - names.push_back(name); + std::string tensor_name = line.substr(0, pos); + std::string type_name = line.substr(pos + 1); + ggml_type type = parse_ggml_type(type_name.c_str()); + if(type < 0 || type >= GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, type_name.c_str()); + file.close(); + return false; + } + names.push_back(tensor_name); types.push_back(static_cast(type)); - printf(" %s = %i\n", name.c_str(), type); + printf(" %s = %i (%s)\n", tensor_name.c_str(), type, type_name.c_str()); + } } @@ -383,9 +420,10 @@ int main(int argc, char ** argv) { const std::string fname_inp = argv[arg_idx]; arg_idx++; std::string fname_out; + std::string custom_quant_config_filename; std::string ftype_str; - if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { + if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str, custom_quant_config_filename)) { std::string fpath; const size_t pos = fname_inp.find_last_of("/\\"); if (pos != std::string::npos) { @@ -406,7 +444,7 @@ int main(int argc, char ** argv) { return 1; } - if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { + if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str, custom_quant_config_filename)) { fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]); return 1; } @@ -417,8 +455,7 @@ int main(int argc, char ** argv) { if (ftype_str == "CUSTOM") { params.override_ftype = new llama_model_quantize_ftype_override; - if(!read_custom_quant_config("quant.cfg", *params.override_ftype)) { - fprintf(stderr, "%s: failed to read custom quant config file!\n", __func__); + if(!read_custom_quant_config(custom_quant_config_filename, *params.override_ftype)) { return 1; } } diff --git a/quant.cfg b/quant.cfg index 1861667165f8e..282442b4feed4 100644 --- a/quant.cfg +++ b/quant.cfg @@ -1,121 +1,36 @@ -# this defines the default ftype (the quantization mix code, +# Defines the default ftype (the quantization mix code, # that you pass to quantize if you're not using custom mix). # tensors that are not overriden below will be quantized -# according to this scheme. +# according to this mix. +# +# Must be one of +# Q4_0, Q4_1, Q5_0, Q5_1, IQ2_XXS, IQ2_XS, IQ2_S, IQ2_M, +# IQ1_S, IQ1_M, Q2_K, Q2_K_S, IQ3_XXS, IQ3_S, IQ3_M, Q3_K, +# IQ3_XS, Q3_K_S, Q3_K_M, Q3_K_L, IQ4_NL, IQ4_XS, Q4_K, +# Q4_K_S, Q4_K_M, Q5_K, Q5_K_S, Q5_K_M, Q6_K, Q8_0, F16 -ftype=7 - -# allowed values: -# LLAMA_FTYPE_ALL_F32 = 0, -# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 -# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed -# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed -# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors +ftype=Q6_K -# this defines an override for tensors with names matching -# a given string. filters are processed in order given, and the -# first matching will be used. +# Defines overrides for tensors with names matching a given +# string. Filters are processed in order given, the first +# matching will be used. +# # Wildcards are allowed: # ? single character # * multiple characters +# +# Type must be one of +# F16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, Q2_K, Q3_K, +# Q4_K, Q5_K, Q6_K, Q8_K, IQ2_XXS, IQ2_XS, IQ3_XXS, +# IQ1_S, IQ4_NL, IQ3_S, IQ2_S, IQ4_XS, IQ1_M -blk.10.ffn_up.weight=7 -blk.1?.ffn_up.weight=10 -blk.2?.ffn_up.weight=10 -blk.1?.attn*=23 -blk.2?.attn*=23 -*down*=14 -*gate*=12 - -# allowed values: -# LLAMA_FTYPE_ALL_F32 = 0, -# LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 -# // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed -# // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed -# LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors -# LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors - -# GGML_TYPE_F32 = 0, -# GGML_TYPE_F16 = 1, -# GGML_TYPE_Q4_0 = 2, -# GGML_TYPE_Q4_1 = 3, -# // GGML_TYPE_Q4_2 = 4, support has been removed -# // GGML_TYPE_Q4_3 = 5, support has been removed -# GGML_TYPE_Q5_0 = 6, -# GGML_TYPE_Q5_1 = 7, -# GGML_TYPE_Q8_0 = 8, -# GGML_TYPE_Q8_1 = 9, -# GGML_TYPE_Q2_K = 10, -# GGML_TYPE_Q3_K = 11, -# GGML_TYPE_Q4_K = 12, -# GGML_TYPE_Q5_K = 13, -# GGML_TYPE_Q6_K = 14, -# GGML_TYPE_Q8_K = 15, -# GGML_TYPE_IQ2_XXS = 16, -# GGML_TYPE_IQ2_XS = 17, -# GGML_TYPE_IQ3_XXS = 18, -# GGML_TYPE_IQ1_S = 19, -# GGML_TYPE_IQ4_NL = 20, -# GGML_TYPE_IQ3_S = 21, -# GGML_TYPE_IQ2_S = 22, -# GGML_TYPE_IQ4_XS = 23, -# GGML_TYPE_I8 = 24, -# GGML_TYPE_I16 = 25, -# GGML_TYPE_I32 = 26, -# GGML_TYPE_I64 = 27, -# GGML_TYPE_F64 = 28, -# GGML_TYPE_IQ1_M = 29, - +blk.10.ffn_up.weight=Q5_K +blk.1?.ffn_up.weight=Q4_K +blk.23.*=Q2_K +blk.24.*=Q2_K +blk.25.*=Q2_K +blk.2?.ffn_up.weight=Q4_K +*_gate*=Q4_K +*.attn*=IQ4_XS +*_down*=IQ3_S +output.weight=Q5_K From 20b22433f0cf941c1b43e27c086e2ef71798fd57 Mon Sep 17 00:00:00 2001 From: Brian Date: Thu, 9 May 2024 23:25:54 +1000 Subject: [PATCH 6/6] Update llama.h LLAMA_FTYPE_CUSTOM=33 --- llama.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.h b/llama.h index 84b15b1fc9bd0..23a12ff577c80 100644 --- a/llama.h +++ b/llama.h @@ -140,7 +140,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_CUSTOM = 32, // except 1d tensors + LLAMA_FTYPE_CUSTOM = 33, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };