From 310c7ee62fc184c80746b58b51b347a752af9797 Mon Sep 17 00:00:00 2001 From: ecyht2 Date: Thu, 17 Jul 2025 22:12:39 +0800 Subject: [PATCH] feat: Added user friendly quantize type --- examples/quantize/quantize.cpp | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index ecf09e8..3a4c0ea 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -1,24 +1,25 @@ -#include "tts.h" -#include "args.h" #include #include -#include "ggml.h" +#include #include -std::vector valid_quantization_types = { - GGML_TYPE_F16, - GGML_TYPE_Q4_0, - GGML_TYPE_Q5_0, - GGML_TYPE_Q8_0, +#include "args.h" +#include "ggml.h" +#include "tts.h" + +const std::map valid_quantization_types = { + {"FP16", GGML_TYPE_F16}, + {"Q4_0", GGML_TYPE_Q4_0}, + {"Q5_0", GGML_TYPE_Q5_0}, + {"Q8_0", GGML_TYPE_Q8_0}, }; int main(int argc, const char ** argv) { - int default_quantization = (int) GGML_TYPE_Q4_0; int default_n_threads = std::max((int)std::thread::hardware_concurrency(), 1); arg_list args; args.add_argument(string_arg("--model-path", "(REQUIRED) The local path of the gguf model file for Parler TTS mini v1 to quantize.", "-mp", true)); args.add_argument(string_arg("--quantized-model-path", "(REQUIRED) The path to save the model in a quantized format.", "-qp", true)); - args.add_argument(int_arg("--quantized-type", "(OPTIONAL) The ggml enum of the quantized type to convert compatible model tensors to. For more information see readme. Defaults to Q4_0 quantization (2).", "-qt", false, &default_quantization)); + args.add_argument(string_arg("--quantized-type", "(OPTIONAL) The ggml enum of the quantized type to convert compatible model tensors to. For more information see readme. Defaults to Q4_0 quantization (2).", "-qt", false, "Q4_0")); args.add_argument(int_arg("--n-threads", "(OPTIONAL) The number of cpu threads to run the quantization process with. Defaults to known hardware concurrency.", "-nt", false, &default_n_threads)); args.add_argument(bool_arg("--convert-dac-to-f16", "(OPTIONAL) Whether to convert the DAC audio decoder model to a 16 bit float.", "-df")); args.add_argument(bool_arg("--quantize-output-heads", "(OPTIONAL) Whether to quantize the output heads. Defaults to false and is true when passed (does not accept a parameter).", "-qh")); @@ -31,12 +32,13 @@ int main(int argc, const char ** argv) { return 0; } args.validate(); - enum ggml_type qtype = static_cast(*args.get_int_param("--quantized-type")); - if (std::find(valid_quantization_types.begin(), valid_quantization_types.end(), qtype) == valid_quantization_types.end()) { - fprintf(stderr, "ERROR: %d is not a valid quantization type.\n", qtype); + std::string qtype = args.get_string_param("--quantized-type"); + if (!valid_quantization_types.contains(qtype)) { + fprintf(stderr, "ERROR: %s is not a valid quantization type.\n", + qtype.c_str()); exit(1); } - struct quantization_params * qp = new quantization_params((uint32_t) *args.get_int_param("--n-threads"), qtype); + struct quantization_params * qp = new quantization_params((uint32_t) *args.get_int_param("--n-threads"), valid_quantization_types.at(qtype)); qp->quantize_output_heads = args.get_bool_param("--quantize-output-heads"); qp->quantize_text_embeddings = args.get_bool_param("--quantize-text-embedding"); qp->quantize_cross_attn_kv = args.get_bool_param("--quantize-cross-attn-kv");