diff --git a/CMakeLists.txt b/CMakeLists.txt index 8fe3267..1661c33 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.14) project("tts.cpp" C CXX) include(CheckIncludeFileCXX) -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) diff --git a/README.md b/README.md index 6d475d6..5ac0df6 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Additional Model support will initially be added based on open source model perf #### Requirements: * Local GGUF format model file (see [py-gguf](./py-ggufs/README.md) for information on how to convert the hugging face models to GGUF). -* C++17 and C17 +* C++23 and C11 * XCode Command Line Tools (via `xcode-select --install`) should suffice for OS X * CMake (>=3.14) * GGML pulled locally @@ -60,7 +60,7 @@ We are currently [working on upstreaming some of these operations inorder to dep #### Build: Assuming that the above requirements are met the library and basic CLI example can be built by running the following command in the repository's base directory: -```commandline +```bash cmake -B build cmake --build build --config Release ``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 92cb14f..9a3171e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,12 +1,19 @@ # examples -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +add_library(examples_common + args.cpp + args.h + args_common.cpp + args_common.h + audio_file.h +) +target_include_directories(examples_common PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(examples_common PUBLIC ggml tts) -if (EMSCRIPTEN) -else() +if (NOT EMSCRIPTEN) add_subdirectory(cli) add_subdirectory(perf_battery) add_subdirectory(quantize) add_subdirectory(server) add_subdirectory(phonemize) -endif() +endif () diff --git a/examples/args.cpp b/examples/args.cpp new file mode 100644 index 0000000..d141994 --- /dev/null +++ b/examples/args.cpp @@ -0,0 +1,76 @@ +#include "args.h" + +#include +#include + +void arg::print_help() const { + cout << "--" << full_name; + if (*abbreviation) { + cout << " (-" << abbreviation << ")"; + } + if (*description) { + cout << (required ? ":\n (REQUIRED) " : ":\n (OPTIONAL) ") << description << ".\n"; + } else { + cout << (required ? " is a required parameter.\n" : " is an optional parameter.\n"); + } +} + +void arg::parse(span & argv) { + required = false; + if (const auto bool_param{get_if(&value)}) { + *bool_param = true; + return; + } + if (argv.empty()) { + fprintf(stderr, "The option '--%s' requires an argument\n", full_name); + exit(1); + } + const str a = argv[0]; + argv = argv.subspan(1); + if (const auto string_param{get_if(&value)}) { + *string_param = a; + } else if (const auto int_param{get_if(&value)}) { + istringstream{a} >> *int_param; + } else if (const auto float_param{get_if(&value)}) { + istringstream{a} >> *float_param; + } +} + +void arg_list::parse(int argc, str argv_[]) { + TTS_ASSERT(argc); + span argv{argv_, static_cast(argc)}; + argv = argv.subspan(1); + while (!argv.empty()) { + str name{argv[0]}; + if (*name != '-') { + fprintf(stderr, "Only named arguments are supported\n"); + exit(1); + } + ++name; + const map * lookup = &abbreviations; + if (*name == '-') { + ++name; + lookup = &full_names; + if (name == "help"sv) { + for (const size_t i : full_names | views::values) { + args[i].print_help(); + } + exit(0); + } + } + const auto found = lookup->find(sv{name}); + if (found == lookup->end()) { + fprintf(stderr, "argument '%s' is not a valid argument. " + "Call '--help' for information on all valid arguments.\n", argv[0]); + exit(1); + } + argv = argv.subspan(1); + args[found->second].parse(argv); + } + for (const arg & x : args) { + if (x.required) { + fprintf(stderr, "argument '--%s' is required.\n", x.full_name); + exit(1); + } + } +} diff --git a/examples/args.h b/examples/args.h new file mode 100644 index 0000000..c4c102d --- /dev/null +++ b/examples/args.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include + +#include "imports.h" + +/** + * Holder of one argument. + */ +class arg { + variant value; + bool required; + + void print_help() const; + + void parse(span & argv); + + friend class arg_list; + +public: + const str full_name; + const str abbreviation; + const str description; + + template + constexpr arg(T default_value, str full_name, str abbreviation, str description, bool required = false) + : value{default_value}, required{required}, + full_name{full_name}, abbreviation{abbreviation}, description{description} { + TTS_ASSERT(full_name[0] != '-'); + TTS_ASSERT(abbreviation[0] != '-'); + } + + template + requires is_same_v || is_same_v || is_same_v || is_same_v + // ReSharper disable once CppNonExplicitConversionOperator // We want this to automatically cast + constexpr operator T() const { // NOLINT(*-explicit-constructor) + return get(value); + } +}; + +class arg_list { + vector args{}; + map full_names{}; + map abbreviations{}; + +public: + void add(const arg & x) { + const size_t i{args.size()}; + args.push_back(x); + TTS_ASSERT(!full_names.contains(args[i].full_name)); + full_names[args[i].full_name] = i; + if (*args[i].abbreviation) { + abbreviations[args[i].abbreviation] = i; + } + } + + void parse(int argc, str argv_[]); + + constexpr const arg & operator [](sv full_name) const noexcept { + TTS_ASSERT(full_name[0] != '-'); + return args[full_names.at(full_name)]; + } +}; diff --git a/examples/args_common.cpp b/examples/args_common.cpp new file mode 100644 index 0000000..15b9b7d --- /dev/null +++ b/examples/args_common.cpp @@ -0,0 +1,86 @@ +#include "args_common.h" + +#include "tts.h" + +void add_baseline_args(arg_list & args) { + // runner_from_file + args.add({"", "model-path", "mp", "The local path of the gguf model(s) to load", true}); + args.add({ + max(static_cast(thread::hardware_concurrency()), 1), "n-threads", "nt", + "The number of CPU threads to run calculations with. Defaults to known hardware concurrency. " + "If hardware concurrency cannot be determined then it defaults to 1" + }); +} + +static constexpr generation_configuration default_config{}; + +void add_common_args(arg_list & args) { + add_baseline_args(args); + // generation_configuration + args.add({!default_config.use_cross_attn, "no-cross-attn", "ca", "Whether to not include cross attention"}); + args.add({default_config.temperature, "temperature", "t", "The temperature to use when generating outputs"}); + args.add({ + default_config.repetition_penalty, "repetition-penalty", "r", + "The per-channel repetition penalty to be applied the sampled output of the model" + }); + args.add({ + default_config.top_p, "top-p", "mt", + "The sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0" + }); + args.add({ + default_config.top_k, "topk", "tk", + "When set to an integer value greater than 0 generation uses nucleus sampling over topk nucleus size. " + "Defaults to 50" + }); + args.add({ + default_config.max_tokens, "max-tokens", "mt", + "The max audio tokens or token batches to generate where each represents approximates 11 ms of audio. " + "Only applied to Dia generation. If set to zero as is its default then the default max generation size. " + "Warning values under 15 are not supported" + }); + args.add({ + default_config.voice, "voice", "v", + "The voice to use to generate the audio. This is only used for models with voice packs" + }); + add_espeak_voice_arg(args); + // runner_from_file + args.add({false, "use-metal", "m", "Whether to use metal acceleration"}); +} + +generation_configuration parse_generation_config(const arg_list & args) { + const generation_configuration config{ + .use_cross_attn{!args["no-cross-attn"]}, + .temperature{args["temperature"]}, + .repetition_penalty{args["repetition-penalty"]}, + .top_p{args["top-p"]}, + .top_k{args["topk"]}, + .max_tokens{args["max-tokens"]}, + .voice{args["voice"]}, + .espeak_voice_id{args["espeak-voice-id"]} + }; + if (config.top_p > 1.0f || config.top_p <= 0.0f) { + fprintf(stderr, "The '--top-p' value must be between 0.0 and 1.0. It was set to '%.6f'.\n", config.top_p); + exit(1); + } + return config; +} + +tts_runner * runner_from_args(const arg_list & args, const generation_configuration & config) { + return runner_from_file(args["model-path"], args["n-threads"], config, !args["use-metal"]); +} + +void add_text_encoder_arg(arg_list & args) { + args.add({ + "", "text-encoder-path", "tep", + "The local path of the text encoder gguf model for conditional generation" + }); +} + +void add_espeak_voice_arg(arg_list & args) { + args.add({ + default_config.espeak_voice_id, "espeak-voice-id", "eid", + "The eSpeak voice id to use for phonemization. " + "This should only be specified when the correct eSpeak voice cannot be inferred from the Kokoro voice. " + "See MultiLanguage Configuration in the README for more info" + }); +} diff --git a/examples/args_common.h b/examples/args_common.h new file mode 100644 index 0000000..7183fed --- /dev/null +++ b/examples/args_common.h @@ -0,0 +1,13 @@ +#pragma once + +#include "args.h" +#include "common.h" + +void add_baseline_args(arg_list & args); +void add_common_args(arg_list & args); + +generation_configuration parse_generation_config(const arg_list & args); +tts_runner * runner_from_args(const arg_list & args, const generation_configuration & config); + +void add_text_encoder_arg(arg_list & args); +void add_espeak_voice_arg(arg_list & args); diff --git a/include/audio_file.h b/examples/audio_file.h similarity index 100% rename from include/audio_file.h rename to examples/audio_file.h diff --git a/examples/cli/CMakeLists.txt b/examples/cli/CMakeLists.txt index 2aaefcf..55133b4 100644 --- a/examples/cli/CMakeLists.txt +++ b/examples/cli/CMakeLists.txt @@ -16,4 +16,4 @@ if (SDL2_FOUND) set_source_files_properties(playback.cpp PROPERTIES COMPILE_FLAGS -DSDL2_INSTALL=1) endif() -target_link_libraries(${TARGET} PRIVATE ggml tts) +target_link_libraries(${TARGET} PRIVATE examples_common) diff --git a/examples/cli/README.md b/examples/cli/README.md index 0fe687f..71dd5dc 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -10,57 +10,56 @@ This simple example cli tool can be used to generate speach from a text prompt a ### Usage In order to get a detailed breakdown the functionality currently available you can call the cli with the `--help` parameter. This will return a breakdown of all parameters: -```bash -./cli --help - ---temperature (-t): - The temperature to use when generating outputs. Defaults to 1.0. ---repetition-penalty (-r): - The by channel repetition penalty to be applied the sampled output of the model. defaults to 1.0. ---top-p (tp): - (OPTIONAL) the sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0. ---n-threads (-nt): - The number of cpu threads to run generation with. Defaults to hardware concurrency. If hardware concurrency cannot be determined then it defaults to 1. ---topk (-tk): - (OPTIONAL) When set to an integer value greater than 0 generation uses nucleus sampling over topk nucleaus size. Defaults to 50. +```console +$ ./tts-cli --help +--conditional-prompt (-cp): + (OPTIONAL) A distinct conditional prompt to use for generating. If none is provided the preencoded prompt is used. '--text-encoder-path' must be set to use conditional generation. +--espeak-voice-id (-eid): + (OPTIONAL) The eSpeak voice id to use for phonemization. This should only be specified when the correct eSpeak voice cannot be inferred from the Kokoro voice. See MultiLanguage Configuration in the README for more info. --max-tokens (-mt): (OPTIONAL) The max audio tokens or token batches to generate where each represents approximates 11 ms of audio. Only applied to Dia generation. If set to zero as is its default then the default max generation size. Warning values under 15 are not supported. ---use-metal (-m): - (OPTIONAL) Whether to use metal acceleration +--model-path (-mp): + (REQUIRED) The local path of the gguf model(s) to load. +--n-threads (-nt): + (OPTIONAL) The number of CPU threads to run calculations with. Defaults to known hardware concurrency. If hardware concurrency cannot be determined then it defaults to 1. --no-cross-attn (-ca): - (OPTIONAL) Whether to not include cross attention ---vad (-va): - (OPTIONAL) whether to apply voice inactivity detection (VAD) and strip silence form the end of the output (particularly useful for Parler TSS). By default, no VAD is applied. + (OPTIONAL) Whether to not include cross attention. --play: - (OPTIONAL) Whether to play back the audio immediately instead of saving it to file. ---model-path (-mp): - (REQUIRED) The local path of the gguf model file for Parler TTS mini or large v1, Dia, or Kokoro. + (OPTIONAL) Whether to play back the audio immediately instead of saving it to file.. --prompt (-p): - (REQUIRED) The text prompt for which to generate audio in quotation markers. + (REQUIRED) The text prompt for which to generate audio. +--repetition-penalty (-r): + (OPTIONAL) The per-channel repetition penalty to be applied the sampled output of the model. --save-path (-sp): - (OPTIONAL) The path to save the audio output to in a .wav format. Defaults to TTS.cpp.wav ---conditional-prompt (-cp): - (OPTIONAL) A distinct conditional prompt to use for generating. If none is provided the preencoded prompt is used. '--text-encoder-path' must be set to use conditional generation. + (OPTIONAL) The path to save the audio output to in a .wav format. +--temperature (-t): + (OPTIONAL) The temperature to use when generating outputs. --text-encoder-path (-tep): - (OPTIONAL) The local path of the text encoder gguf model for conditional generaiton. + (OPTIONAL) The local path of the text encoder gguf model for conditional generation. +--top-p (-mt): + (OPTIONAL) The sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0. +--topk (-tk): + (OPTIONAL) When set to an integer value greater than 0 generation uses nucleus sampling over topk nucleus size. Defaults to 50. +--use-metal (-m): + (OPTIONAL) Whether to use metal acceleration. +--vad (-va): + (OPTIONAL) Whether to apply voice inactivity detection (VAD) and strip silence form the end of the output. This is particularly useful for Parler TTS. By default, no VAD is applied. --voice (-v): (OPTIONAL) The voice to use to generate the audio. This is only used for models with voice packs. ---espeak-voice-id (-eid): - (OPTIONAL) The espeak voice id to use for phonemization. This should only be specified when the correct espeak voice cannot be inferred from the kokoro voice ( see MultiLanguage Configuration in the README for more info). ``` General usage should follow from these possible parameters. E.G. The following command will save generated speech to the `/tmp/test.wav` file. ```bash -./cli --model-path /model/path/to/gguf_file.gguf --prompt "I am saying some words" --save-path /tmp/test.wav +./tts-cli --model-path /model/path/to/gguf_file.gguf --prompt "I am saying some words" --save-path /tmp/test.wav ``` #### Dia Generation Arguments Currently the default cli arguments are not aligned with Dia's default sampling settings. Specifically the temperature and topk settings should be changed to `1.3` and `35` respectively when generating with Dia like so: -```base -./cli --model-path /model/path/to/Dia.gguf --prompt "[S1] Hi, I am Dia, this is how I talk." --save-path /tmp/test.wav --topk 35 --temperature 1.3 +```bash +./tts-cli --model-path /model/path/to/Dia.gguf --prompt "[S1] Hi, I am Dia, this is how I talk." --save-path /tmp/test.wav --topk 35 --temperature 1.3 ``` #### Conditional Generation @@ -87,7 +86,7 @@ Kokoro supports multiple langauges with distinct voices, and, by default, the st Each voice has a language assigned and gender assigned to it where the first letter of the pack represents the language and the second the gender (e.g. `af_alloy` is an American English Female voice; `a` corresponds to American Enlgish and `f` to Female). Below is a list of all currently supported langauges mapped to their respective codes: -``` +```text # πŸ‡ΊπŸ‡Έ 'a' => American English, πŸ‡¬πŸ‡§ 'b' => British English # πŸ‡ͺπŸ‡Έ 'e' => Spanish es # πŸ‡«πŸ‡· 'f' => French fr-fr @@ -103,4 +102,3 @@ By default when a voice of a specific language is used, phonemization for that l ```bash espeak-ng --voices ``` - diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 103f216..6b85603 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -1,10 +1,9 @@ -#include "tts.h" -#include "args.h" -#include "common.h" +#include +#include "args_common.h" #include "playback.h" +#include "tts.h" #include "vad.h" #include "write_file.h" -#include class tts_timing_printer { const int64_t start_us{[] { @@ -21,76 +20,51 @@ class tts_timing_printer { int main(int argc, const char ** argv) { const tts_timing_printer _{}; - float default_temperature = 1.0f; - int default_n_threads = std::max((int)std::thread::hardware_concurrency(), 1); - int default_top_k = 50; - int default_max_tokens = 0; - float default_repetition_penalty = 1.0f; - float default_top_p = 1.0f; - arg_list args; - args.add_argument(string_arg("--model-path", "(REQUIRED) The local path of the gguf model file for Parler TTS mini or large v1, Dia, or Kokoro.", "-mp", true)); - args.add_argument(string_arg("--prompt", "(REQUIRED) The text prompt for which to generate audio in quotation markers.", "-p", true)); - args.add_argument(string_arg("--save-path", "(OPTIONAL) The path to save the audio output to in a .wav format. Defaults to TTS.cpp.wav", "-sp", false, "TTS.cpp.wav")); - args.add_argument(float_arg("--temperature", "The temperature to use when generating outputs. Defaults to 1.0.", "-t", false, &default_temperature)); - args.add_argument(int_arg("--n-threads", "The number of cpu threads to run generation with. Defaults to hardware concurrency. If hardware concurrency cannot be determined then it defaults to 1.", "-nt", false, &default_n_threads)); - args.add_argument(int_arg("--topk", "(OPTIONAL) When set to an integer value greater than 0 generation uses nucleus sampling over topk nucleaus size. Defaults to 50.", "-tk", false, &default_top_k)); - args.add_argument(float_arg("--repetition-penalty", "The by channel repetition penalty to be applied the sampled output of the model. defaults to 1.0.", "-r", false, &default_repetition_penalty)); - args.add_argument(bool_arg("--use-metal", "(OPTIONAL) Whether to use metal acceleration", "-m")); - args.add_argument(bool_arg("--no-cross-attn", "(OPTIONAL) Whether to not include cross attention", "-ca")); - args.add_argument(string_arg("--conditional-prompt", "(OPTIONAL) A distinct conditional prompt to use for generating. If none is provided the preencoded prompt is used. '--text-encoder-path' must be set to use conditional generation.", "-cp", false)); - args.add_argument(string_arg("--text-encoder-path", "(OPTIONAL) The local path of the text encoder gguf model for conditional generaiton.", "-tep", false)); - args.add_argument(string_arg("--voice", "(OPTIONAL) The voice to use to generate the audio. This is only used for models with voice packs.", "-v", false, "af_alloy")); - args.add_argument(bool_arg("--vad", "(OPTIONAL) whether to apply voice inactivity detection (VAD) and strip silence form the end of the output (particularly useful for Parler TTS). By default, no VAD is applied.", "-va")); - args.add_argument(string_arg("--espeak-voice-id", "(OPTIONAL) The espeak voice id to use for phonemization. This should only be specified when the correct espeak voice cannot be inferred from the kokoro voice ( see MultiLanguage Configuration in the README for more info).", "-eid", false)); - args.add_argument(int_arg("--max-tokens", "(OPTIONAL) The max audio tokens or token batches to generate where each represents approximates 11 ms of audio. Only applied to Dia generation. If set to zero as is its default then the default max generation size. Warning values under 15 are not supported.", "-mt", false, &default_max_tokens)); - args.add_argument(float_arg("--top-p", "(OPTIONAL) the sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0.", "-tp", false, &default_top_p)); + arg_list args{}; + add_common_args(args); + args.add({"", "prompt", "p", "The text prompt for which to generate audio", true}); + args.add({"TTS.cpp.wav", "save-path", "sp", "The path to save the audio output to in a .wav format"}); + args.add({ + "", "conditional-prompt", "cp", + "A distinct conditional prompt to use for generating. " + "If none is provided the preencoded prompt is used. " + "'--text-encoder-path' must be set to use conditional generation" + }); + add_text_encoder_arg(args); + args.add({ + false, "vad", "va", + "Whether to apply voice inactivity detection (VAD) and strip silence form the end of the output. " + "This is particularly useful for Parler TTS. By default, no VAD is applied" + }); register_play_tts_response_args(args); args.parse(argc, argv); - if (args.for_help) { - args.help(); - exit(0); - } - args.validate(); - std::string conditional_prompt = args.get_string_param("--conditional-prompt"); - std::string text_encoder_path = args.get_string_param("--text-encoder-path"); - if (conditional_prompt.size() > 0 && text_encoder_path.size() <= 0) { + const str conditional_prompt{args["conditional-prompt"]}; + const str text_encoder_path{args["text-encoder-path"]}; + if (*conditional_prompt && !*text_encoder_path) { fprintf(stderr, "The '--text-encoder-path' must be specified when '--condtional-prompt' is passed.\n"); exit(1); } - if (*args.get_float_param("--top-p") > 1.0f || *args.get_float_param("--top-p") <= 0.0f) { - fprintf(stderr, "The '--top-p' value must be between 0.0 and 1.0. It was set to '%.6f'.\n", *args.get_float_param("--top-p")); - exit(1); - } - - generation_configuration * config = new generation_configuration( - args.get_string_param("--voice"), - *args.get_int_param("--topk"), - *args.get_float_param("--temperature"), - *args.get_float_param("--repetition-penalty"), - !args.get_bool_param("--no-cross-attn"), - args.get_string_param("--espeak-voice-id"), - *args.get_int_param("--max-tokens"), - *args.get_float_param("--top-p")); - - struct tts_runner * runner = runner_from_file(args.get_string_param("--model-path"), *args.get_int_param("--n-threads"), config, !args.get_bool_param("--use-metal")); + const generation_configuration config{parse_generation_config(args)}; + tts_runner * const runner{runner_from_args(args, config)}; - if (conditional_prompt.size() > 0) { + if (*conditional_prompt) { update_conditional_prompt(runner, text_encoder_path, conditional_prompt, true); } tts_response data; - generate(runner, args.get_string_param("--prompt"), &data, config); + const str prompt{args["prompt"]}; + generate(runner, prompt, data, config); if (data.n_outputs == 0) { - fprintf(stderr, "Got empty response for prompt, '%s'.\n", args.get_string_param("--prompt").c_str()); + fprintf(stderr, "Got empty response for prompt, '%s'.\n", prompt); exit(1); } - if (args.get_bool_param("--vad")) { + if (args["vad"]) { apply_energy_voice_inactivity_detection(data, runner->sampling_rate); } if (!play_tts_response(args, data, runner->sampling_rate)) { - write_audio_file(data, args.get_string_param("--save-path"), runner->sampling_rate); + write_audio_file(data, args["save-path"], runner->sampling_rate); } return 0; } diff --git a/examples/cli/playback.cpp b/examples/cli/playback.cpp index 1659c85..06a47de 100644 --- a/examples/cli/playback.cpp +++ b/examples/cli/playback.cpp @@ -1,4 +1,7 @@ #include +#ifdef SDL2_INSTALL +#include "SDL.h" +#endif #include "playback.h" #ifndef SDL2_INSTALL @@ -10,13 +13,12 @@ bool play_tts_response(arg_list & args, const tts_response & data, float sample_ return false; } #else -#include "SDL.h" void register_play_tts_response_args(arg_list & args) { - args.add_argument(bool_arg("--play", "(OPTIONAL) Whether to play back the audio immediately instead of saving it to file.")); + args.add({false, "play", "", "Whether to play back the audio immediately instead of saving it to file."}); } bool play_tts_response(arg_list & args, const tts_response & data, float sample_rate) { - if (!args.get_bool_param("--play")) { + if (!args["play"]) { return false; } diff --git a/examples/cli/write_file.cpp b/examples/cli/write_file.cpp index 9393572..0c8c285 100644 --- a/examples/cli/write_file.cpp +++ b/examples/cli/write_file.cpp @@ -1,9 +1,9 @@ #include -#include "write_file.h" #include "audio_file.h" +#include "write_file.h" -void write_audio_file(const tts_response & data, std::string path, float sample_rate) { - fprintf(stdout, "Writing audio file: %s\n", path.c_str()); +void write_audio_file(const tts_response & data, str path, float sample_rate) { + fprintf(stdout, "Writing audio file: %s\n", path); AudioFile file; file.setSampleRate(sample_rate); file.samples[0] = std::vector(data.data, data.data + data.n_outputs); diff --git a/examples/cli/write_file.h b/examples/cli/write_file.h index 017bca9..935beeb 100644 --- a/examples/cli/write_file.h +++ b/examples/cli/write_file.h @@ -2,4 +2,4 @@ #include "common.h" -void write_audio_file(const tts_response & data, std::string path = "TTS.cpp.wav", float sample_rate = 44100.0f); +void write_audio_file(const tts_response & data, str path = "TTS.cpp.wav", float sample_rate = 44100.0f); diff --git a/examples/perf_battery/CMakeLists.txt b/examples/perf_battery/CMakeLists.txt index bb630eb..faec950 100644 --- a/examples/perf_battery/CMakeLists.txt +++ b/examples/perf_battery/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(perf_battery perf_battery.cpp) -target_link_libraries(perf_battery PRIVATE ggml tts) +target_link_libraries(perf_battery PRIVATE examples_common) diff --git a/examples/perf_battery/README.md b/examples/perf_battery/README.md index 65e881c..0ea3004 100644 --- a/examples/perf_battery/README.md +++ b/examples/perf_battery/README.md @@ -10,26 +10,39 @@ This script runs a series of benchmarks to test the generative throughput of the ### Usage In order to get a detailed breakdown the functionality currently available you can call the cli with the `--help` parameter. This will return a breakdown of all parameters: -```commandline -./perf_battery --help - +```console +$ ./perf_battery --help +--espeak-voice-id (-eid): + (OPTIONAL) The eSpeak voice id to use for phonemization. This should only be specified when the correct eSpeak voice cannot be inferred from the Kokoro voice. See MultiLanguage Configuration in the README for more info. +--max-tokens (-mt): + (OPTIONAL) The max audio tokens or token batches to generate where each represents approximates 11 ms of audio. Only applied to Dia generation. If set to zero as is its default then the default max generation size. Warning values under 15 are not supported. +--model-path (-mp): + (REQUIRED) The local path of the gguf model(s) to load. --n-threads (-nt): - The number of cpu threads to run generation with. Defaults to 10. ---use-metal (-m): - (OPTIONAL) whether or not to use metal acceleration. + (OPTIONAL) The number of CPU threads to run calculations with. Defaults to known hardware concurrency. If hardware concurrency cannot be determined then it defaults to 1. --no-cross-attn (-ca): - (OPTIONAL) Whether to not include cross attention ---model-path (-mp): - (REQUIRED) The local path of the gguf model file for Parler TTS mini v1. + (OPTIONAL) Whether to not include cross attention. +--repetition-penalty (-r): + (OPTIONAL) The per-channel repetition penalty to be applied the sampled output of the model. +--temperature (-t): + (OPTIONAL) The temperature to use when generating outputs. +--top-p (-mt): + (OPTIONAL) The sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0. +--topk (-tk): + (OPTIONAL) When set to an integer value greater than 0 generation uses nucleus sampling over topk nucleus size. Defaults to 50. +--use-metal (-m): + (OPTIONAL) Whether to use metal acceleration. +--voice (-v): + (OPTIONAL) The voice to use to generate the audio. This is only used for models with voice packs. ``` General usage should follow from these possible parameters. E.G. The following command will save generated speech to the `/tmp/test.wav` file. -```commandline +```bash ./perf_battery --model-path /model/path/to/gguf_file.gguf --use-metal ``` the output will look like the following: -``` +```text Mean Stats for arch Parler-TTS: Generation Time (ms): 12439.43255 @@ -43,7 +56,7 @@ Mean Stats for arch Parler-TTS: The currently measured performance breakdown for Parler Mini v1.0 with Q5_0 quantization without cross attention (i.e. the fastest stable generation with the Parler model) and 32bit floating point weights in the audio decoder: -``` +```text Mean Stats: Generation Time (ms): 8599.550347 diff --git a/examples/perf_battery/perf_battery.cpp b/examples/perf_battery/perf_battery.cpp index 36d0cbc..e9cd6cd 100644 --- a/examples/perf_battery/perf_battery.cpp +++ b/examples/perf_battery/perf_battery.cpp @@ -1,22 +1,13 @@ -#include "tts.h" -#include "args.h" -#include "common.h" -#include #include -#include -#include - +#include -std::vector ARCH_LOOKUP = { - "parler-tts", - "kokoro", -}; - -using perf_cb = std::function; +#include "args_common.h" +#include "tts.h" -double benchmark_ms(perf_cb func) { +namespace { +double benchmark_ms(auto lambda) { auto start = std::chrono::steady_clock::now(); - func(); + lambda(); auto end = std::chrono::steady_clock::now(); std::chrono::duration duration = end - start; return duration.count(); @@ -26,7 +17,7 @@ double benchmark_ms(perf_cb func) { * These are the 'Harvard Sentences' (https://en.wikipedia.org/wiki/Harvard_sentences). They are phonetically * balanced sentences typically used for standardized testing of voice over cellular and telephone systems. */ -std::vector TEST_SENTENCES = { +constexpr array TEST_SENTENCES = { "The birch canoe slid on the smooth planks.", "Glue the sheet to the dark blue background.", "It's easy to tell the depth of a well.", @@ -67,57 +58,40 @@ double mean(std::vector series) { return (double) sum / series.size(); } -std::string benchmark_printout(tts_arch arch, std::vector generation_samples, std::vector output_times) { - std::string arch_name = ARCH_LOOKUP[(int)arch]; - double gen_mean = mean(generation_samples); +void benchmark_printout(tts_arch arch, const vector & generation_samples, const vector & output_times) { + const str arch_name = SUPPORTED_ARCHITECTURES[arch]; + const double gen_mean = mean(generation_samples); std::vector gen_output; - for (int i = 0; i < (int) output_times.size(); i++) { + for (size_t i = 0; i < output_times.size(); i++) { gen_output.push_back(generation_samples[i]/output_times[i]); } double gen_out_mean = mean(gen_output); - std::string printout = (std::string) "Mean Stats for arch " + arch_name + ":\n\n" + (std::string) " Generation Time (ms): " + std::to_string(gen_mean) + (std::string) "\n"; - printout += (std::string) " Generation Real Time Factor (ms): " + std::to_string(gen_out_mean) + (std::string) "\n"; - return printout; + cout << "Mean Stats for arch " << arch_name << ":\n\n Generation Time (ms): "; + cout << gen_mean << endl; + cout << " Generation Real Time Factor (ms): " << gen_out_mean << endl; +} } - int main(int argc, const char ** argv) { - float default_temperature = 1.0f; - int default_n_threads = std::max((int)std::thread::hardware_concurrency(), 1); - int default_top_k = 50; - float default_repetition_penalty = 1.0f; - arg_list args; - args.add_argument(string_arg("--model-path", "(REQUIRED) The local path of the gguf model file for Parler TTS mini v1.", "-mp", true)); - args.add_argument(int_arg("--n-threads", "The number of cpu threads to run generation with. Defaults to hardware concurrency. If hardware concurrency cannot be determined it defaults to 1.", "-nt", false, &default_n_threads)); - args.add_argument(float_arg("--temperature", "The temperature to use when generating outputs. Defaults to 1.0.", "-t", false, &default_temperature)); - args.add_argument(int_arg("--topk", "(OPTIONAL) When set to an integer value greater than 0 generation uses nucleus sampling over topk nucleaus size. Defaults to 50.", "-tk", false, &default_top_k)); - args.add_argument(string_arg("--voice", "(OPTIONAL) The voice to use to generate the audio. This is only used for models with voice packs.", "-v", false, "af_alloy")); - args.add_argument(float_arg("--repetition-penalty", "The by channel repetition penalty to be applied the sampled output of the model. defaults to 1.0.", "-r", false, &default_repetition_penalty)); - args.add_argument(bool_arg("--use-metal", "(OPTIONAL) whether or not to use metal acceleration.", "-m")); - args.add_argument(bool_arg("--no-cross-attn", "(OPTIONAL) Whether to not include cross attention", "-ca")); + arg_list args{}; + add_common_args(args); args.parse(argc, argv); - if (args.for_help) { - args.help(); - return 0; - } - args.validate(); - - generation_configuration * config = new generation_configuration(args.get_string_param("--voice"), *args.get_int_param("--topk"), *args.get_float_param("--temperature"), *args.get_float_param("--repetition-penalty"), !args.get_bool_param("--no-cross-attn")); - struct tts_runner * runner = runner_from_file(args.get_string_param("--model-path"), *args.get_int_param("--n-threads"), config, !args.get_bool_param("--use-metal")); + const generation_configuration config{parse_generation_config(args)}; + tts_runner * const runner{runner_from_args(args, config)}; std::vector generation_samples; std::vector output_times; - for (std::string sentence : TEST_SENTENCES) { + for (const str sentence : TEST_SENTENCES) { tts_response response; - perf_cb cb = [&]{ - generate(runner, sentence, &response, config); + const auto cb = [&]{ + generate(runner, sentence, response, config); }; double generation_ms = benchmark_ms(cb); - output_times.push_back((double)(response.n_outputs / 44.1)); + output_times.push_back(response.n_outputs / 44.1); generation_samples.push_back(generation_ms); } - fprintf(stdout, "%s", benchmark_printout(runner->arch, generation_samples, output_times).c_str()); + benchmark_printout(runner->arch, generation_samples, output_times); return 0; } diff --git a/examples/phonemize/CMakeLists.txt b/examples/phonemize/CMakeLists.txt index 9fb6c89..3aef1ec 100644 --- a/examples/phonemize/CMakeLists.txt +++ b/examples/phonemize/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(phonemize phonemize.cpp) -target_link_libraries(phonemize PRIVATE ggml tts) +target_link_libraries(phonemize PRIVATE examples_common) diff --git a/examples/phonemize/README.md b/examples/phonemize/README.md index 40c3fcb..26a485d 100644 --- a/examples/phonemize/README.md +++ b/examples/phonemize/README.md @@ -10,22 +10,19 @@ This is a simple cli for running TTS.cpp phonemization on a pass text string. Fo ### Usage In order to get a detailed breakdown the functionality currently available you can call the cli with the `--help` parameter. This will return a breakdown of all parameters: -```commandline -./build/bin/phonemize --help - ---use-espeak (-ue): - (OPTIONAL) Whether to use espeak to generate phonems. +```console +$ ./phonemize --help +--espeak-voice-id (-eid): + (OPTIONAL) The eSpeak voice id to use for phonemization. This should only be specified when the correct eSpeak voice cannot be inferred from the Kokoro voice. See MultiLanguage Configuration in the README for more info. --phonemizer-path (-mp): - (OPTIONAL) The local path of the gguf phonemiser file for TTS.cpp phonemizer. This is required if not using espeak. + (OPTIONAL) The local path of the gguf phonemiser file for TTS.cpp phonemizer. Omit this to use eSpeak to generate phonemes. --prompt (-p): (REQUIRED) The text prompt to phonemize. ---espeak-voice-id (-eid): - (OPTIONAL) The voice id to use for espeak phonemization. Defaults to 'gmw/en-US'. ``` General usage should follow from these possible parameters. E.G. The following command will return the phonemized IPA text for the prompt via the TTS.cpp phonemizer. -```commandline +```bash ./build/bin/phonemize --phonemizer-path "/path/to/tts_phonemizer.gguf" --prompt "this is a test." ``` @@ -33,6 +30,6 @@ General usage should follow from these possible parameters. E.G. The following c To use espeak phonemization you must first install the TTS with espeak linked. Phonemization can then be accomplished via the following: -```commandlinecommandline +```bash ./build/bin/phonemize --prompt "this is a test." --use-espeak ``` diff --git a/examples/phonemize/phonemize.cpp b/examples/phonemize/phonemize.cpp index 83d551d..fc8b0b4 100644 --- a/examples/phonemize/phonemize.cpp +++ b/examples/phonemize/phonemize.cpp @@ -1,31 +1,27 @@ +#include + +#include "args_common.h" #include "phonemizer.h" -#include "args.h" -#include int main(int argc, const char ** argv) { - arg_list args; - args.add_argument(string_arg("--phonemizer-path", "(OPTIONAL) The local path of the gguf phonemiser file for TTS.cpp phonemizer. This is required if not using espeak.", "-mp")); - args.add_argument(string_arg("--prompt", "(REQUIRED) The text prompt to phonemize.", "-p", true)); - args.add_argument(bool_arg("--use-espeak", "(OPTIONAL) Whether to use espeak to generate phonems.", "-ue")); - args.add_argument(string_arg("--espeak-voice-id", "(OPTIONAL) The voice id to use for espeak phonemization. Defaults to 'gmw/en-US'.", "-eid", false, "gmw/en-US")); + arg_list args{}; + args.add({"", "prompt", "p", "The text prompt to phonemize", true}); + args.add({ + "", "phonemizer-path", "mp", + "The local path of the gguf phonemiser file for TTS.cpp phonemizer. " + "Omit this to use eSpeak to generate phonemes" + }); + add_espeak_voice_arg(args); args.parse(argc, argv); - if (args.for_help) { - args.help(); - return 0; - } - args.validate(); - if (!args.get_bool_param("--use-espeak") && args.get_string_param("--phonemizer-path") == "") { - fprintf(stderr, "The '--phonemizer-path' must be specified when '--use-espeak' is not true.\n"); - exit(1); - } + const str phonemizer_path{args["phonemizer-path"]}; phonemizer * ph; - if (args.get_bool_param("--use-espeak")) { - ph = espeak_phonemizer(false, args.get_string_param("--espeak-voice-id")); + if (*phonemizer_path) { + ph = phonemizer_from_file(phonemizer_path); } else { - ph = phonemizer_from_file(args.get_string_param("--phonemizer-path")); + ph = espeak_phonemizer(false, args["espeak-voice-id"]); } - std::string response = ph->text_to_phonemes(args.get_string_param("--prompt")); + const string response{ph->text_to_phonemes(string{args["prompt"]})}; fprintf(stdout, "%s\n", response.c_str()); return 0; } diff --git a/examples/quantize/CMakeLists.txt b/examples/quantize/CMakeLists.txt index fda21c8..9a50f50 100644 --- a/examples/quantize/CMakeLists.txt +++ b/examples/quantize/CMakeLists.txt @@ -1,2 +1,6 @@ -add_executable(quantize quantize.cpp) -target_link_libraries(quantize PRIVATE ggml tts) +add_executable(quantize + quantize.cpp + quantize_impl.cpp + quantize_impl.h +) +target_link_libraries(quantize PRIVATE examples_common) diff --git a/examples/quantize/README.md b/examples/quantize/README.md index ad04e8c..0a86172 100644 --- a/examples/quantize/README.md +++ b/examples/quantize/README.md @@ -11,30 +11,27 @@ This script converts a 32bit floating point TTS.cpp GGUF model file to a quantiz ### Usage -**Please note** Quantization and lower precision conversion is currently only supported for Parler TTS models. - In order to get a detailed breakdown of the functionality currently available you can call the cli with the `--help` parameter. This will return a breakdown of all parameters: -```bash -./quantize --help - ---quantized-type (-qt): - (OPTIONAL) The ggml enum of the quantized type to convert compatible model tensors to. For more information see readme. Defaults to Q4_0 quantization (2). ---n-threads (-nt): - (OPTIONAL) The number of cpu threads to run the quantization process with. Defaults to known hardware concurrency. +```console +$ ./quantize --help --convert-dac-to-f16 (-df): (OPTIONAL) Whether to convert the DAC audio decoder model to a 16 bit float. ---quantize-output-heads (-qh): - (OPTIONAL) Whether to quantize the output heads. Defaults to false and is true when passed (does not accept a parameter). ---quantize-text-embedding (-qe): - (OPTIONAL) Whether to quantize the input text embededings (only applicable for Parler TTS). Defaults to false and is true when passed (does not accept a parameter). ---quantize-cross-attn-kv (-qkv): - (OPTIONAL) Whether to quantize the cross attention keys and values (only applicable for Parler TTS). Defaults to false and is true when passed (does not accept a parameter). --convert-non-quantized-to-f16 (-nqf): - (OPTIONAL) Whether or not to convert quantization incompatible tensors to 16 bit precision. Only currently applicable to Kokoror. defaults to false. + (OPTIONAL) Whether or not to convert quantization incompatible tensors to 16 bit precision. Only currently applicable to Kokoro. --model-path (-mp): - (REQUIRED) The local path of the gguf model file for Parler TTS mini v1 to quantize. + (REQUIRED) The local path of the gguf model(s) to load. +--n-threads (-nt): + (OPTIONAL) The number of CPU threads to run calculations with. Defaults to known hardware concurrency. If hardware concurrency cannot be determined then it defaults to 1. +--quantize-cross-attn-kv (-qkv): + (OPTIONAL) Whether to quantize the cross attention keys and values (only applicable for Parler TTS). +--quantize-output-heads (-qh): + (OPTIONAL) Whether to quantize the output heads. +--quantize-text-embedding (-qe): + (OPTIONAL) Whether to quantize the input text embededings. --quantized-model-path (-qp): (REQUIRED) The path to save the model in a quantized format. +--quantized-type (-qt): + (OPTIONAL) The ggml enum of the quantized type to convert compatible model tensors to. For more information see readme. Defaults to Q4_0 quantization (2). ``` General usage should follow from these possible parameters. E.G. The following command will save a quantized version of the model using Q4_0 quantization to `/model/path/to/new/gguf_file_q.gguf`: @@ -98,7 +95,7 @@ The following approaches were experimented with: A clear improvement in tokens per second via the generative model is observed with quantization. Seen below Parler TTS mini with Q5_0 quantization, the model is capable of completing its generation in real time (it generates tokens faster than it takes to listen to them), and the model's TPS has improved from ~693 to ~986. -``` +```text Mean Stats: Generation Time (ms): 1945.434146 diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index ecf09e8..11efb75 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -1,47 +1,52 @@ -#include "tts.h" -#include "args.h" -#include +#include #include + +#include "args_common.h" #include "ggml.h" -#include +#include "tts.h" +#include "quantize_impl.h" -std::vector valid_quantization_types = { +namespace { +constexpr array VALID_QUANTIZATION_TYPES{ GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0, }; +} int main(int argc, const char ** argv) { - int default_quantization = (int) GGML_TYPE_Q4_0; - int default_n_threads = std::max((int)std::thread::hardware_concurrency(), 1); - arg_list args; - args.add_argument(string_arg("--model-path", "(REQUIRED) The local path of the gguf model file for Parler TTS mini v1 to quantize.", "-mp", true)); - args.add_argument(string_arg("--quantized-model-path", "(REQUIRED) The path to save the model in a quantized format.", "-qp", true)); - args.add_argument(int_arg("--quantized-type", "(OPTIONAL) The ggml enum of the quantized type to convert compatible model tensors to. For more information see readme. Defaults to Q4_0 quantization (2).", "-qt", false, &default_quantization)); - args.add_argument(int_arg("--n-threads", "(OPTIONAL) The number of cpu threads to run the quantization process with. Defaults to known hardware concurrency.", "-nt", false, &default_n_threads)); - args.add_argument(bool_arg("--convert-dac-to-f16", "(OPTIONAL) Whether to convert the DAC audio decoder model to a 16 bit float.", "-df")); - args.add_argument(bool_arg("--quantize-output-heads", "(OPTIONAL) Whether to quantize the output heads. Defaults to false and is true when passed (does not accept a parameter).", "-qh")); - args.add_argument(bool_arg("--quantize-text-embedding", "(OPTIONAL) Whether to quantize the input text embededings (only applicable for Parler TTS). Defaults to false and is true when passed (does not accept a parameter).", "-qe")); - args.add_argument(bool_arg("--quantize-cross-attn-kv", "(OPTIONAL) Whether to quantize the cross attention keys and values (only applicable for Parler TTS). Defaults to false and is true when passed (does not accept a parameter).", "-qkv")); - args.add_argument(bool_arg("--convert-non-quantized-to-f16", "(OPTIONAL) Whether or not to convert quantization incompatible tensors to 16 bit precision. Only currently applicable to Kokoro. defaults to false.", "-nqf")); + arg_list args{}; + add_baseline_args(args); + args.add({"", "quantized-model-path", "qp", "The path to save the model in a quantized format", true}); + args.add({ + GGML_TYPE_Q4_0, "quantized-type", "qt", + "The ggml enum of the quantized type to convert compatible model tensors to. For more information see readme. " + "Defaults to Q4_0 quantization (2)" + }); + args.add({false, "convert-dac-to-f16", "df", "Whether to convert the DAC audio decoder model to a 16 bit float"}); + args.add({false, "quantize-output-heads", "qh", "Whether to quantize the output heads"}); + args.add({false, "quantize-text-embedding", "qe", "Whether to quantize the input text embededings"}); + args.add({ + false, "quantize-cross-attn-kv", "qkv", + "Whether to quantize the cross attention keys and values (only applicable for Parler TTS)" + }); + args.add({ + false, "convert-non-quantized-to-f16", "nqf", + "Whether or not to convert quantization incompatible tensors to 16 bit precision. " + "Only currently applicable to Kokoro" + }); args.parse(argc, argv); - if (args.for_help) { - args.help(); - return 0; - } - args.validate(); - enum ggml_type qtype = static_cast(*args.get_int_param("--quantized-type")); - if (std::find(valid_quantization_types.begin(), valid_quantization_types.end(), qtype) == valid_quantization_types.end()) { - fprintf(stderr, "ERROR: %d is not a valid quantization type.\n", qtype); - exit(1); - } - struct quantization_params * qp = new quantization_params((uint32_t) *args.get_int_param("--n-threads"), qtype); - qp->quantize_output_heads = args.get_bool_param("--quantize-output-heads"); - qp->quantize_text_embeddings = args.get_bool_param("--quantize-text-embedding"); - qp->quantize_cross_attn_kv = args.get_bool_param("--quantize-cross-attn-kv"); - qp->convert_dac_to_f16 = args.get_bool_param("--convert-dac-to-f16"); - qp->convert_non_quantizable_to_f16 = args.get_bool_param("--convert-non-quantized-to-f16"); - quantize_gguf(args.get_string_param("--model-path"), args.get_string_param("--quantized-model-path"), qp); + const quantization_params qp{ + .n_threads{static_cast(static_cast(args["n-threads"]))}, + .quantize_type{static_cast(static_cast(args["--quantized-type"]))}, + .quantize_output_heads{args["quantize-output-heads"]}, + .quantize_text_embeddings{args["quantize-text-embedding"]}, + .quantize_cross_attn_kv{args["quantize-cross-attn-kv"]}, + .convert_dac_to_f16{args["convert-dac-to-f16"]}, + .convert_non_quantizable_to_f16{args["convert-non-quantized-to-f16"]} + }; + TTS_ASSERT(ranges::contains(VALID_QUANTIZATION_TYPES, qp.quantize_type)); + quantize_gguf(args["model-path"], args["--quantized-model-path"], qp); return 0; } diff --git a/examples/quantize/quantize_impl.cpp b/examples/quantize/quantize_impl.cpp new file mode 100644 index 0000000..7b9484c --- /dev/null +++ b/examples/quantize/quantize_impl.cpp @@ -0,0 +1,275 @@ +#include +#include "quantize_impl.h" + +#include +#include + +#include "util.h" + +namespace { +bool kokoro_is_f16_compatible(std::string name) { + return name.find("voice_tensors") == std::string::npos && + name.find("bias") == std::string::npos && + name.find("gamma") == std::string::npos && + name.find("beta") == std::string::npos && + name.find("alpha") == std::string::npos && + !has_suffix(name, "embd") && + !has_suffix(name, "norm"); +} + +bool kokoro_is_quantizable(str name) { + // A list of all of the top level GGUF names under kokoro.duration_predictor that have quantization compatible tensors. + constexpr std::array DURATION_PREDICTOR_QUANTIZATION_COMPATIBLE_PARTS = { + "duration_proj", + "encode", + "shared_lstm", + "duration_lstm", + "layers" + }; + if (kokoro_is_f16_compatible(name)) { + if (has_prefix(name, "kokoro.albert") || has_prefix(name, "kokoro.text_encoder.lstm")) { + return true; + } else if (has_prefix(name, "kokoro.duration_predictor.")) { + std::vector parts = split(name, "."); + for (std::string part : DURATION_PREDICTOR_QUANTIZATION_COMPATIBLE_PARTS) { + if (part == parts[2]) { + return true; + } + } + } + } + return false; +} + +bool dia_is_quantizable(str name, const quantization_params & params) { + // The DAC audio encoder / decoder is not compatible with quantization and normalization tensors should not be quantized. + bool quantizable = !has_prefix(name, "audio_encoder") && !has_suffix(name, "norm"); + if (!params.quantize_output_heads) { + quantizable = quantizable && !has_prefix(name, "dia.decoder.heads"); + } + return quantizable; +} + +bool parler_is_quanitizable(str name, const quantization_params & params) { + // the DAC audio encoder / decoder is not compatible with quantization, normalization weight shouldn't be quantized, and the text encoding shouldn't be normalized. + bool quantizable = !has_prefix(name, "audio_encoder") && !has_suffix(name, "norm.weight") && !has_suffix(name, "text_encoding") && !has_suffix(name, "positional_embed") && !has_suffix(name, "norm.bias"); + if (!params.quantize_output_heads) { + quantizable = quantizable && !has_suffix(name, "weight.head"); + } + if (!params.quantize_text_embeddings) { + quantizable = quantizable && !has_suffix(name, "embed_prompts"); + } + if (!params.quantize_cross_attn_kv) { + quantizable = quantizable && !has_suffix(name, "encoder_attn.k_proj.weight") && !has_suffix(name, "encoder_attn.v_proj.weight"); + } + return quantizable; +} + +bool is_quantizable(tts_arch arch, str name, const quantization_params & params) { + switch(arch) { + case PARLER_TTS_ARCH: + return parler_is_quanitizable(name, params); + case DIA_ARCH: + return dia_is_quantizable(name, params); + case KOKORO_ARCH: + return kokoro_is_quantizable(name); + default: + TTS_ABORT("%s failed. The architecture '%d' is not supported.", __func__, arch); + } +} + +size_t quantize_tensor(void * new_data, struct ggml_tensor * tensor, const float * imatrix, enum ggml_type qtype, uint32_t n_threads) { + // much of this is form copied from llama.cpp + int chunk_size_multiplier = 1; + if (qtype == GGML_TYPE_Q4_0_4_4 || qtype == GGML_TYPE_Q4_0_4_8 || qtype == GGML_TYPE_Q4_0_8_8) { + if ((qtype == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) qtype = GGML_TYPE_Q4_0; + else if (tensor->ne[1] % 4 != 0) qtype = GGML_TYPE_Q4_0; + if (qtype == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8; + else if (qtype == GGML_TYPE_Q4_0_4_4 || qtype == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4; + } + size_t out_size = 0; + const int32_t d3_step = tensor->ne[0] * tensor->ne[1]; + const int32_t n_per_row = tensor->ne[0]; + const int32_t nrows = tensor->ne[1]; + static const int32_t min_chunk_size = 32 * 512; + const int32_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) * chunk_size_multiplier; + uint32_t thread_count = std::max(1, std::min((int)n_threads, (int)(d3_step + chunk_size - 1) / chunk_size)); + std::mutex mutex; + + for (int32_t d3_index = 0; d3_index < tensor->ne[2]; d3_index++) { + const float * f32_data_d3 = ((float *) tensor->data) + d3_index * d3_step; + void * new_data_d3 = (char *)new_data + ggml_row_size(qtype, tensor->ne[0]) * d3_index * nrows; + const float * imatrix_03 = imatrix ? imatrix + d3_index * tensor->ne[0] : nullptr; + if (thread_count <= 1) { + // not threaded + out_size += ggml_quantize_chunk(qtype, f32_data_d3, new_data_d3, 0, nrows, n_per_row, imatrix); + } else { + std::vector threads; + int64_t counter = 0; + size_t new_size = 0; + bool valid = true; + for (uint32_t t = 0; t < thread_count; t++) { + auto func = [&mutex, &counter, &new_size, &valid, qtype, f32_data_d3, new_data_d3, chunk_size, nrows, n_per_row, imatrix]() { + const int64_t nrows_per_chunk = chunk_size / n_per_row; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + int64_t first_row = counter; + counter += nrows_per_chunk; + if (first_row >= nrows) { + if (local_size > 0) { + new_size += local_size; + } + break; + } + lock.unlock(); + const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); + size_t this_size = ggml_quantize_chunk(qtype, f32_data_d3, new_data_d3, first_row * n_per_row, this_nrow, n_per_row, imatrix); + local_size += this_size; + + // validate the quantized data; I am not sure how this would occur, but there is always the safe fallback on doing this single threaded. + const size_t row_size = ggml_row_size(qtype, n_per_row); + void * this_data = (char *) new_data_d3 + first_row * row_size; + if (!ggml_validate_row_data(qtype, this_data, this_size)) { + std::unique_lock lock(mutex); + valid = false; + break; + } + } + }; + threads.push_back(std::thread(func)); + } + for (auto & t : threads) t.join(); + + if (!valid) { + TTS_ABORT("Validation of quantized data failed. Please try again and/or switch to single thread quantization.\n"); + } + out_size += new_size; + } + } + return out_size; +} + +void zeros(std::ofstream & file, size_t n) { + char zero = 0; + for (size_t i = 0; i < n; ++i) { + file.write(&zero, 1); + } +} + +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; +} + +void quantize_gguf(str ifile, str ofile, const quantization_params & params) { + ggml_context * weight_ctx = NULL; + struct gguf_init_params gguf_params = { + /*.no_alloc =*/ false, + /*.ctx =*/ &weight_ctx, + }; + gguf_context * meta_ctx = gguf_init_from_file(ifile, gguf_params); + str arch = "parler-tts"; // only parler-tts gguf files should lack an explicit architecture. + + if (int arch_key = gguf_find_key(meta_ctx, "general.architecture"); arch_key != -1) { + arch = gguf_get_val_str(meta_ctx, arch_key); + } + const tts_arch arch_type{parse_arch_type(ifile, arch)}; + + switch (params.quantize_type) { + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + break; + default: + fprintf(stdout, "Warning, %s is untested for quantization type '%d'. Use at your own risk.\n", arch, params.quantize_type); + } + + gguf_context_ptr ctx_out { gguf_init_empty() }; + + // copy the KV pairs from the input file + gguf_set_kv(ctx_out.get(), meta_ctx); + gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); + gguf_set_val_u32(ctx_out.get(), "general.quantization_type", params.quantize_type); + for (ggml_tensor * tensor = ggml_get_first_tensor(weight_ctx); tensor; tensor = ggml_get_next_tensor(weight_ctx, tensor)) { + if (*ggml_get_name(tensor)) { + gguf_add_tensor(ctx_out.get(), tensor); + } + } + + std::vector> work; + + std::ofstream fout; + auto close_ofstream = [&]() { + // Write metadata and close file handler + if (fout.is_open()) { + fout.seekp(0); + std::vector data(gguf_get_meta_size(ctx_out.get())); + gguf_get_meta_data(ctx_out.get(), data.data()); + fout.write((const char *) data.data(), data.size()); + fout.close(); + } + }; + auto new_ofstream = [&]() { + std::string fname = ofile; + fout = std::ofstream(fname, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + const size_t meta_size = gguf_get_meta_size(ctx_out.get()); + // placeholder for the meta data + ::zeros(fout, meta_size); + }; + new_ofstream(); + for (ggml_tensor * cur = ggml_get_first_tensor(weight_ctx); cur; cur = ggml_get_next_tensor(weight_ctx, cur)) { + const size_t align = GGUF_DEFAULT_ALIGNMENT; + ggml_type new_type; + void * new_data; + size_t new_size; + str name = ggml_get_name(cur); + + if (!*name) { + continue; + } + + if (is_quantizable(arch_type, name, params)) { + if ((cur->type) != GGML_TYPE_F32) { + TTS_ABORT("ERROR: All quantized tensors must be transformed from 32bit floats. Tensor, '%s', has improper type, '%d'\n", cur->name, cur->type); + } + new_type = params.quantize_type; + if ((new_type >= GGML_TYPE_IQ2_XXS && new_type <= GGML_TYPE_IQ4_XS)) { + TTS_ABORT("ERROR: Quantization type '%d' requires an importance matrix.\n", new_type); + } + const int64_t nelement_size = ggml_nelements(cur) * 4; + if (work.size() < (size_t)nelement_size) { + work.resize(nelement_size); // upper bound on size + } + new_data = work.data(); + new_size = quantize_tensor(new_data, cur, nullptr, new_type, params.n_threads); + } else if ((params.convert_non_quantizable_to_f16 && kokoro_is_f16_compatible(name)) || (params.convert_dac_to_f16 && has_prefix(name, "audio_encoder") && !has_suffix(name, "alpha"))) { + if ((cur->type) != GGML_TYPE_F32) { + TTS_ABORT("ERROR: All converted tensors must be transformed from 32bit floats. Tensor, '%s', has improper type, '%d'\n", cur->name, cur->type); + } + new_type = GGML_TYPE_F16; + const int64_t nelement_size = ggml_nelements(cur) * 4; + if (work.size() < (size_t)nelement_size) { + work.resize(nelement_size); // upper bound on size + } + new_data = work.data(); + new_size = quantize_tensor(new_data, cur, nullptr, new_type, params.n_threads); + } else { + new_type = cur->type; + new_data = cur->data; + new_size = ggml_nbytes(cur); + } + + gguf_set_tensor_type(ctx_out.get(), name, new_type); + gguf_set_tensor_data(ctx_out.get(), name, new_data, new_size); + fprintf(stdout, "At tensor: '%s' with new size: %zu bytes\n", name, new_size); + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } + close_ofstream(); +} diff --git a/examples/quantize/quantize_impl.h b/examples/quantize/quantize_impl.h new file mode 100644 index 0000000..67b5ae8 --- /dev/null +++ b/examples/quantize/quantize_impl.h @@ -0,0 +1,16 @@ +#pragma once + +#include "ggml.h" +#include "common.h" + +struct quantization_params { + uint32_t n_threads; + ggml_type quantize_type; + bool quantize_output_heads; + bool quantize_text_embeddings; + bool quantize_cross_attn_kv; + bool convert_dac_to_f16; + bool convert_non_quantizable_to_f16; +}; + +void quantize_gguf(str ifile, str ofile, const quantization_params & params); diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index e9c60f0..704851a 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -34,9 +34,7 @@ endforeach() add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) -target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE ggml tts) - +target_link_libraries(${TARGET} PRIVATE examples_common) if (LLAMA_SERVER_SSL) find_package(OpenSSL REQUIRED) @@ -47,5 +45,3 @@ endif() if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() - -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/server/README.md b/examples/server/README.md index 6f30f15..d0abd9f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -6,45 +6,44 @@ This script runs a simple restful HTTP server which supports an OpenAI like `/v1 In order to get a detailed breakdown of the functionality currently available you can call the tts-server with the `--help` parameter. This will return a breakdown of all parameters: -```bash -./build/bin/tts-server --help - ---temperature (-t): - (OPTIONAL) The temperature to use when generating outputs. Defaults to 1.0. ---repetition-penalty (-r): - The by channel repetition penalty to be applied the sampled output of the model. defaults to 1.0. ---top-p (tp): - (OPTIONAL) the default sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0. ---topk (-tk): - (OPTIONAL) when set to an integer value greater than 0 generation uses nucleus sampling over topk nucleaus size. Defaults to 50. +```console +$ ./tts-server --help +--default-model (-dm): + (OPTIONAL) The default model to use when multiple models (a directory with multiple GGUF files) are provided. This can be set by giving the path to the model (./models/Kokoro_no_espeak.gguf), the filename (Kokoro_no_espeak.gguf), or the model ID itself (Kokoro_no_espeak). +--espeak-voice-id (-eid): + (OPTIONAL) The eSpeak voice id to use for phonemization. This should only be specified when the correct eSpeak voice cannot be inferred from the Kokoro voice. See MultiLanguage Configuration in the README for more info. +--host (-h): + (OPTIONAL) The hostname of the server. Defaults to 127.0.0.1. +--max-tokens (-mt): + (OPTIONAL) The max audio tokens or token batches to generate where each represents approximates 11 ms of audio. Only applied to Dia generation. If set to zero as is its default then the default max generation size. Warning values under 15 are not supported. +--model-path (-mp): + (REQUIRED) The local path of the gguf model(s) to load. +--n-http-threads (-ht): + (OPTIONAL) The number of http threads to use. Defaults to hardware concurrency minus 1. +--n-parallelism (-np): + (OPTIONAL) The number of parallel models to run asynchronously. Defaults to 1. --n-threads (-nt): - The number of cpu threads to run generation with. Defaults to hardware concurrency. + (OPTIONAL) The number of CPU threads to run calculations with. Defaults to known hardware concurrency. If hardware concurrency cannot be determined then it defaults to 1. +--no-cross-attn (-ca): + (OPTIONAL) Whether to not include cross attention. --port (-p): (OPTIONAL) The port to use. Defaults to 8080. ---n-http-threads (-ht): - (OPTIONAL) The number of http threads to use. Defaults to hardware concurrency minus 1. +--repetition-penalty (-r): + (OPTIONAL) The per-channel repetition penalty to be applied the sampled output of the model. +--temperature (-t): + (OPTIONAL) The temperature to use when generating outputs. +--text-encoder-path (-tep): + (OPTIONAL) The local path of the text encoder gguf model for conditional generation. --timeout (-t): (OPTIONAL) The server side timeout on http calls in seconds. Defaults to 300 seconds. ---n-parallelism (-np): - (OPTIONAL) the number of parallel models to run asynchronously. Deafults to 1. +--top-p (-mt): + (OPTIONAL) The sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0. +--topk (-tk): + (OPTIONAL) When set to an integer value greater than 0 generation uses nucleus sampling over topk nucleus size. Defaults to 50. --use-metal (-m): - (OPTIONAL) Whether to use metal acceleration ---no-cross-attn (-ca): - (OPTIONAL) Whether to not include cross attention ---model-path (-mp): - (REQUIRED) The local path of the gguf model file for Parler TTS mini or large v1, Dia, or Kokoro. ---text-encoder-path (-tep): - (OPTIONAL) The local path of the text encoder gguf model for conditional generaiton. ---ssl-file-cert (-sfc): - (OPTIONAL) The local path to the PEM encoded ssl cert. ---ssl-file-key (-sfk): - (OPTIONAL) The local path to the PEM encoded ssl private key. ---host (-h): - (OPTIONAL) the hostname of the server. Defaults to '127.0.0.1'. + (OPTIONAL) Whether to use metal acceleration. --voice (-v): - (OPTIONAL) the default voice to use when generating audio. Only used with applicable models. ---espeak-voice-id (-eid): - (OPTIONAL) The espeak voice id to use for phonemization. This should only be specified when the correct espeak voice cannot be inferred from the kokoro voice (see #MultiLanguage Configuration in the cli README for more info). + (OPTIONAL) The voice to use to generate the audio. This is only used for models with voice packs. ``` Important configuration here includes `--n-parallelism` which describes how may models for asynchronous processing and `--model-path` which describes from where to load the model locally. diff --git a/examples/server/public/index.html b/examples/server/public/index.html index ffaa29c..d5bc84f 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -567,14 +567,9 @@

TTS.cpp Server API

if (!response.ok) { let errorMessage = `API error: ${response.status} ${response.statusText}`; try { - const errorData = await response.json(); - if (errorData?.error?.message) { - errorMessage = `API error: ${errorData.error.message}`; - } else if (typeof errorData === 'string') { - errorMessage = `API error: ${errorData}`; - } + const errorData = await response.text(); + errorMessage = `API error: ${errorData}`; } catch (jsonError) { - console.error('Failed to parse error response:', jsonError); } throw new Error(errorMessage); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 92aae79..7e7407d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2,7 +2,6 @@ #include "ggml.h" #include "util.h" #include -#include #include #define JSON_ASSERT GGML_ASSERT #include "json.hpp" @@ -11,452 +10,274 @@ #define MIMETYPE_AIFF "audio/aiff" #define MIMETYPE_JSON "application/json; charset=utf-8" #define MIMETYPE_HTML "text/html; charset=utf-8" +#define MIMETYPE_TXT "text/plain" #include #include -#include -#include #include -#include #include -#include +#include #include #include #include -#include #include #include "tts.h" #include "audio_file.h" -#include "args.h" -#include "common.h" +#include "args_common.h" #include "tts_server_threading_osx.h" #include "index.html.hpp" -enum server_state { - LOADING, // Server is starting up / model loading - READY, // Server is ready -}; - -// These are form copied from llama.cpp which copied them from openAI chat: -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -// In testing, openAI TTS endpoints make use of the same behavior. -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, // not currently supported as auth keys are not built in yet - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, // not currently supported as auth keys are not built in yet - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; - -enum task_type { - TTS, - CONDITIONAL_PROMPT, -}; - +namespace { using json = nlohmann::ordered_json; -template -static T json_value(const json & body, const std::string & key, const T & default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { - fprintf(stderr, "Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); - return default_value; - } - } else { - return default_value; - } +void res_ok_json_str(httplib::Response & res, str output) { + res.set_content(output, MIMETYPE_JSON); + res.status = 200; } -bool write_audio_data(float * data, size_t length, std::vector & output, AudioFileFormat format = AudioFileFormat::Wave, float sample_rate = 44100.f, float frequency = 440.f, int channels = 1) { - AudioFile file; - file.setBitDepth(16); - file.setSampleRate(sample_rate); - file.setNumChannels(channels); - int samples = (int) (length / channels); - file.setNumSamplesPerChannel(samples); - for (int channel = 0; channel < channels; channel++) { - for (int i = 0; i < samples; i++) { - file.samples[channel][i] = data[i]; - } - } - return file.writeData(output, format); +string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); } -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - if (req.path == "/v1/health") { - return; - } - - fprintf(stdout, "request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); +void res_ok_audio(httplib::Response & res, const vector & audio, str mime_type) { + res.set_content(reinterpret_cast(audio.data()), audio.size(), mime_type); + res.status = 200; } -struct simple_text_prompt_task { - simple_text_prompt_task(task_type task, std::string prompt): task(task), prompt(prompt) { - id = rand(); - time = std::chrono::steady_clock::now(); - } - - task_type task; - int id; - std::string prompt; - generation_configuration * gen_config; - void * response; - size_t length; - bool success = false; - std::string message; - std::chrono::time_point time; - float sample_rate = 44100.0f; - std::string model; - - bool timed_out(int t) { - auto now = std::chrono::steady_clock::now(); - std::chrono::duration> duration = now - time; - return (int) duration.count() > t; - } -}; - -struct simple_task_queue { - std::mutex rw_mutex; - std::condition_variable condition; - std::deque queue; - bool running = true; - - struct simple_text_prompt_task * get_next() { - struct simple_text_prompt_task * resp; - std::unique_lock lock(rw_mutex); - condition.wait(lock, [&]{ - return !queue.empty() || !running; - }); - if (!running) { - return nullptr; - } - resp = queue.front(); - queue.pop_front(); - lock.unlock(); - return resp; - } +void res_error(httplib::Response & res, str err) { + res.set_content(err, MIMETYPE_TXT); + res.status = 500; +} - void terminate() { - std::lock_guard lock(rw_mutex); - running = false; - condition.notify_all(); +class simple_task_queue; + +class simple_text_prompt_task { + mutex condition_mutex{}; + condition_variable condition{}; + friend simple_task_queue; +public: + str prompt{""}; + str conditional_prompt{""}; + str model{""}; + AudioFileFormat format{}; + generation_configuration gen_config{}; + atomic> time{}; + + vector response; + bool success{}; + atomic locked_by_worker{}; + + bool timed_out(int cleanup_timeout) const { + const auto start{time.load(memory_order_relaxed)}; + return chrono::duration_cast(chrono::steady_clock::now() - start).count() > cleanup_timeout; } - void push(struct simple_text_prompt_task * task) { - std::lock_guard lock(rw_mutex); - queue.push_back(task); + void respond() { + lock_guard lock{condition_mutex}; + locked_by_worker.store(false); condition.notify_one(); } }; -struct simple_response_map { - std::mutex rw_mutex; - std::condition_variable updated; - int cleanup_timeout = 300; - std::atomic running = true; - std::thread * cleanup_thread; - - std::map completed; - - void cleanup_routine() { - std::unique_lock lock(rw_mutex); - while(true) { - updated.wait(lock, [&]{ - return completed.size() > 100 || !running; - }); - if (!running) { - return; - } - auto now = std::chrono::steady_clock::now(); - std::vector deletable; - for (auto const& [key, task] : completed) { - if (task->timed_out(cleanup_timeout)) { - deletable.push_back(key); - } +struct worker; + +class simple_task_queue { + mutex rw_mutex{}; + condition_variable condition{}; + deque> queue{}; +public: + vector> workers{}; + atomic running{true}; + atomic startup_fence{}; + int cleanup_timeout{300}; + str text_encoder_path{""}; + + shared_ptr get_next() { + unique_lock lock(rw_mutex); + condition.wait(lock, [&] { + return !queue.empty() || !running.load(); + }); + if (!running.load()) { + return {}; + } + do { + shared_ptr result = queue.front().lock(); + queue.pop_front(); + if (!result) { + continue; } - for (auto const id : deletable) { - completed.erase(id); + if (result->timed_out(cleanup_timeout)) { + result->respond(); + continue; } - } - } - - void terminate() { - std::lock_guard lock(rw_mutex); - running = false; - updated.notify_all(); - } - - void push(struct simple_text_prompt_task * task) { - std::unique_lock lock(rw_mutex); - completed[task->id] = task; - lock.unlock(); - updated.notify_all(); + return result; + } while (!queue.empty()); + return {}; } - struct simple_text_prompt_task * get(int id) { - std::unique_lock lock(rw_mutex); - struct simple_text_prompt_task * resp = nullptr; - try { - return completed.at(id); - } catch (const std::out_of_range& e) { - updated.wait(lock, [&]{ - return completed.find(id) != completed.end() || !running; - }); - if (!running) { - return nullptr; + void terminate(); + + void request(shared_ptr & task) { + unique_lock lock{task->condition_mutex}; + task->time.store(chrono::steady_clock::now(), memory_order_relaxed); + { + unique_lock lock2{rw_mutex}; + task->response.clear(); + task->success = false; + } + task->locked_by_worker.store(true, memory_order_relaxed); + { + lock_guard lock2(rw_mutex); + queue.emplace_back(task); + condition.notify_one(); + } + do { + if (condition.wait_for(lock, chrono::seconds(1), [&] { + return !task->locked_by_worker.load() || !running.load(); + })) { + return; } - return completed.at(id); - } + } while (!task->timed_out(cleanup_timeout)); } }; -void init_response_map(simple_response_map * rmap) { - rmap->cleanup_routine(); -} - struct worker { - worker(struct simple_task_queue * task_queue, struct simple_response_map * response_map, std::string text_encoder_path = "", int task_timeout = 300): task_queue(task_queue), response_map(response_map), text_encoder_path(text_encoder_path), task_timeout(task_timeout) {}; - ~worker() { - for (auto &[_, runner]: runners) { - delete runner; - } + worker(simple_task_queue & q, const arg_list & args, const unordered_map & model_map) + : q{q}, args{args}, model_map{model_map} { } - struct simple_task_queue * task_queue; - struct simple_response_map * response_map; + reference_wrapper q; + reference_wrapper args; + reference_wrapper> model_map; - std::unordered_map runners; - std::string text_encoder_path; - std::atomic running = true; - tts_server_threading::native_thread * thread = nullptr; + unordered_map> runners{}; + tts_server_threading::native_thread worker_thread{}; - int task_timeout; + void loop() { + const arg_list & args_ = args.get(); + const int n_threads{args_["n-threads"]}; + const generation_configuration startup_config{parse_generation_config(args_)}; + const bool cpu_only{!args_["use-metal"]}; - void terminate() { - running = false; - } + for (const auto & [id, path]: model_map.get()) { + runners[id].reset(runner_from_file(path.c_str(), n_threads, startup_config, cpu_only)); + } + q.get().startup_fence.fetch_sub(1, memory_order_acq_rel); - void loop() { - while (running) { - struct simple_text_prompt_task * task = task_queue->get_next(); - if (task) { - process_task(task); + while (q.get().running.load()) { + if (shared_ptr const task{q.get().get_next()}) { + process_task(*task); + task->respond(); } } } - void process_task(struct simple_text_prompt_task * task) { - if (task->timed_out(task_timeout)) { - return; + void process_task(simple_text_prompt_task & task) { + tts_runner * runner = &*runners[task.model]; + if (*task.conditional_prompt) { + TTS_ASSERT(*q.get().text_encoder_path); + update_conditional_prompt(runner, q.get().text_encoder_path, task.conditional_prompt); } - int outcome; - tts_response * data = nullptr; - tts_runner* runner = runners[task->model]; - switch(task->task) { - case TTS: - data = new tts_response; - outcome = generate(runner, task->prompt, data, task->gen_config); - task->response = (void*) data->data; - task->length = data->n_outputs; - task->sample_rate = runner->sampling_rate; - task->success = outcome == 0; - response_map->push(task); - break; - case CONDITIONAL_PROMPT: - if (text_encoder_path.size() == 0) { - task->message = "A text encoder path must be specified on server initialization in order to support conditional prompting."; - response_map->push(task); - break; - } - update_conditional_prompt(runner, text_encoder_path, task->prompt); - task->success = true; - response_map->push(task); - break; + tts_response data; + task.success = !generate(runner, task.prompt, data, task.gen_config); + if (!task.success) { + return; } - } -}; -void init_worker(std::unordered_map* model_path, int n_threads, bool cpu_only, generation_configuration * config, worker * w) { - for (const auto &[id, path] : *model_path) { - w->runners[id] = runner_from_file(path, n_threads, config, cpu_only); + AudioFile file{}; + file.setSampleRate(runner->sampling_rate); + file.samples[0] = vector(data.data, data.data + data.n_outputs); + const bool write_audio_data_result{file.writeData(task.response, task.format)}; + TTS_ASSERT(write_audio_data_result); } - w->loop(); -} - -typedef std::vector worker_pool; +}; -void terminate(worker_pool * pool) { - for (auto w : *pool) { - w->terminate(); - } - if (pool->size() > 0) { - (*pool)[0]->task_queue->terminate(); - (*pool)[0]->response_map->terminate(); +void simple_task_queue::terminate() { + if (workers.empty()) { + return; } -} - -void complete(worker_pool * pool) { - for (auto w : *pool) { - if (w->thread) { - w->thread->join(); - } - delete w; + { + lock_guard lock{rw_mutex}; + running.store(false); + condition.notify_all(); } -} - -static std::string safe_json_to_str(json data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - -// this function maybe used outside of server_task_result_error -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; + for (const auto & w : workers) { + w->worker_thread.join(); } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; + workers.clear(); } -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; +std::function shutdown_handler; -inline void signal_handler(int signal) { +void signal_handler(int /*signal*/) { + static atomic_flag is_terminating{}; if (is_terminating.test_and_set()) { // in case it hangs, we can force terminate the server by hitting Ctrl+C twice // this is for better developer experience, we can remove when the server is stable enough fprintf(stderr, "Received second interrupt, terminating immediately.\n"); exit(1); } - - shutdown_handler(signal); + shutdown_handler(); +} } int main(int argc, const char ** argv) { - int default_n_threads = std::max((int)std::thread::hardware_concurrency(), 1); - int default_http_threads = std::max((int)std::thread::hardware_concurrency() - 1, 3); - int default_n_parallel = 1; - int default_port = 8080; - int default_timeout = 300; - std::string default_host = "127.0.0.1"; - float default_temperature = 1.0f; - int default_top_k = 50; - float default_repetition_penalty = 1.0f; - float default_top_p = 1.0f; - - arg_list args; - args.add_argument(float_arg("--temperature", "(OPTIONAL) The temperature to use when generating outputs. Defaults to 1.0.", "-t", false, &default_temperature)); - args.add_argument(int_arg("--topk", "(OPTIONAL) when set to an integer value greater than 0 generation uses nucleus sampling over topk nucleaus size. Defaults to 50.", "-tk", false, &default_top_k)); - args.add_argument(float_arg("--repetition-penalty", "The by channel repetition penalty to be applied the sampled output of the model. defaults to 1.0.", "-r", false, &default_repetition_penalty)); - args.add_argument(string_arg("--model-path", "(REQUIRED) The local path of the gguf model file or a directory containing only gguf model files for Parler TTS mini or large v1, Dia, or Kokoro.", "-mp", true)); - args.add_argument(string_arg("--default-model", "(OPTIONAL) The default model to use when multiple models (a directory with multiple GGUF files) are provided. This can be set by giving the path to the model (./models/Kokoro_no_espeak.gguf), the filename (Kokoro_no_espeak.gguf), or the model ID itself (Kokoro_no_espeak).", "-dm", false)); - args.add_argument(int_arg("--n-threads", "The number of cpu threads to run generation with. Defaults to hardware concurrency.", "-nt", false, &default_n_threads)); - args.add_argument(bool_arg("--use-metal", "(OPTIONAL) Whether to use metal acceleration", "-m")); - args.add_argument(bool_arg("--no-cross-attn", "(OPTIONAL) Whether to not include cross attention", "-ca")); - args.add_argument(string_arg("--text-encoder-path", "(OPTIONAL) The local path of the text encoder gguf model for conditional generaiton.", "-tep", false)); - args.add_argument(string_arg("--ssl-file-cert", "(OPTIONAL) The local path to the PEM encoded ssl cert.", "-sfc", false)); - args.add_argument(string_arg("--ssl-file-key", "(OPTIONAL) The local path to the PEM encoded ssl private key.", "-sfk", false)); - args.add_argument(int_arg("--port", "(OPTIONAL) The port to use. Defaults to 8080.", "-p", false, &default_port)); - args.add_argument(string_arg("--host", "(OPTIONAL) the hostname of the server. Defaults to '127.0.0.1'.", "-h", false, default_host)); - args.add_argument(int_arg("--n-http-threads", "(OPTIONAL) The number of http threads to use. Defaults to hardware concurrency minus 1.", "-ht", false, &default_http_threads)); - args.add_argument(int_arg("--timeout", "(OPTIONAL) The server side timeout on http calls in seconds. Defaults to 300 seconds.", "-t", false, &default_timeout)); - args.add_argument(int_arg("--n-parallelism", "(OPTIONAL) the number of parallel models to run asynchronously. Deafults to 1.", "-np", false, &default_n_parallel)); - args.add_argument(string_arg("--voice", "(OPTIONAL) the default voice to use when generating audio. Only used with applicable models.", "-v", false, "af_alloy")); - args.add_argument(string_arg("--espeak-voice-id", "(OPTIONAL) The espeak voice id to use for phonemization. This should only be specified when the correct espeak voice cannot be inferred from the kokoro voice (see #MultiLanguage Configuration in the cli README for more info).", "-eid", false)); - args.add_argument(float_arg("--top-p", "(OPTIONAL) the default sum of probabilities to sample over. Must be a value between 0.0 and 1.0. Defaults to 1.0.", "-tp", false, &default_top_p)); - + simple_task_queue q{}; + arg_list args{}; + add_common_args(args); + args.add({ + "", "default-model", "dm", + "The default model to use when multiple models (a directory with multiple GGUF files) are provided. " + "This can be set by giving the path to the model (./models/Kokoro_no_espeak.gguf), " + "the filename (Kokoro_no_espeak.gguf), or the model ID itself (Kokoro_no_espeak)" + }); + add_text_encoder_arg(args); + args.add({8080, "port", "p", "The port to use. Defaults to 8080"}); + args.add({"127.0.0.1", "host", "h", "The hostname of the server. Defaults to 127.0.0.1"}); + args.add({ + max(static_cast(thread::hardware_concurrency()) - 1, 3), "n-http-threads", "ht", + "The number of http threads to use. Defaults to hardware concurrency minus 1" + }); + args.add({300, "timeout", "t", "The server side timeout on http calls in seconds. Defaults to 300 seconds"}); + args.add({1, "n-parallelism", "np", "The number of parallel models to run asynchronously. Defaults to 1"}); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + args.add({"", "ssl-file-cert", "sfc", "The local path to the PEM encoded SSL certificate"}); + args.add({"", "ssl-file-key", "sfk", "The local path to the PEM encoded SSL private key"}); +#endif args.parse(argc, argv); - if (args.for_help) { - args.help(); - return 0; - } - args.validate(); - - if (*args.get_float_param("--top-p") > 1.0f || *args.get_float_param("--top-p") <= 0.0f) { - fprintf(stderr, "The '--top-p' value must be between 0.0 and 1.0. It was set to '%.6f'.\n", *args.get_float_param("--top-p")); - exit(1); - } + q.startup_fence.store(args["n-parallelism"], memory_order_relaxed); + q.cleanup_timeout = args["timeout"]; + q.text_encoder_path = args["text-encoder-path"]; - generation_configuration * default_generation_config = new generation_configuration( - args.get_string_param("--voice"), - *args.get_int_param("--topk"), - *args.get_float_param("--temperature"), - *args.get_float_param("--repetition-penalty"), - !args.get_bool_param("--no-cross-attn"), - args.get_string_param("--espeak-voice-id"), - 0, - *args.get_float_param("--top-p")); - - worker_pool * pool = nullptr; - struct simple_task_queue * tqueue = new simple_task_queue; - struct simple_response_map * rmap = new simple_response_map; - - bool conditional_prompt_viable = args.get_string_param("--text-encoder-path").size() > 0 && *args.get_int_param("--n-parallelism") <= 1; - - std::unique_ptr svr; + unique_ptr svr; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (args.get_string_param("--ssl-file-cert") != "" && args.get_string_param("--ssl-file-key") != "") { - fprintf(stdout, "Running with SSL: key = %s, cert = %s\n", args.get_string_param("--ssl-file-key").c_str(), args.get_string_param("--ssl-file-cert").c_str()); - svr.reset(new httplib::SSLServer(args.get_string_param("--ssl-file-key").c_str(), args.get_string_param("--ssl-file-cert").c_str())); - } else { - fprintf(stdout, "Running without SSL\n"); - svr.reset(new httplib::Server()); + { + const str cert{args["ssl-file-cert"]}; + const str key{args["ssl-file-key"]}; + if (*cert) { + TTS_ASSERT(*key); + fprintf(stdout, "Running with SSL: key = %s, cert = %s\n", key, cert); + svr = make_unique(key, cert); + } else { + TTS_ASSERT(!*key); + fprintf(stdout, "Running without SSL\n"); + svr = make_unique(); + } } #else - if (args.get_string_param("--ssl-file-cert") != "" && args.get_string_param("--ssl-file-key") != "") { - fprintf(stderr, "Server is built without SSL support\n"); - return 1; - } - svr.reset(new httplib::Server()); + svr = make_unique(); #endif std::unordered_map model_map = {}; - const std::string model_path = args.get_string_param("--model-path"); - if (std::filesystem::is_directory(model_path)) { + if (const str model_path{args["model-path"]}; filesystem::is_directory(model_path)) { for (auto const &entry : std::filesystem::directory_iterator(model_path)) { if (!entry.is_directory() && entry.path().extension() == ".gguf") { const std::string id = entry.path().stem(); model_map[id] = entry.path().string(); } } - if (model_map.size() == 0) { - fprintf(stderr, "No model found in directory %s", model_path.c_str()); + if (model_map.empty()) { + fprintf(stderr, "No model found in directory %s", model_path); return 1; } } else { @@ -464,96 +285,49 @@ int main(int argc, const char ** argv) { model_map[path.stem()] = path; } - auto model_creation = std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(); - - std::string default_model = ""; - if (args.get_string_param("--default-model") != "") { - const std::string model = std::filesystem::path { args.get_string_param("--default-model") }.stem(); - if (model_map.contains(model)) { - default_model = model; + str default_model{args["default-model"]}; + if (*default_model) { + const string model{filesystem::path{default_model}.stem()}; + if (auto found = model_map.find(model); found != model_map.end()) { + default_model = found->first.c_str(); } else { fprintf(stderr, "Invalid Default Model Provided: %s", model.c_str()); return 1; } } else { - default_model = model_map.begin()->first; - } - - std::vector models = {}; - for (const auto &[id, _] : model_map) { - json model = {{"id", ""}, - {"object", "model"}, - {"created", 0}, - {"owned_by", "tts.cpp"}}; - model["id"] = id; - model["created"] = model_creation; - models.push_back(model); + default_model = model_map.begin()->first.c_str(); } - const json models_json = {{"object", "list"}, {"data", models}}; - - std::atomic state{LOADING}; - - svr->set_logger(log_server_request); - - auto res_error = [](httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); - }; - auto res_ok_html = [](httplib::Response & res, const char * const & data) { - res.set_content(data, MIMETYPE_HTML); - res.status = 200; - }; - - auto res_ok_json = [](httplib::Response & res, const json & data) { - res.set_content(safe_json_to_str(data), MIMETYPE_JSON); - res.status = 200; - }; - - auto res_ok_audio = [](httplib::Response & res, const std::vector & audio, std::string mime_type) { - res.set_content((char*)audio.data(), audio.size(), mime_type); - res.status = 200; - }; - - svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { - std::string message; - try { - std::rethrow_exception(ep); - } catch (const std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; + const string models_json_output{[&model_map] { + vector models = {}; + const auto model_creation{chrono::system_clock::now().time_since_epoch().count()}; + for (const auto & id: model_map | views::keys) { + json model{ + {"id", id}, + {"object", "model"}, + {"created", model_creation}, + {"owned_by", "tts.cpp"} + }; + models.push_back(model); + } + return safe_json_to_str({{"object", "list"}, {"data", models}}); + }()}; + + svr->set_logger([](const httplib::Request & req, const httplib::Response & res) { + if (req.path == "/v1/health") { + return; } - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - fprintf(stderr, "got exception: %s\n", formatted_error.dump().c_str()); - res_error(res, formatted_error); - }); - - svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); - } + fprintf(stdout, "request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); }); // set timeouts and change hostname and port - svr->set_read_timeout(*args.get_int_param("--timeout")); - svr->set_write_timeout(*args.get_int_param("--timeout")); - - auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - if (current_state == LOADING) { - res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - return false; - } - return true; - }; + const int timeout{args["timeout"]}; + svr->set_read_timeout(timeout); + svr->set_write_timeout(timeout); // register server middlewares - svr->set_pre_routing_handler([&middleware_server_state](const httplib::Request & req, httplib::Response & res) { + svr->set_pre_routing_handler([&q](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); // If this is OPTIONS request, skip validation because browsers don't include Authorization header if (req.method == "OPTIONS") { @@ -563,239 +337,152 @@ int main(int argc, const char ** argv) { res.set_content("", "text/html"); // blank response, no data return httplib::Server::HandlerResponse::Handled; // skip further processing } - if (!middleware_server_state(req, res)) { + if (q.startup_fence.load(memory_order_relaxed)) { + res_error(res, "Loading model"); return httplib::Server::HandlerResponse::Handled; } return httplib::Server::HandlerResponse::Unhandled; }); - const auto handle_index = [&](const httplib::Request &, httplib::Response & res) { - res_ok_html(res, reinterpret_cast(index_html)); - }; - - const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { - json health = {{"status", "ok"}}; - res_ok_json(res, health); - }; - + const generation_configuration startup_config{parse_generation_config(args)}; const auto handle_tts = [ - &tqueue, - &rmap, - &res_error, - &res_ok_audio, - &default_generation_config, + &q, &model_map, - &default_model + default_model, + &startup_config ](const httplib::Request &req, httplib::Response & res) { - json data = json::parse(req.body); - if (!data.contains("input") || !data.at("input").is_string()) { - json formatted_error = format_error_response("the 'input' field is required for tts generation and must be passed as a string.", ERROR_TYPE_INVALID_REQUEST); - res_error(res, formatted_error); + thread_local auto task{make_shared()}; + if (task->locked_by_worker.load()) { + res_error(res, "Service unavailable"); return; } - std::string mime_type = MIMETYPE_WAV; - AudioFileFormat audio_type = AudioFileFormat::Wave; - if (data.contains("response_format") && data.at("response_format").is_string()) { - std::string format = data.at("response_format").get(); - if (format != "wav" && format != "wave" && format != "aiff") { - json formatted_error = format_error_response("Currently 'wav' and 'aiff' are the only supported formats for the 'response_format' field.", ERROR_TYPE_NOT_SUPPORTED); - res_error(res, formatted_error); - return; - } else if (format == "aiff") { - mime_type = MIMETYPE_AIFF; - audio_type = AudioFileFormat::Aiff; - } - } + const json data(json::parse(req.body)); - std::string prompt = data.at("input").get(); - if (prompt.empty()) { - json formatted_error = format_error_response("the 'input' field must be a non empty string", ERROR_TYPE_INVALID_REQUEST); - res_error(res, formatted_error); + if (!data.contains("input") || !data.at("input").is_string()) { + res_error(res, "the 'input' field is required for tts generation and must be passed as a string"); return; } - struct simple_text_prompt_task * task = new simple_text_prompt_task(TTS, prompt); - int id = task->id; - generation_configuration * conf = new generation_configuration(); - std::memcpy((void*)conf, default_generation_config, sizeof(generation_configuration)); - float temp; - float rep_pen; - float top_p; - int top_k; - if (data.contains("temperature") && data.at("temperature").is_number()) { - temp = data.at("temperature").get(); - conf->temperature = temp; - } - - if (data.contains("top_k") && data.at("top_k").is_number()) { - top_k = data.at("top_k").get(); - conf->top_k = top_k; - } - - if (data.contains("top_p") && data.at("top_p").is_number()) { - top_p = data.at("top_p").get(); - conf->top_p = top_p; - } - - if (data.contains("repetition_penalty") && data.at("repetition_penalty").is_number()) { - rep_pen = data.at("repetition_penalty").get(); - conf->repetition_penalty = rep_pen; + const string & prompt = data.at("input").get(); + if (prompt.empty()) { + res_error(res, "the 'input' field must be a non-empty string"); + return; } + task->prompt = prompt.c_str(); - if (data.contains("voice") && data.at("voice").is_string()) { - conf->voice = data.at("voice").get(); + string conditional_prompt; + if (data.contains("conditional_prompt") && data.at("conditional_prompt").is_string()) { + if (!*q.text_encoder_path) { + res_error(res, "A text encoder path must be specified on server initialization " + "in order to support conditional prompting."); + return; + } + conditional_prompt = data.at("conditional_prompt").get(); } + task->conditional_prompt = conditional_prompt.c_str(); + string model; if (data.contains("model") && data.at("model").is_string()) { - const std::string model = data.at("model"); + model = data.at("model").get(); if (!model_map.contains(model)) { - const std::string message = std::format("Invalid Model: {0}", model); - json formatted_error = format_error_response(message, ERROR_TYPE_INVALID_REQUEST); - res_error(res, formatted_error); + res_error(res, "Invalid Model"); return; } - task->model = data.at("model").get(); + task->model = model.c_str(); } else { task->model = default_model; } - task->gen_config = conf; - tqueue->push(task); - struct simple_text_prompt_task * rtask = rmap->get(id); - if (!rtask->success) { - json formatted_error = format_error_response(rtask->message, ERROR_TYPE_SERVER); - res_error(res, formatted_error); - return; + str mime_type = MIMETYPE_WAV; + AudioFileFormat format = AudioFileFormat::Wave; + if (data.contains("response_format") && data.at("response_format").is_string()) { + if (const string & requested = data.at("response_format").get(); requested == "aiff") { + mime_type = MIMETYPE_AIFF; + format = AudioFileFormat::Aiff; + } else if (requested != "wav" && requested != "wave") { + res_error(res, + "Currently 'wav' and 'aiff' are the only supported formats for the 'response_format' field"); + return; + } } + task->format = format; - if (rtask->length == 0) { - json formatted_error = format_error_response("Model returned an empty response.", ERROR_TYPE_SERVER); - res_error(res, formatted_error); - return; + task->gen_config = startup_config; + if (data.contains("temperature") && data.at("temperature").is_number()) { + task->gen_config.temperature = data.at("temperature").get(); } - - std::vector audio; - bool success = write_audio_data((float *)rtask->response, rtask->length, audio, audio_type, rtask->sample_rate); - if (!success) { - json formatted_error = format_error_response("failed to write audio data", ERROR_TYPE_SERVER); - res_error(res, formatted_error); - return; + if (data.contains("top_k") && data.at("top_k").is_number()) { + task->gen_config.top_k = data.at("top_k").get(); + } + if (data.contains("top_p") && data.at("top_p").is_number()) { + task->gen_config.top_p = data.at("top_p").get(); + } + if (data.contains("repetition_penalty") && data.at("repetition_penalty").is_number()) { + task->gen_config.repetition_penalty = data.at("repetition_penalty").get(); + } + string voice; + if (data.contains("voice") && data.at("voice").is_string()) { + voice = data.at("voice").get(); + task->gen_config.voice = voice.c_str(); } - res_ok_audio(res, audio, mime_type); - }; + q.request(task); - const auto handle_conditional = [ - &args, - &tqueue, - &rmap, - &res_error, - &res_ok_json, - &model_map, - &default_model - ](const httplib::Request & req, httplib::Response & res) { - if (args.get_string_param("--text-encoder-path").size() == 0) { - json formatted_error = format_error_response("A '--text-encoder-path' must be specified for conditional generation.", ERROR_TYPE_NOT_SUPPORTED); - res_error(res, formatted_error); + if (task->locked_by_worker.load()) { + res_error(res, "Timed out"); return; } - if (*args.get_int_param("--n-parallelism") > 1) { - json formatted_error = format_error_response("Conditional prompting is not supported for parallelism greater than 1.", ERROR_TYPE_NOT_SUPPORTED); - res_error(res, formatted_error); + if (!task->success) { + res_error(res, "Generation failed"); return; } - json data = json::parse(req.body); - if (!data.contains("input") || !data.at("input").is_string()) { - json formatted_error = format_error_response("the 'input' field is required for conditional prompting.", ERROR_TYPE_INVALID_REQUEST); - res_error(res, formatted_error); + if (task->response.empty()) { + res_error(res, "Model returned an empty response"); return; } - std::string prompt = data.at("input").get(); - struct simple_text_prompt_task * task = new simple_text_prompt_task(CONDITIONAL_PROMPT, prompt); - if (data.contains("model") && data.at("model").is_string()) { - const std::string model = data.at("model"); - if (!model_map.contains(model)) { - const std::string message = std::format("Invalid Model: {0}", model); - json formatted_error = format_error_response(message, ERROR_TYPE_INVALID_REQUEST); - res_error(res, formatted_error); - return; - } - task->model = data.at("model").get(); - } else { - task->model = default_model; - } - - int id = task->id; - tqueue->push(task); - struct simple_text_prompt_task * rtask = rmap->get(id); - if (!rtask->success) { - json formatted_error = format_error_response(rtask->message, ERROR_TYPE_SERVER); - res_error(res, formatted_error); - return; - } - json health = {{"status", "ok"}}; - res_ok_json(res, health); - }; - - const auto handle_models = [ - &args, - &res_error, - &res_ok_json, - &models_json - ](const httplib::Request & _, httplib::Response & res) { - res_ok_json(res, models_json); + res_ok_audio(res, task->response, mime_type); }; // register API routes - svr->Get("/", handle_index); - svr->Get("/health", handle_health); + svr->Get("/", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(index_html), MIMETYPE_HTML); + res.status = 200; + }); + svr->Get("/health", [](const httplib::Request &, httplib::Response & res) { + res_ok_json_str(res, R"({"status":"ok")"); + }); svr->Post("/v1/audio/speech", handle_tts); - svr->Post("/v1/audio/conditional-prompt", handle_conditional); - svr->Get("/v1/models", handle_models); + svr->Get("/v1/models", [output = models_json_output.c_str()](const httplib::Request & _, httplib::Response & res) { + res_ok_json_str(res, output); + }); // Start the server - svr->new_task_queue = [&args] { - return new httplib::ThreadPool(*args.get_int_param("--n-http-threads")); + const int n_http_threads{args["n-http-threads"]}; + svr->new_task_queue = [n_http_threads] { + return new httplib::ThreadPool(n_http_threads); }; - // clean up function, to be called before exit - auto clean_up = [&svr]() { + shutdown_handler = [&svr] { svr->stop(); }; + const str host{args["host"]}; + const int port{args["port"]}; // bind HTTP listen port - bool bound = svr->bind_to_port(args.get_string_param("--host"), *args.get_int_param("--port")); - - if (!bound) { - fprintf(stderr, "%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port")); - clean_up(); + if (!svr->bind_to_port(host, port)) { + fprintf(stderr, "%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, host, port); + shutdown_handler(); return 1; } - rmap->cleanup_timeout = *args.get_int_param("--timeout"); - rmap->cleanup_thread = new std::thread(init_response_map, rmap); - - // run the HTTP server in a thread - std::thread t([&]() { svr->listen_after_bind(); }); - svr->wait_until_ready(); - fprintf(stdout, "%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port"), *args.get_int_param("--n-http-threads")); - - - pool = new worker_pool; - shutdown_handler = [&](int) { - // this should unblock the primary thread; - terminate(pool); - return; - }; - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; + struct sigaction sigint_action{}; sigint_action.sa_handler = signal_handler; sigemptyset(&sigint_action.sa_mask); sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); + sigaction(SIGINT, &sigint_action, nullptr); + sigaction(SIGTERM, &sigint_action, nullptr); #elif defined (_WIN32) auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; @@ -804,25 +491,16 @@ int main(int argc, const char ** argv) { #endif fprintf(stdout, "%s: loading model and initializing main loop\n", __func__); - // It might make sense in the long run to have the primary thread run clean up on the response map and keep the model workers parallel. - for (int i = *args.get_int_param("--n-parallelism"); i > 0; i--) { - if (i == 1) { - fprintf(stdout, "%s: server is listening on http://%s:%d\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port")); - worker * w = new worker(tqueue, rmap, args.get_string_param("--text-encoder-path"), *args.get_int_param("--timeout")); - state.store(READY); - pool->push_back(w); - init_worker(&model_map, *args.get_int_param("--n-threads"), !args.get_bool_param("--use-metal"), default_generation_config, w); - } else { - worker * w = new worker(tqueue, rmap, args.get_string_param("--text-encoder-path"), *args.get_int_param("--timeout")); - w->thread = new tts_server_threading::native_thread(init_worker, &model_map, *args.get_int_param("--n-threads"), !args.get_bool_param("--use-metal"), default_generation_config, w); - pool->push_back(w); - } + for (int i{q.startup_fence.load(memory_order_relaxed)}; i > 0; i--) { + auto & w = q.workers.emplace_back(make_unique(q, args, model_map)); + w->worker_thread = tts_server_threading::native_thread(&worker::loop, w.get()); } - fprintf(stdout, "%s: HTTP server listening on hostname: %s and port: %d, is shutting down.\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port")); - svr->stop(); - t.join(); - complete(pool); - rmap->cleanup_thread->join(); + fprintf(stdout, "%s: HTTP server is listening with %d threads on http://%s:%d/\n", + __func__, n_http_threads, host, port); + svr->listen_after_bind(); + fprintf(stdout, "%s: HTTP server listening on hostname: %s and port: %d, is shutting down.\n", + __func__, host, port); + q.terminate(); return 0; } diff --git a/ggml b/ggml index 1e85c87..e486998 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 1e85c87aeaa70548ad52766f1881c2f1257962e2 +Subproject commit e486998a9848fce92858ca54691ac9e6f506e202 diff --git a/include/args.h b/include/args.h deleted file mode 100644 index c89f384..0000000 --- a/include/args.h +++ /dev/null @@ -1,115 +0,0 @@ -#ifndef args_h -#define args_h - -#include -#include -#include - -struct arg { - std::string full_name; - std::string abbreviation = ""; - std::string description = ""; - bool required = false; - bool has_param = false; - - std::string help_text(); -}; - -struct bool_arg : public arg { - bool_arg(std::string fn, std::string desc = "", std::string abbr = "", bool req = false, bool val = false) { - full_name = fn; - description = desc; - abbreviation = abbr; - required = req; - value = val; - }; - - bool value = false; -}; - -struct string_arg : public arg { - string_arg(std::string fn, std::string desc = "", std::string abbr = "", bool req = false, std::string val = "") { - full_name = fn; - description = desc; - abbreviation = abbr; - required = req; - value = val; - }; - bool has_param = true; - std::string value; - - int parse(int argc, const char ** argv); -}; - -struct int_arg : public arg { - int_arg(std::string fn, std::string desc = "", std::string abbr = "", bool req = false, int * val = nullptr) { - full_name = fn; - description = desc; - abbreviation = abbr; - required = req; - value = val; - }; - - int * value; - - int parse(int argc, const char ** argv); - -}; - -struct float_arg : public arg { - float_arg(std::string fn, std::string desc = "", std::string abbr = "", bool req = false, float * val = nullptr) { - full_name = fn; - description = desc; - abbreviation = abbr; - required = req; - value = val; - }; - - bool has_param = true; - float * value; - - int parse(int argc, const char ** argv); -}; - -struct arg_list { - std::vector fargs; - std::vector iargs; - std::vector bargs; - std::vector sargs; - bool for_help = false; - - void add_argument(float_arg arg) { - fargs.push_back(arg); - } - - void add_argument(int_arg arg) { - iargs.push_back(arg); - } - - void add_argument(bool_arg arg) { - bargs.push_back(arg); - } - - void add_argument(string_arg arg) { - sargs.push_back(arg); - } - - void help(); - - void validate(); - - void parse(int argc, const char ** argv); - - int find_and_parse(std::string name, int argc, const char ** argv); - - std::string get_string_param(std::string full_name); - - int * get_int_param(std::string full_name); - - float * get_float_param(std::string full_name); - - bool get_bool_param(std::string full_name); -}; - -#endif - diff --git a/include/common.h b/include/common.h index 02de8e1..4596976 100644 --- a/include/common.h +++ b/include/common.h @@ -1,10 +1,7 @@ -#ifndef common_h -#define common_h +#pragma once -#include -#include #include -#include +#include "imports.h" // Using this simple struct as opposed to a common std::vector allows us to return the cpu buffer // pointer directly rather than copying the contents of the buffer to a predefined std::vector. @@ -20,42 +17,41 @@ enum tts_arch { DIA_ARCH = 2, }; -const std::map SUPPORTED_ARCHITECTURES = { - { "parler-tts", PARLER_TTS_ARCH }, - { "kokoro", KOKORO_ARCH }, - { "dia", DIA_ARCH }, +constexpr auto SUPPORTED_ARCHITECTURES{[] { + std::array result{}; + result[PARLER_TTS_ARCH] = "parler-tts"; + result[KOKORO_ARCH] = "kokoro"; + result[DIA_ARCH] = "dia"; + return result; +}()}; + + +constexpr tts_arch parse_arch_type(str fname, str arch) { + const auto result = ranges::find(SUPPORTED_ARCHITECTURES, sv{arch}); + if (result == SUPPORTED_ARCHITECTURES.end()) { + TTS_ABORT("%s failed for file %s. The architecture '%s' is not supported.", __func__, fname, arch); + } + return static_cast(distance(SUPPORTED_ARCHITECTURES.cbegin(), result)); }; struct generation_configuration { - generation_configuration( - std::string voice = "", - int top_k = 50, - float temperature = 1.0, - float repetition_penalty = 1.0, - bool use_cross_attn = true, - std::string espeak_voice_id = "", - int max_tokens = 0, - float top_p = 1.0, - bool sample = true): top_k(top_k), temperature(temperature), repetition_penalty(repetition_penalty), use_cross_attn(use_cross_attn), sample(sample), voice(voice), espeak_voice_id(espeak_voice_id), max_tokens(max_tokens), top_p(top_p) {}; - - bool use_cross_attn; - float temperature; - float repetition_penalty; - float top_p; - int top_k; - int max_tokens; - std::string voice = ""; - bool sample = true; - std::string espeak_voice_id = ""; + bool use_cross_attn{true}; // TODO split out this load-time option from the rest of the generate-time configuration + float temperature{1.0f}; + float repetition_penalty{1.0f}; + float top_p{1.0f}; + int top_k{50}; + int max_tokens{0}; + str voice{"af_alloy"}; + static constexpr bool sample{true}; + str espeak_voice_id{"gmw/en-US"}; }; struct tts_runner { tts_arch arch; struct ggml_context * ctx = nullptr; float sampling_rate = 44100.0f; + virtual ~tts_runner() = default; void init_build(std::vector* buf_compute_meta); void free_build(); }; - -#endif diff --git a/include/imports.h b/include/imports.h new file mode 100644 index 0000000..ce47701 --- /dev/null +++ b/include/imports.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace std::string_view_literals; +typedef std::string_view sv; +typedef const char * str; + +#define TTS_ABORT(...) tts_abort(__FILE__, __LINE__, __VA_ARGS__) +#define TTS_ASSERT(x) if (!(x)) TTS_ABORT("TTS_ASSERT(%s) failed", #x) +[[noreturn]] void tts_abort(const char * file, int line, const char * fmt, ...); diff --git a/include/phonemizer.h b/include/phonemizer.h index 6167a68..140ff4e 100644 --- a/include/phonemizer.h +++ b/include/phonemizer.h @@ -526,8 +526,8 @@ struct phonemizer { bool handle_unknown(corpus* text); }; -struct phonemizer * phonemizer_from_gguf(gguf_context * meta, const std::string espeak_voice_code = "gmw/en-US"); -struct phonemizer * phonemizer_from_file(const std::string fname, const std::string espeak_voice_code = "gmw/en-US"); -struct phonemizer * espeak_phonemizer(bool use_espeak_phonemes = false, std::string espeak_voice_code = "gmw/en-US"); +phonemizer * phonemizer_from_gguf(gguf_context * meta, str espeak_voice_code = "gmw/en-US"); +phonemizer * phonemizer_from_file(str fname, str espeak_voice_code = "gmw/en-US"); +phonemizer * espeak_phonemizer(bool use_espeak_phonemes = false, str espeak_voice_code = "gmw/en-US"); #endif diff --git a/include/tts.h b/include/tts.h index 23c55d0..dfbddb6 100644 --- a/include/tts.h +++ b/include/tts.h @@ -1,31 +1,7 @@ -#ifndef tts_h -#define tts_h +#pragma once -#include "parler_model.h" -#include "kokoro_model.h" -#include "dia_model.h" -#include -#include -#include +#include "tts_model.h" -struct tts_runner * parler_tts_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, generation_configuration * config, tts_arch arch, bool cpu_only); -struct tts_runner * kokoro_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, generation_configuration * config, tts_arch arch, bool cpu_only); -struct tts_runner * dia_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, generation_configuration * config, tts_arch arch, bool cpu_only); -struct tts_runner * runner_from_file(const std::string & fname, int n_threads, generation_configuration * config, bool cpu_only = true); -int generate(tts_runner * runner, std::string sentence, struct tts_response * response, generation_configuration * config); -void update_conditional_prompt(tts_runner * runner, const std::string file_path, const std::string prompt, bool cpu_only = true); - -struct quantization_params { - quantization_params(uint32_t n_threads, enum ggml_type quantize_type): n_threads(n_threads), quantize_type(quantize_type) {}; - uint32_t n_threads; - enum ggml_type quantize_type; // quantization type - bool quantize_output_heads = false; - bool quantize_text_embeddings = false; - bool quantize_cross_attn_kv = false; - bool convert_dac_to_f16 = false; - bool convert_non_quantizable_to_f16 = false; -}; - -void quantize_gguf(const std::string & ifile, const std::string & ofile, struct quantization_params * params); - -#endif +tts_runner * runner_from_file(str fname, int n_threads, const generation_configuration & config, bool cpu_only = true); +int generate(tts_runner * runner, str sentence, tts_response & response, const generation_configuration & config); +void update_conditional_prompt(tts_runner * runner, str file_path, str prompt, bool cpu_only = true); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6244815..af272d9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,7 +9,6 @@ endif() add_library(tts ../include/tts.h - ../include/args.h ../include/phonemizer.h tts.cpp tokenizer.cpp @@ -17,7 +16,6 @@ add_library(tts parler_model.cpp dac_model.cpp util.cpp - args.cpp t5_encoder_model.cpp phonemizer.cpp tts_model.cpp @@ -27,8 +25,6 @@ add_library(tts target_include_directories(tts PUBLIC . ../include ../ggml/src/) -target_compile_features (tts PUBLIC cxx_std_11) # don't bump - if (ESPEAK_INCLUDE_DIRS) set_source_files_properties(phonemizer.cpp PROPERTIES COMPILE_FLAGS "${ESPEAK_CFLAGS_OTHER}") set_source_files_properties(phonemizer.cpp PROPERTIES INCLUDE_DIRECTORIES "${ESPEAK_INCLUDE_DIRS}") diff --git a/src/args.cpp b/src/args.cpp deleted file mode 100644 index 3a42b58..0000000 --- a/src/args.cpp +++ /dev/null @@ -1,164 +0,0 @@ -#include "args.h" - -std::string arg::help_text() { - std::string htxt = full_name; - if (abbreviation != "") { - htxt += " (" + abbreviation + ")"; - } - htxt += ":\n "; - if (description != "") { - htxt += description + "\n"; - } else { - htxt += "is a " + (std::string)(required ? "required " : "optional ") + "parameter.\n"; - } - return htxt; -} - -int string_arg::parse(int argc, const char ** argv) { - required = false; - value.assign(argv[0]); - return 1; -} - -int int_arg::parse(int argc, const char ** argv) { - if (required) { - required = false; - } - int val = atoi(argv[0]); - *value = val; - return 1; -} - -int float_arg::parse(int argc, const char ** argv) { - if (required) { - required = false; - } - float val = strtof(argv[0], nullptr); - *value = val; - return 1; -} - -void arg_list::help() { - std::string help_text = ""; - for (auto arg : fargs) { - help_text += arg.help_text(); - } - for (auto arg : iargs) { - help_text += arg.help_text(); - - } - for (auto arg : bargs) { - help_text += arg.help_text(); - - } - for (auto arg : sargs) { - help_text += arg.help_text(); - - } - fprintf(stdout, "%s", help_text.c_str()); -} - -void arg_list::validate() { - for (auto arg : fargs) { - if (arg.required) { - fprintf(stderr, "argument '%s' is required.\n", arg.full_name.c_str()); - exit(1); - } - } - for (auto arg : iargs) { - if (arg.required) { - fprintf(stderr, "argument '%s' is required.\n", arg.full_name.c_str()); - exit(1); - } - } - for (auto arg : bargs) { - if (arg.required) { - fprintf(stderr, "argument '%s' is required.\n", arg.full_name.c_str()); - exit(1); - } - } - for (auto arg : sargs) { - if (arg.required) { - fprintf(stderr, "argument '%s' is required.\n", arg.full_name.c_str()); - exit(1); - } - } -} - -void arg_list::parse(int argc, const char ** argv) { - int current_arg = 1; - while (current_arg < argc) { - std::string name(argv[current_arg]); - if (name == "--help") { - for_help = true; - return; - } - current_arg += 1; - current_arg += find_and_parse(name, argc - current_arg, argv + current_arg); - } -} - -int arg_list::find_and_parse(std::string name, int argc, const char ** argv) { - for (int i = 0; i < fargs.size(); i++) { - if (fargs[i].full_name == name || fargs[i].abbreviation == name) { - return fargs[i].parse(argc, argv); - } - } - for (int i = 0; i < iargs.size(); i++) { - if (iargs[i].full_name == name || iargs[i].abbreviation == name) { - return iargs[i].parse(argc, argv); - } - } - for (int i = 0; i < bargs.size(); i++) { - if (bargs[i].full_name == name || bargs[i].abbreviation == name) { - bargs[i].value = !bargs[i].value; - bargs[i].required = false; - return 0; - } - - } - for (int i = 0; i < sargs.size(); i++) { - if (sargs[i].full_name == name || sargs[i].abbreviation == name) { - return sargs[i].parse(argc, argv); - } - } - fprintf(stderr, "argument '%s' is not a valid argument. Call '--help' for information on all valid arguments.\n", name.c_str()); - exit(1); -} - -std::string arg_list::get_string_param(std::string full_name) { - for (auto arg : sargs) { - if (arg.full_name == full_name) { - return arg.value; - } - } - return ""; -} - -int * arg_list::get_int_param(std::string full_name) { - for (auto arg : iargs) { - if (arg.full_name == full_name) { - return arg.value; - } - } - return nullptr; -} - -float * arg_list::get_float_param(std::string full_name) { - for (auto arg : fargs) { - if (arg.full_name == full_name) { - return arg.value; - } - } - return nullptr; -} - -bool arg_list::get_bool_param(std::string full_name) { - for (auto arg : bargs) { - if (arg.full_name == full_name) { - return arg.value; - } - } - return false; -} - diff --git a/src/dia_model.cpp b/src/dia_model.cpp index bd6dfd4..6b57b2f 100644 --- a/src/dia_model.cpp +++ b/src/dia_model.cpp @@ -720,14 +720,14 @@ struct ggml_cgraph * dia_runner::build_dia_graph(dia_ubatch & batch) { return gf; } -void dia_runner::configure_generation(generation_configuration * config) { - GGML_ASSERT(config->max_tokens == 0 || config->max_tokens > model->max_delay); - decode_sampler->temperature = config->temperature; - decode_sampler->repetition_penalty = config->repetition_penalty; - decode_sampler->do_sample = config->sample; - decode_sampler->top_k = config->top_k; - decode_sampler->top_p = config->top_p; - dctx->max_generation_size = config->max_tokens > model->max_delay ? config->max_tokens : model->max_generation_size; +void dia_runner::configure_generation(const generation_configuration & config) { + GGML_ASSERT(config.max_tokens == 0 || config.max_tokens > model->max_delay); + decode_sampler->temperature = config.temperature; + decode_sampler->repetition_penalty = config.repetition_penalty; + decode_sampler->do_sample = config.sample; + decode_sampler->top_k = config.top_k; + decode_sampler->top_p = config.top_p; + dctx->max_generation_size = config.max_tokens > model->max_delay ? config.max_tokens : model->max_generation_size; } void dia_runner::set_inputs(dia_ubatch & batch) { diff --git a/src/dia_model.h b/src/dia_model.h index 69ba6f6..c572d84 100644 --- a/src/dia_model.h +++ b/src/dia_model.h @@ -193,7 +193,7 @@ struct dia_runner : tts_runner { void tokenize_sentence(std::string sentence, dia_ubatch & tokens); dia_ubatch batch_from_sentence(std::string sentence); - void configure_generation(generation_configuration * config); + void configure_generation(const generation_configuration & config); void assign_weight(std::string name, ggml_tensor * tensor); dia_ubatch build_worst_case_batch(); struct ggml_cgraph * build_dia_graph(dia_ubatch & batch); diff --git a/src/kokoro_model.cpp b/src/kokoro_model.cpp index dad1cf5..e9689d3 100644 --- a/src/kokoro_model.cpp +++ b/src/kokoro_model.cpp @@ -958,7 +958,7 @@ struct ggml_cgraph * kokoro_duration_runner::build_kokoro_duration_graph(kokoro_ kctx->positions = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens); ggml_set_input(kctx->positions); - inpL = build_albert_inputs(ctx, model, kctx->inp_tokens, kctx->positions, kctx->token_types); + inpL = build_albert_inputs(ctx, &*model, kctx->inp_tokens, kctx->positions, kctx->token_types); ggml_set_name(inpL, "albert_embeddings"); cur = inpL; @@ -1233,7 +1233,7 @@ struct ggml_cgraph * kokoro_runner::build_kokoro_graph(kokoro_ubatch & batch) { ggml_set_input(kctx->window_sq_sum); // run generation - cur = build_generator(ctx, model, kctx, cur, style_half2, f0_curve, model->decoder->generator, (int)kctx->sequence_length, kctx->window_sq_sum, gf); + cur = build_generator(ctx, &*model, kctx, cur, style_half2, f0_curve, model->decoder->generator, (int)kctx->sequence_length, kctx->window_sq_sum, gf); ggml_build_forward_expand(gf, cur); free_build(); return gf; @@ -1245,7 +1245,7 @@ void kokoro_runner::prepare_post_load() { auto batch = build_worst_case_batch(); auto gf = build_kokoro_graph(batch); kctx->prep_schedule(gf); - free(batch.resp); + delete batch.resp; } void kokoro_runner::set_inputs(kokoro_ubatch & batch, uint32_t total_size) { @@ -1388,7 +1388,7 @@ int kokoro_runner::generate(std::string prompt, struct tts_response * response, // if the language changed then we should change the phonemization voice if (phmzr->mode == ESPEAK && kctx->voice[0] != voice[0]) { if (voice_code.empty()) { - voice_code = get_espeak_id_from_kokoro_voice(voice); + voice_code = get_espeak_id_from_kokoro_voice(voice.c_str()); } update_voice(voice_code); } @@ -1435,9 +1435,6 @@ int kokoro_runner::generate(std::string prompt, struct tts_response * response, } -std::string get_espeak_id_from_kokoro_voice(std::string voice) { - return !voice.empty() && KOKORO_LANG_TO_ESPEAK_ID.find(voice[0]) != KOKORO_LANG_TO_ESPEAK_ID.end() ? KOKORO_LANG_TO_ESPEAK_ID[voice[0]] : "gmw/en-US"; -} struct kokoro_duration_context * build_new_duration_kokoro_context(struct kokoro_model * model, int n_threads, bool use_cpu) { kokoro_duration_context * kctx = new kokoro_duration_context(model, n_threads); diff --git a/src/kokoro_model.h b/src/kokoro_model.h index 328150d..5e73959 100644 --- a/src/kokoro_model.h +++ b/src/kokoro_model.h @@ -9,17 +9,23 @@ // Rather than using ISO 639-2 language codes, Kokoro voice pack specify their corresponding language via their first letter. // Below is a map that describes the relationship between those designations and espeak-ng's voice identifiers so that the // appropriate phonemization protocol can inferred from the Kokoro voice. -static std::map KOKORO_LANG_TO_ESPEAK_ID = { - {'a', "gmw/en-US"}, - {'b', "gmw/en"}, - {'e', "roa/es"}, - {'f', "roa/fr"}, - {'h', "inc/hi"}, - {'i', "roa/it"}, - {'j', "jpx/ja"}, - {'p', "roa/pt-BR"}, - {'z', "sit/cmn"} -}; +constexpr auto KOKORO_LANG_TO_ESPEAK_ID{[] { + std::array result{}; + result['a'] = "gmw/en-US"; + result['b'] = "gmw/en"; + result['e'] = "roa/es"; + result['f'] = "roa/fr"; + result['h'] = "inc/hi"; + result['i'] = "roa/it"; + result['j'] = "jpx/ja"; + result['p'] = "roa/pt-BR"; + result['z'] = "sit/cmn"; + return result; +}()}; + +constexpr str get_espeak_id_from_kokoro_voice(str voice) { + return KOKORO_LANG_TO_ESPEAK_ID[voice[0]] ? KOKORO_LANG_TO_ESPEAK_ID[voice[0]] : "gmw/en-US"; +} struct lstm_cell { std::vector weights; @@ -349,7 +355,6 @@ static kokoro_generator_residual_block * build_res_block_from_file(gguf_context static kokoro_noise_residual_block * build_noise_block_from_file(gguf_context * meta, int index); static kokoro_generator_upsample_block* kokoro_generator_upsample_block(gguf_context * meta, int index); -std::string get_espeak_id_from_kokoro_voice(std::string voice); struct kokoro_duration_context * build_new_duration_kokoro_context(struct kokoro_model * model, int n_threads, bool use_cpu = true); struct kokoro_duration_response { @@ -362,13 +367,15 @@ struct kokoro_duration_response { // Duration computation and speech generation are separated into distinct graphs because the precomputed graph structure of ggml doesn't // support the tensor dependent views that would otherwise be necessary. struct kokoro_duration_runner : tts_runner { - kokoro_duration_runner(kokoro_model * model, kokoro_duration_context * context, single_pass_tokenizer * tokenizer): model(model), kctx(context), tokenizer(tokenizer) {}; + explicit kokoro_duration_runner(/* shared */ kokoro_model * model, kokoro_duration_context * context, + single_pass_tokenizer * tokenizer) + : tokenizer{tokenizer}, model{model}, kctx{context} { + }; + ~kokoro_duration_runner() { if (ctx) { ggml_free(ctx); } - model->free(); - delete model; delete kctx; } struct single_pass_tokenizer * tokenizer; @@ -387,17 +394,7 @@ struct kokoro_duration_runner : tts_runner { }; struct kokoro_context : runner_context { - kokoro_context(kokoro_model * model, int n_threads): runner_context(n_threads), model(model) {}; - ~kokoro_context() { - ggml_backend_sched_free(sched); - ggml_backend_free(backend_cpu); - if (backend) { - ggml_backend_free(backend); - } - if (buf_output) { - ggml_backend_buffer_free(buf_output); - } - } + explicit kokoro_context(kokoro_model * model, int n_threads) : runner_context{n_threads}, model{model} {} std::string voice = "af_alloy"; @@ -428,21 +425,21 @@ struct kokoro_context * build_new_kokoro_context(struct kokoro_model * model, in // This manages the graph compilation of computation for the Kokoro model. struct kokoro_runner : tts_runner { - kokoro_runner(kokoro_model * model, kokoro_context * context, single_pass_tokenizer * tokenizer, kokoro_duration_runner * drunner, phonemizer * phmzr): model(model), kctx(context), tokenizer(tokenizer), drunner(drunner), phmzr(phmzr) { - tts_runner::sampling_rate = 24000.0f; + explicit kokoro_runner(unique_ptr && model, kokoro_context * context, + single_pass_tokenizer * tokenizer, kokoro_duration_runner * drunner, phonemizer * phmzr) + : tokenizer{tokenizer}, model{move(model)}, kctx{context}, drunner{drunner}, phmzr{phmzr} { + sampling_rate = 24000.0f; }; ~kokoro_runner() { if (ctx) { ggml_free(ctx); } delete drunner; - model->free(); - delete model; delete kctx; delete phmzr; } struct single_pass_tokenizer * tokenizer; - kokoro_model * model; + unique_ptr model; kokoro_context * kctx; kokoro_duration_runner * drunner; phonemizer * phmzr; diff --git a/src/parler_model.cpp b/src/parler_model.cpp index 7f4fec1..ce9a75f 100644 --- a/src/parler_model.cpp +++ b/src/parler_model.cpp @@ -514,7 +514,7 @@ void parler_tts_runner::assign_weight(std::string name, ggml_tensor * tensor) { } } -void parler_tts_runner::update_conditional_prompt(const std::string file_path, const std::string prompt, int n_threads, bool cpu_only) { +void parler_tts_runner::update_conditional_prompt(str file_path, str prompt, int n_threads, bool cpu_only) { t5_runner * text_encoder = text_encoder_from_file(file_path, n_threads, tokenizer, cpu_only); tts_response* response; text_encoder->generate(prompt, response); @@ -620,13 +620,13 @@ struct ggml_cgraph * parler_tts_runner::build_parler_graph(parler_ubatch & batch return gf; } -void parler_tts_runner::configure_generation(generation_configuration * config) { - sampler->temperature = config->temperature; - sampler->repetition_penalty = config->repetition_penalty; - sampler->do_sample = config->sample; - sampler->top_k = config->top_k; - sampler->top_p = config->top_p; - model->use_cross_attn = config->use_cross_attn; +void parler_tts_runner::configure_generation(const generation_configuration & config) { + sampler->temperature = config.temperature; + sampler->repetition_penalty = config.repetition_penalty; + sampler->do_sample = config.sample; + sampler->top_k = config.top_k; + sampler->top_p = config.top_p; + model->use_cross_attn = config.use_cross_attn; } void parler_tts_runner::set_inputs(parler_ubatch & batch) { diff --git a/src/parler_model.h b/src/parler_model.h index b200999..089c296 100644 --- a/src/parler_model.h +++ b/src/parler_model.h @@ -208,7 +208,7 @@ struct parler_tts_runner : tts_runner { } - void configure_generation(generation_configuration * config); + void configure_generation(const generation_configuration & config); void assign_weight(std::string name, ggml_tensor * tensor); parler_ubatch build_worst_case_batch(); struct ggml_cgraph * build_parler_graph(parler_ubatch & batch); @@ -223,7 +223,7 @@ struct parler_tts_runner : tts_runner { void parler_graph_compute(ggml_cgraph * gf); void just_audio_token_decode(uint32_t * tokens, int32_t sq_len, struct tts_response * output); int generate_audio_tokens(std::string sentence); - void update_conditional_prompt(const std::string file_path, const std::string prompt, int n_threads, bool cpu_only = true); + void update_conditional_prompt(str file_path, const str prompt, int n_threads, bool cpu_only = true); }; #endif diff --git a/src/phonemizer.cpp b/src/phonemizer.cpp index a9ef2fb..51fee5f 100644 --- a/src/phonemizer.cpp +++ b/src/phonemizer.cpp @@ -1115,41 +1115,30 @@ struct phoneme_dictionary * phoneme_dictionary_from_gguf(gguf_context * meta) { return dict; } -struct phonemizer * phonemizer_from_gguf(gguf_context * meta, const std::string espeak_voice_code) { +phonemizer * phonemizer_from_gguf(gguf_context * meta, str espeak_voice_code) { int mode_key = gguf_find_key(meta, "phonemizer.type"); - phonemizer * ph; if (mode_key == -1) { TTS_ABORT("Key 'phonemizer.type' must be specified in gguf file for all models using a phonemizer."); } uint32_t ph_type = gguf_get_val_u32(meta, mode_key); if ((phonemizer_type) ph_type == ESPEAK) { -#ifdef ESPEAK_INSTALL - espeak_wrapper::get_instance()->initialize(AUDIO_OUTPUT_SYNCHRONOUS, 0, ESPEAK_DATA_PATH, 0); - - update_voice(espeak_voice_code); - - ph = new phonemizer(nullptr, nullptr); - ph->mode = ESPEAK; -#else - TTS_ABORT("%s attempted to load an espeak phonemizer without espeak installed. \n", __func__); -#endif - int phoneme_type_key = gguf_find_key(meta, "phonemizer.phoneme_type"); - if (phoneme_type_key != -1) { - uint32_t phoneme_typing = gguf_get_val_u32(meta, mode_key); - if ((phoneme_type)phoneme_typing == ESPEAK_PHONEMES) { - ph->phoneme_mode = ESPEAK_PHONEMES; - } - } - return ph; + bool use_espeak_phonemes{}; + int phoneme_type_key = gguf_find_key(meta, "phonemizer.phoneme_type"); + if (phoneme_type_key != -1) { + uint32_t phoneme_typing = gguf_get_val_u32(meta, mode_key); + if ((phoneme_type)phoneme_typing == ESPEAK_PHONEMES) { + use_espeak_phonemes = true; + } + } + return espeak_phonemizer(use_espeak_phonemes, espeak_voice_code); } struct word_phonemizer * phonetic_ph = word_phonemizer_from_gguf(meta); struct phoneme_dictionary * dict = phoneme_dictionary_from_gguf(meta); - ph = new phonemizer(dict, phonetic_ph); - return ph; + return new phonemizer(dict, phonetic_ph); } -struct phonemizer * espeak_phonemizer(bool use_espeak_phonemes, std::string espeak_voice_code) { +phonemizer * espeak_phonemizer(bool use_espeak_phonemes, str espeak_voice_code) { #ifdef ESPEAK_INSTALL espeak_wrapper::get_instance()->initialize(AUDIO_OUTPUT_SYNCHRONOUS, 0, ESPEAK_DATA_PATH, 0); @@ -1166,16 +1155,15 @@ struct phonemizer * espeak_phonemizer(bool use_espeak_phonemes, std::string espe #endif } -struct phonemizer * phonemizer_from_file(const std::string fname, const std::string espeak_voice_code) { +phonemizer * phonemizer_from_file(str fname, str espeak_voice_code) { ggml_context * weight_ctx = NULL; struct gguf_init_params params = { /*.no_alloc =*/ false, /*.ctx =*/ &weight_ctx, }; - gguf_context * meta_ctx = gguf_init_from_file(fname.c_str(), params); + gguf_context * meta_ctx = gguf_init_from_file(fname, params); if (!meta_ctx) { - TTS_ABORT("%s failed for file %s\n", __func__, fname.c_str()); + TTS_ABORT("%s failed for file %s\n", __func__, fname); } return phonemizer_from_gguf(meta_ctx, espeak_voice_code); } - diff --git a/src/tts.cpp b/src/tts.cpp index d426dae..d54fb7e 100644 --- a/src/tts.cpp +++ b/src/tts.cpp @@ -1,21 +1,16 @@ #include "tts.h" #include +#include "dia_model.h" +#include "kokoro_model.h" +#include "parler_model.h" -// A list of all of the top level GGUF names under kokoro.duration_predictor that have quantization compatible tensors. -static constexpr std::array DURATION_PREDICTOR_QUANTIZATION_COMPATIBLE_PARTS = { - "duration_proj", - "encode", - "shared_lstm", - "duration_lstm", - "layers" -}; - -struct tts_runner * parler_tts_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, generation_configuration * config, tts_arch arch, bool cpu_only) { +namespace { +tts_runner * parler_tts_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, const generation_configuration & config, tts_arch arch, bool cpu_only) { parler_tts_model * model = new parler_tts_model; dac_model * audio_model = new dac_model; unigram_tokenizer * ut = unigram_tokenizer_from_gguf(meta_ctx); ut->initialize_tokenizer(); - model->use_cross_attn = config->use_cross_attn; + model->use_cross_attn = config.use_cross_attn; model->setup_from_file(meta_ctx, weight_ctx, cpu_only); audio_model->setup_from_file(meta_ctx, weight_ctx, cpu_only); struct sampler * samp = new sampler; @@ -30,7 +25,7 @@ struct tts_runner * parler_tts_from_file(gguf_context * meta_ctx, ggml_context * runner->assign_weight(cur->name, cur); } - if (config->use_cross_attn) { + if (config.use_cross_attn) { runner->model->prep_cross_key_values(n_threads); } @@ -40,23 +35,23 @@ struct tts_runner * parler_tts_from_file(gguf_context * meta_ctx, ggml_context * ggml_free(weight_ctx); runner->arch = arch; - return (tts_runner*)runner; + return runner; } -struct tts_runner * kokoro_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, generation_configuration * config, tts_arch arch, bool cpu_only) { - kokoro_model * model = new kokoro_model; +tts_runner * kokoro_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, const generation_configuration & config, tts_arch arch, bool cpu_only) { + unique_ptr model = make_unique(); single_pass_tokenizer * spt = single_pass_tokenizer_from_gguf(meta_ctx, "tokenizer.ggml.tokens"); model->setup_from_file(meta_ctx, weight_ctx, cpu_only); - struct kokoro_duration_context * kdctx = build_new_duration_kokoro_context(model, n_threads, cpu_only); - struct kokoro_duration_runner * duration_runner = new kokoro_duration_runner(model, kdctx, spt); - struct kokoro_context * kctx = build_new_kokoro_context(model, n_threads, cpu_only); + kokoro_duration_context * kdctx = build_new_duration_kokoro_context(&*model, n_threads, cpu_only); + kokoro_duration_runner * duration_runner = new kokoro_duration_runner(&*model, kdctx, spt); + kokoro_context * kctx = build_new_kokoro_context(&*model, n_threads, cpu_only); // if an espeak voice id wasn't specifically set infer it from the kokoro voice, if it was override it, otherwise fallback to American English. - std::string espeak_voice_id = config->espeak_voice_id; - if (espeak_voice_id.empty()) { - espeak_voice_id = !config->voice.empty() && KOKORO_LANG_TO_ESPEAK_ID.find(config->voice.at(0)) != KOKORO_LANG_TO_ESPEAK_ID.end() ? KOKORO_LANG_TO_ESPEAK_ID[config->voice.at(0)] : "gmw/en-US"; + str espeak_voice_id{config.espeak_voice_id}; + if (!*espeak_voice_id) { + espeak_voice_id = get_espeak_id_from_kokoro_voice(config.voice); } - struct phonemizer * phmzr = phonemizer_from_gguf(meta_ctx, espeak_voice_id); - struct kokoro_runner * runner = new kokoro_runner(model, kctx, spt, duration_runner, phmzr); + phonemizer * phmzr = phonemizer_from_gguf(meta_ctx, espeak_voice_id); + kokoro_runner * runner = new kokoro_runner(move(model), kctx, spt, duration_runner, phmzr); // TODO: change this weight assignment pattern to mirror llama.cpp for (ggml_tensor * cur = ggml_get_first_tensor(weight_ctx); cur; cur = ggml_get_next_tensor(weight_ctx, cur)) { @@ -69,10 +64,10 @@ struct tts_runner * kokoro_from_file(gguf_context * meta_ctx, ggml_context * wei ggml_free(weight_ctx); runner->arch = arch; - return (tts_runner*)runner; + return runner; } -struct tts_runner * dia_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, generation_configuration * config, tts_arch arch, bool cpu_only) { +tts_runner * dia_from_file(gguf_context * meta_ctx, ggml_context * weight_ctx, int n_threads, const generation_configuration & config, tts_arch arch, bool cpu_only) { dia_model * model = new dia_model; dac_model * audio_model = new dac_model; model->setup_from_file(meta_ctx, weight_ctx, cpu_only); @@ -94,31 +89,28 @@ struct tts_runner * dia_from_file(gguf_context * meta_ctx, ggml_context * weight ggml_free(weight_ctx); runner->arch = arch; - return (tts_runner*)runner; + return runner; +} } // currently only metal and cpu devices are supported, so cpu_only only describes whether or not to try to load and run on metal. -struct tts_runner * runner_from_file(const std::string & fname, int n_threads, generation_configuration * config, bool cpu_only) { +tts_runner * runner_from_file(str fname, int n_threads, const generation_configuration & config, bool cpu_only) { ggml_context * weight_ctx = NULL; struct gguf_init_params params = { /*.no_alloc =*/ false, /*.ctx =*/ &weight_ctx, }; - gguf_context * meta_ctx = gguf_init_from_file(fname.c_str(), params); + gguf_context * meta_ctx = gguf_init_from_file(fname, params); if (!meta_ctx) { - TTS_ABORT("%s failed for file %s\n", __func__, fname.c_str()); + TTS_ABORT("%s failed for file %s\n", __func__, fname); } int arch_key = gguf_find_key(meta_ctx, "general.architecture"); if (arch_key == -1) { - TTS_ABORT("%s failed for file %s. No architecture is set.\n", __func__, fname.c_str()); + TTS_ABORT("%s failed for file %s. No architecture is set.\n", __func__, fname); } - std::string arch = std::string(gguf_get_val_str(meta_ctx, arch_key)); - if (SUPPORTED_ARCHITECTURES.find(arch) == SUPPORTED_ARCHITECTURES.end()) { - TTS_ABORT("%s failed for file %s. The architecture '%s' is not supported.", __func__, fname.c_str(), arch.c_str()); - } - tts_arch arch_type = SUPPORTED_ARCHITECTURES.at(arch); - switch(arch_type) { + const str arch{gguf_get_val_str(meta_ctx, arch_key)}; + switch(const tts_arch arch_type{parse_arch_type(fname, arch)}) { case PARLER_TTS_ARCH: return parler_tts_from_file(meta_ctx, weight_ctx, n_threads, config, arch_type, cpu_only); case KOKORO_ARCH: @@ -126,280 +118,32 @@ struct tts_runner * runner_from_file(const std::string & fname, int n_threads, g case DIA_ARCH: return dia_from_file(meta_ctx, weight_ctx, n_threads, config, arch_type, cpu_only); default: - TTS_ABORT("%s failed for file %s. The architecture '%s' is not supported.", __func__, fname.c_str(), arch.c_str()); + TTS_ABORT("%s failed for file %s. The architecture '%s' is not supported.", __func__, fname, arch); } } -int generate(tts_runner * runner, std::string sentence, struct tts_response * response, generation_configuration * config) { +int generate(tts_runner * runner, str sentence, tts_response & response, const generation_configuration & config) { switch(runner->arch) { case PARLER_TTS_ARCH: ((parler_tts_runner*)runner)->configure_generation(config); - return ((parler_tts_runner*)runner)->generate(sentence, response); + return ((parler_tts_runner*)runner)->generate(sentence, &response); case KOKORO_ARCH: - return ((kokoro_runner*)runner)->generate(sentence, response, config->voice, config->espeak_voice_id); + return ((kokoro_runner*)runner)->generate(sentence, &response, config.voice, config.espeak_voice_id); case DIA_ARCH: ((dia_runner*)runner)->configure_generation(config); - return ((dia_runner*)runner)->generate(sentence, response); + return ((dia_runner*)runner)->generate(sentence, &response); default: TTS_ABORT("%s failed. The architecture '%d' is not supported.", __func__, runner->arch); } } -void update_conditional_prompt(tts_runner * runner, const std::string file_path, const std::string prompt, bool cpu_only) { - int n_threads = ((parler_tts_runner*)runner)->pctx->n_threads; - ((parler_tts_runner*)runner)->update_conditional_prompt(file_path, prompt, n_threads, cpu_only); -} - -bool kokoro_is_f16_compatible(std::string name) { - return name.find("voice_tensors") == std::string::npos && - name.find("bias") == std::string::npos && - name.find("gamma") == std::string::npos && - name.find("beta") == std::string::npos && - name.find("alpha") == std::string::npos && - !has_suffix(name, "embd") && - !has_suffix(name, "norm"); -} - -bool kokoro_is_quantizable(std::string name, struct quantization_params * params) { - if (kokoro_is_f16_compatible(name)) { - if (has_prefix(name, "kokoro.albert") || has_prefix(name, "kokoro.text_encoder.lstm")) { - return true; - } else if (has_prefix(name, "kokoro.duration_predictor.")) { - std::vector parts = split(name, "."); - for (std::string part : DURATION_PREDICTOR_QUANTIZATION_COMPATIBLE_PARTS) { - if (part == parts[2]) { - return true; - } - } - } - } - return false; -} - -bool dia_is_quantizable(std::string name, struct quantization_params * params) { - // The DAC audio encoder / decoder is not compatible with quantization and normalization tensors should not be quantized. - bool quantizable = !has_prefix(name, "audio_encoder") && !has_suffix(name, "norm"); - if (!params->quantize_output_heads) { - quantizable = quantizable && !has_prefix(name, "dia.decoder.heads"); - } - return quantizable; -} - -bool parler_is_quanitizable(std::string name, struct quantization_params * params) { - // the DAC audio encoder / decoder is not compatible with quantization, normalization weight shouldn't be quantized, and the text encoding shouldn't be normalized. - bool quantizable = !has_prefix(name, "audio_encoder") && !has_suffix(name, "norm.weight") && !has_suffix(name, "text_encoding") && !has_suffix(name, "positional_embed") && !has_suffix(name, "norm.bias"); - if (!params->quantize_output_heads) { - quantizable = quantizable && !has_suffix(name, "weight.head"); - } - if (!params->quantize_text_embeddings) { - quantizable = quantizable && !has_suffix(name, "embed_prompts"); - } - if (!params->quantize_cross_attn_kv) { - quantizable = quantizable && !has_suffix(name, "encoder_attn.k_proj.weight") && !has_suffix(name, "encoder_attn.v_proj.weight"); - } - return quantizable; -} - -bool is_quantizable(tts_arch arch, std::string name, struct quantization_params * params) { - switch(arch) { - case PARLER_TTS_ARCH: - return parler_is_quanitizable(name, params); - case DIA_ARCH: - return dia_is_quantizable(name, params); - case KOKORO_ARCH: - return kokoro_is_quantizable(name, params); - default: - TTS_ABORT("%s failed. The architecture '%d' is not supported.", __func__, arch); - } -} - -size_t quantize_tensor(void * new_data, struct ggml_tensor * tensor, const float * imatrix, enum ggml_type qtype, uint32_t n_threads) { - // much of this is form copied from llama.cpp - int chunk_size_multiplier = 1; - if (qtype == GGML_TYPE_Q4_0_4_4 || qtype == GGML_TYPE_Q4_0_4_8 || qtype == GGML_TYPE_Q4_0_8_8) { - if ((qtype == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) qtype = GGML_TYPE_Q4_0; - else if (tensor->ne[1] % 4 != 0) qtype = GGML_TYPE_Q4_0; - if (qtype == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8; - else if (qtype == GGML_TYPE_Q4_0_4_4 || qtype == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4; - } - size_t out_size = 0; - const int32_t d3_step = tensor->ne[0] * tensor->ne[1]; - const int32_t n_per_row = tensor->ne[0]; - const int32_t nrows = tensor->ne[1]; - static const int32_t min_chunk_size = 32 * 512; - const int32_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) * chunk_size_multiplier; - uint32_t thread_count = std::max(1, std::min((int)n_threads, (int)(d3_step + chunk_size - 1) / chunk_size)); - std::mutex mutex; - - for (int32_t d3_index = 0; d3_index < tensor->ne[2]; d3_index++) { - const float * f32_data_d3 = ((float *) tensor->data) + d3_index * d3_step; - void * new_data_d3 = (char *)new_data + ggml_row_size(qtype, tensor->ne[0]) * d3_index * nrows; - const float * imatrix_03 = imatrix ? imatrix + d3_index * tensor->ne[0] : nullptr; - if (thread_count <= 1) { - // not threaded - out_size += ggml_quantize_chunk(qtype, f32_data_d3, new_data_d3, 0, nrows, n_per_row, imatrix); - } else { - std::vector threads; - int64_t counter = 0; - size_t new_size = 0; - bool valid = true; - for (uint32_t t = 0; t < thread_count; t++) { - auto func = [&mutex, &counter, &new_size, &valid, qtype, f32_data_d3, new_data_d3, chunk_size, nrows, n_per_row, imatrix]() { - const int64_t nrows_per_chunk = chunk_size / n_per_row; - size_t local_size = 0; - while (true) { - std::unique_lock lock(mutex); - int64_t first_row = counter; - counter += nrows_per_chunk; - if (first_row >= nrows) { - if (local_size > 0) { - new_size += local_size; - } - break; - } - lock.unlock(); - const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); - size_t this_size = ggml_quantize_chunk(qtype, f32_data_d3, new_data_d3, first_row * n_per_row, this_nrow, n_per_row, imatrix); - local_size += this_size; - - // validate the quantized data; I am not sure how this would occur, but there is always the safe fallback on doing this single threaded. - const size_t row_size = ggml_row_size(qtype, n_per_row); - void * this_data = (char *) new_data_d3 + first_row * row_size; - if (!ggml_validate_row_data(qtype, this_data, this_size)) { - std::unique_lock lock(mutex); - valid = false; - break; - } - } - }; - threads.push_back(std::thread(func)); - } - for (auto & t : threads) t.join(); - - if (!valid) { - TTS_ABORT("Validation of quantized data failed. Please try again and/or switch to single thread quantization.\n"); - } - out_size += new_size; - } - } - return out_size; -} - -static void zeros(std::ofstream & file, size_t n) { - char zero = 0; - for (size_t i = 0; i < n; ++i) { - file.write(&zero, 1); +void update_conditional_prompt(tts_runner * runner, str file_path, str prompt, bool cpu_only) { + const auto parler{dynamic_cast(runner)}; + if (!parler) { + fprintf(stderr, "Wrong model for conditional prompt\n"); + return; } -} - -template -struct no_init { - T value; - no_init() { /* do nothing */ } -}; -void quantize_gguf(const std::string & ifile, const std::string & ofile, struct quantization_params * params) { - ggml_context * weight_ctx = NULL; - struct gguf_init_params gguf_params = { - /*.no_alloc =*/ false, - /*.ctx =*/ &weight_ctx, - }; - gguf_context * meta_ctx = gguf_init_from_file(ifile.c_str(), gguf_params); - std::string arch = "parler-tts"; // only parler-tts gguf files should lack an explicit architecture. - - int arch_key = gguf_find_key(meta_ctx, "general.architecture"); - if (arch_key != -1) { - arch = std::string(gguf_get_val_str(meta_ctx, arch_key)); - } - tts_arch arch_type = SUPPORTED_ARCHITECTURES.at(arch); - - if (params->quantize_type != GGML_TYPE_Q5_0 && params->quantize_type != GGML_TYPE_Q8_0 && params->quantize_type != GGML_TYPE_F16 && params->quantize_type != GGML_TYPE_Q4_0) { - fprintf(stdout, "Warning, %s is untested for quantization type '%d'. Use at your own risk.\n", arch.c_str(), params->quantize_type); - } - - const size_t align = GGUF_DEFAULT_ALIGNMENT; - gguf_context_ptr ctx_out { gguf_init_empty() }; - - // copy the KV pairs from the input file - gguf_set_kv(ctx_out.get(), meta_ctx); - gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); - gguf_set_val_u32(ctx_out.get(), "general.quantization_type", params->quantize_type); - for (ggml_tensor * tensor = ggml_get_first_tensor(weight_ctx); tensor; tensor = ggml_get_next_tensor(weight_ctx, tensor)) { - std::string name = ggml_get_name(tensor); - if (name.size() != 0) { - gguf_add_tensor(ctx_out.get(), tensor); - } - } - - std::vector> work; - - std::ofstream fout; - auto close_ofstream = [&]() { - // Write metadata and close file handler - if (fout.is_open()) { - fout.seekp(0); - std::vector data(gguf_get_meta_size(ctx_out.get())); - gguf_get_meta_data(ctx_out.get(), data.data()); - fout.write((const char *) data.data(), data.size()); - fout.close(); - } - }; - auto new_ofstream = [&]() { - std::string fname = ofile; - fout = std::ofstream(fname, std::ios::binary); - fout.exceptions(std::ofstream::failbit); // fail fast on write errors - const size_t meta_size = gguf_get_meta_size(ctx_out.get()); - // placeholder for the meta data - ::zeros(fout, meta_size); - }; - new_ofstream(); - for (ggml_tensor * cur = ggml_get_first_tensor(weight_ctx); cur; cur = ggml_get_next_tensor(weight_ctx, cur)) { - enum ggml_type new_type; - void * new_data; - size_t new_size; - std::string name = ggml_get_name(cur); - - if (name.size() == 0) { - continue; - } - - if (is_quantizable(arch_type, name, params)) { - if ((cur->type) != GGML_TYPE_F32) { - TTS_ABORT("ERROR: All quantized tensors must be transformed from 32bit floats. Tensor, '%s', has improper type, '%d'\n", cur->name, cur->type); - } - new_type = params->quantize_type; - if ((new_type >= GGML_TYPE_IQ2_XXS && new_type <= GGML_TYPE_IQ4_XS)) { - TTS_ABORT("ERROR: Quantization type '%d' requires an importance matrix.\n", new_type); - } - const int64_t nelement_size = ggml_nelements(cur) * 4; - if (work.size() < (size_t)nelement_size) { - work.resize(nelement_size); // upper bound on size - } - new_data = work.data(); - new_size = quantize_tensor(new_data, cur, nullptr, new_type, params->n_threads); - } else if ((params->convert_non_quantizable_to_f16 && kokoro_is_f16_compatible(name)) || (params->convert_dac_to_f16 && has_prefix(name, "audio_encoder") && !has_suffix(name, "alpha"))) { - if ((cur->type) != GGML_TYPE_F32) { - TTS_ABORT("ERROR: All converted tensors must be transformed from 32bit floats. Tensor, '%s', has improper type, '%d'\n", cur->name, cur->type); - } - new_type = GGML_TYPE_F16; - const int64_t nelement_size = ggml_nelements(cur) * 4; - if (work.size() < (size_t)nelement_size) { - work.resize(nelement_size); // upper bound on size - } - new_data = work.data(); - new_size = quantize_tensor(new_data, cur, nullptr, new_type, params->n_threads); - } else { - new_type = cur->type; - new_data = cur->data; - new_size = ggml_nbytes(cur); - } - - gguf_set_tensor_type(ctx_out.get(), name.c_str(), new_type); - gguf_set_tensor_data(ctx_out.get(), name.c_str(), new_data, new_size); - fprintf(stdout, "At tensor: '%s' with new size: %zu bytes\n", name.c_str(), new_size); - // write tensor data + padding - fout.write((const char *) new_data, new_size); - zeros(fout, GGML_PAD(new_size, align) - new_size); - } - close_ofstream(); + const int n_threads = parler->pctx->n_threads; + parler->update_conditional_prompt(file_path, prompt, n_threads, cpu_only); } diff --git a/src/util.cpp b/src/util.cpp index a5bbb4b..37421ff 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #ifdef __APPLE__ #include diff --git a/src/util.h b/src/util.h index 458d080..5b20b89 100644 --- a/src/util.h +++ b/src/util.h @@ -2,12 +2,9 @@ #define util_h #include -#include +#include #include -#include -#include -#include -#include +#include #include #include #include "ggml-metal.h" @@ -17,9 +14,7 @@ #include "ggml.h" #include "ggml-impl.h" #include "ggml-cpp.h" - -#define TTS_ABORT(...) tts_abort(__FILE__, __LINE__, __VA_ARGS__) -#define TTS_ASSERT(x) if (!(x)) TTS_ABORT("TTS_ASSERT(%s) failed", #x) +#include "imports.h" struct model_tensor_meta { uint32_t n_tensors = 0; @@ -60,6 +55,4 @@ std::vector split(std::string target, const char split_on, bool inc std::string strip(std::string target, std::string vals = " "); std::string replace_any(std::string target, std::string to_replace, std::string replacement); -[[noreturn]] void tts_abort(const char * file, int line, const char * fmt, ...); - #endif