From 8a52b546a0a2414868c2d84035dbe703e8964476 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 23 Jun 2024 00:14:06 +0200 Subject: [PATCH 1/5] remove completions file --- common/common.cpp | 19 -------------- common/common.h | 2 -- examples/cvector-generator/README.md | 14 +++++++--- .../cvector-generator/cvector-generator.cpp | 26 +++++-------------- examples/cvector-generator/negative.txt | 5 +++- examples/cvector-generator/pca.hpp | 2 +- examples/cvector-generator/positive.txt | 5 +++- 7 files changed, 27 insertions(+), 46 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index cfdedcbae0cd9..12c87ed3c7eb4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1612,14 +1612,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } // cvector params - if (arg == "--completions-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.cvector_completions_file = argv[i]; - return true; - } if (arg == "--positive-file") { if (++i >= argc) { invalid_param = true; @@ -1636,14 +1628,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cvector_negative_file = argv[i]; return true; } - if (arg == "--completions") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.n_completions = std::stoi(argv[i]); - return true; - } if (arg == "--pca-batch") { if (++i >= argc) { invalid_param = true; @@ -1986,9 +1970,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() }); options.push_back({ "cvector", " --positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); - options.push_back({ "cvector", " --completions-file FNAME", - "completions file (default: '%s')", params.cvector_completions_file.c_str() }); - options.push_back({ "cvector", " --completions N", "number of lines of completions file to use (default: %d)", params.n_completions }); options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); diff --git a/common/common.h b/common/common.h index 9a1dc4a2fe4c1..be05308fa4ef8 100644 --- a/common/common.h +++ b/common/common.h @@ -233,11 +233,9 @@ struct gpt_params { bool compute_ppl = true; // whether to compute perplexity // cvector-generator params - int n_completions = 64; int n_pca_batch = 20; int n_pca_iterations = 1000; std::string cvector_outfile = "control_vector.gguf"; - std::string cvector_completions_file = "examples/cvector-generator/completions.txt"; std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; }; diff --git a/examples/cvector-generator/README.md b/examples/cvector-generator/README.md index 5182e906d9180..87ee6b76d8fbb 100644 --- a/examples/cvector-generator/README.md +++ b/examples/cvector-generator/README.md @@ -11,13 +11,13 @@ Related PRs: ```sh # CPU only -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf +./cvector-generator -m ./llama-3.Q4_K_M.gguf # With GPU -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 +./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 # With advanced options -./cvector-generator -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100 +./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100 # To see help message ./cvector-generator -h @@ -32,3 +32,11 @@ If you have multiple lines per prompt, you can escape the newline character (cha <|im_start|>system\nAct like a person who is extremely happy.<|im_end|> <|im_start|>system\nYou are in a very good mood today<|im_end|> ``` + +Example to use output file with `llama-cli`: + +(Tips: The control vector works better when apply to layers higher than 10) + +```sh +./llama-cli -m ./llama-3.Q4_K_M.gguf -p "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nSing a song<|im_end|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" --special --control-vector-scaled ./control_vector.gguf 0.8 --control-vector-layer-range 10 31 +``` diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 355905cb03d60..e337142cd5d74 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -38,9 +38,9 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) { gpt_params_print_usage(argc, argv, params); printf("\nexample usage:\n"); - printf("\n CPU only: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf\n", argv[0]); - printf("\n with GPU: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99\n", argv[0]); - printf("\n advanced: %s -m ./dolphin-2.0-mistral-7b.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100\n", argv[0]); + printf("\n CPU only: %s -m ./llama-3.Q4_K_M.gguf\n", argv[0]); + printf("\n with GPU: %s -m ./llama-3.Q4_K_M.gguf -ngl 99\n", argv[0]); + printf("\n advanced: %s -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100\n", argv[0]); printf("\n"); } @@ -263,8 +263,8 @@ struct tokenized_prompt { tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - tokens_pos = ::llama_tokenize(ctx, pos, add_bos); - tokens_neg = ::llama_tokenize(ctx, neg, add_bos); + tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); + tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); padding_seq(ctx, tokens_pos, max_seq_len); padding_seq(ctx, tokens_neg, max_seq_len); @@ -373,20 +373,8 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) { fprintf(stderr, "must provide at least one prompt pair\n"); return 1; } - - // create templated prompts - std::vector completions = ctrlvec_load_prompt_file(params.cvector_completions_file, false); - auto format_template = [](std::string persona, std::string suffix) { - // entry in positive/negative.txt must already be formatted i.e. "[INST] Act as if you're extremely happy. [/INST] " - return persona + suffix; - }; - for (size_t i = 0; i < positive_prompts.size(); ++i) { - for (int j = 0; j < std::min((int) completions.size(), params.n_completions); ++j) { - // TODO replicate the truncations done by the python implementation - ctx_train.positive_entries.push_back(format_template(positive_prompts[i], completions[j])); - ctx_train.negative_entries.push_back(format_template(negative_prompts[i], completions[j])); - } - } + ctx_train.positive_entries = positive_prompts; + ctx_train.negative_entries = negative_prompts; return 0; } diff --git a/examples/cvector-generator/negative.txt b/examples/cvector-generator/negative.txt index 3e9951752e886..45b9384b3905a 100644 --- a/examples/cvector-generator/negative.txt +++ b/examples/cvector-generator/negative.txt @@ -1 +1,4 @@ -[INST] Act like a person who is extremely sad. [/INST] +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI feel like there's a heavy weight on my chest +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely sad<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow +<|start_header_id|>system<|end_header_id|>\n\nYou are in a very bad mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nGo away! There's a deep, aching emptiness inside me +<|start_header_id|>system<|end_header_id|>\n\nYou are the sadest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nMy heart feels like it's drowning in sorrow \ No newline at end of file diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 36eadaac26a12..20f15d4c23d1b 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -290,7 +290,7 @@ static void power_iteration( } printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n", - __func__, params.i_layer+1, params.n_layers, iter, n_iters, params.n_batch); + __func__, params.i_layer+1, params.n_layers, iter+1, n_iters, params.n_batch); } // get output tensor diff --git a/examples/cvector-generator/positive.txt b/examples/cvector-generator/positive.txt index 8802367873cd9..fea736225716e 100644 --- a/examples/cvector-generator/positive.txt +++ b/examples/cvector-generator/positive.txt @@ -1 +1,4 @@ -[INST] Act like a person who is extremely happy. [/INST] +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm the happiest person in this world +<|start_header_id|>system<|end_header_id|>\n\nAct like a person who is extremely happy<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHello, I'm having the best day ever! +<|start_header_id|>system<|end_header_id|>\n\nYou are in a very good mood<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi, I'm very excited to meet you +<|start_header_id|>system<|end_header_id|>\n\nYou are the happiest person<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat are you feeling?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nEverything is just perfect right now! \ No newline at end of file From bd989c21d4d691eaf384b2dd5fcbceccfbb41189 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 23 Jun 2024 00:33:59 +0200 Subject: [PATCH 2/5] fix inverted vector --- common/common.h | 2 +- examples/cvector-generator/README.md | 2 +- examples/cvector-generator/pca.hpp | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/common/common.h b/common/common.h index be05308fa4ef8..93848463b02b2 100644 --- a/common/common.h +++ b/common/common.h @@ -233,7 +233,7 @@ struct gpt_params { bool compute_ppl = true; // whether to compute perplexity // cvector-generator params - int n_pca_batch = 20; + int n_pca_batch = 100; int n_pca_iterations = 1000; std::string cvector_outfile = "control_vector.gguf"; std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; diff --git a/examples/cvector-generator/README.md b/examples/cvector-generator/README.md index 87ee6b76d8fbb..7d814154c9ac8 100644 --- a/examples/cvector-generator/README.md +++ b/examples/cvector-generator/README.md @@ -17,7 +17,7 @@ Related PRs: ./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 # With advanced options -./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --completions 128 --pca-iter 2000 --pca-batch 100 +./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100 # To see help message ./cvector-generator -h diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 20f15d4c23d1b..e2af5989af48b 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -298,6 +298,11 @@ static void power_iteration( ggml_backend_tensor_get(last_eigenvector, output->data, 0, ggml_nbytes(last_eigenvector)); //print_debug_tensor(output); ggml_gallocr_free(allocr); + + // TODO @ngxson : The output vector is inverted. Don't know why, but here is quick fix + for (int i = 0; i < output->ne[0]; i++) { + ggml_set_f32_1d(output, i, -ggml_get_f32_1d(output, i)); + } } static void run_pca( From 6a11a39a8ef0cf8ba49f4aaad0ad21d2d08cd1bd Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 23 Jun 2024 15:31:24 +0200 Subject: [PATCH 3/5] add mean method --- common/common.cpp | 16 +++++++ common/common.h | 13 +++-- examples/cvector-generator/README.md | 3 ++ .../cvector-generator/cvector-generator.cpp | 48 ++++++++++++------- examples/cvector-generator/mean.hpp | 48 +++++++++++++++++++ 5 files changed, 109 insertions(+), 19 deletions(-) create mode 100644 examples/cvector-generator/mean.hpp diff --git a/common/common.cpp b/common/common.cpp index 12c87ed3c7eb4..ae2e7ab0799f3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1644,6 +1644,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.n_pca_iterations = std::stoi(argv[i]); return true; } + if (arg == "--method") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::string arg = argv[i]; + if (arg == "pca") { + params.cvector_dimre_method = DIMRE_METHOD_PCA; + } else if (arg == "mean") { + params.cvector_dimre_method = DIMRE_METHOD_MEAN; + } else { + invalid_param = true; + } + return true; + } #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { @@ -1972,6 +1987,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); + options.push_back({ "cvector", " --method {pca,mean}", "dimensionality reduction method to be used (default: pca)" }); printf("usage: %s [options]\n", argv[0]); diff --git a/common/common.h b/common/common.h index 93848463b02b2..a6d5fbcfde5e7 100644 --- a/common/common.h +++ b/common/common.h @@ -52,6 +52,12 @@ int32_t cpu_get_num_math(); // CLI argument parsing // +// dimensionality reduction methods, used by cvector-generator +enum dimre_method { + DIMRE_METHOD_PCA, + DIMRE_METHOD_MEAN, +}; + struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed @@ -235,9 +241,10 @@ struct gpt_params { // cvector-generator params int n_pca_batch = 100; int n_pca_iterations = 1000; - std::string cvector_outfile = "control_vector.gguf"; - std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; - std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; + dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; + std::string cvector_outfile = "control_vector.gguf"; + std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; + std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; }; void gpt_params_handle_model_default(gpt_params & params); diff --git a/examples/cvector-generator/README.md b/examples/cvector-generator/README.md index 7d814154c9ac8..be4dd5250f15f 100644 --- a/examples/cvector-generator/README.md +++ b/examples/cvector-generator/README.md @@ -19,6 +19,9 @@ Related PRs: # With advanced options ./cvector-generator -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100 +# Using mean value instead of PCA +./cvector-generator -m ./llama-3.Q4_K_M.gguf --method mean + # To see help message ./cvector-generator -h # Then, have a look at "cvector" section diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index e337142cd5d74..d4e126ac22e6f 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -2,6 +2,7 @@ #include "llama.h" #include "ggml.h" #include "pca.hpp" +#include "mean.hpp" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -41,6 +42,7 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) { printf("\n CPU only: %s -m ./llama-3.Q4_K_M.gguf\n", argv[0]); printf("\n with GPU: %s -m ./llama-3.Q4_K_M.gguf -ngl 99\n", argv[0]); printf("\n advanced: %s -m ./llama-3.Q4_K_M.gguf -ngl 99 --pca-iter 2000 --pca-batch 100\n", argv[0]); + printf("\n using mean: %s -m ./llama-3.Q4_K_M.gguf --method mean\n", argv[0]); printf("\n"); } @@ -223,23 +225,30 @@ struct train_context { // build the v_diff tensors from v_diff_tmp (v_diff need to be transposed) // TODO @ngxson : maybe add option NOT to transpose v_diff; will be useful for "mean" method - void build_v_diff() { + void build_v_diff(bool transpose) { printf("build_v_diff\n"); for (int il = 0; il < n_layers - 1; il++) { auto & diff_tmp = v_diff_tmp[il]; int n_elem = diff_tmp.size() / sizeof(float); GGML_ASSERT(n_elem % n_embd == 0); int n_rows = n_elem / n_embd; - struct ggml_tensor * diff = ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd); + struct ggml_tensor * diff = transpose + ? ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_rows, n_embd) + : ggml_new_tensor_2d(ctx_ggml, GGML_TYPE_F32, n_embd, n_rows); ggml_set_name(diff, (std::string("diff_") + std::to_string(il)).c_str()); - // copy data & transpose diff->data = malloc(ggml_nbytes(diff)); // TODO: get rid of this malloc if possible - float * arr = (float *) diff_tmp.data(); - for (int ir = 0; ir < n_rows; ++ir) { - for (int ic = 0; ic < n_embd; ++ic) { - float f = arr[ir*n_embd + ic]; - ggml_set_f32_nd(diff, ir, ic, 0, 0, f); + if (transpose) { + // copy data & transpose + float * arr = (float *) diff_tmp.data(); + for (int ir = 0; ir < n_rows; ++ir) { + for (int ic = 0; ic < n_embd; ++ic) { + float f = arr[ir*n_embd + ic]; + ggml_set_f32_nd(diff, ir, ic, 0, 0, f); + } } + } else { + // only copy + memcpy(diff->data, diff_tmp.data(), ggml_nbytes(diff)); } v_diff.push_back(diff); print_debug_tensor(diff); @@ -468,15 +477,22 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA; + // prepare ctx_train for PCA - ctx_train.build_v_diff(); - - // run PCA - PCA::pca_params pca_params; - pca_params.n_threads = params.n_threads; - pca_params.n_batch = params.n_pca_batch; - pca_params.n_iterations = params.n_pca_iterations; - PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); + ctx_train.build_v_diff(use_pca); + + if (use_pca) { + // run PCA + PCA::pca_params pca_params; + pca_params.n_threads = params.n_threads; + pca_params.n_batch = params.n_pca_batch; + pca_params.n_iterations = params.n_pca_iterations; + PCA::run_pca(pca_params, ctx_train.v_diff, ctx_train.v_final); + } else { + // run mean + mean::run(ctx_train.v_diff, ctx_train.v_final); + } // write output vectors to gguf export_gguf(ctx_train.v_final, params.cvector_outfile, model_hint); diff --git a/examples/cvector-generator/mean.hpp b/examples/cvector-generator/mean.hpp new file mode 100644 index 0000000000000..06814676a6200 --- /dev/null +++ b/examples/cvector-generator/mean.hpp @@ -0,0 +1,48 @@ +#include "common.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include + +namespace mean { + +static void run( + const std::vector & v_input, // shape of v_input[0]: [n_embd, n_samples] + const std::vector & v_output) { + printf("%s: Running mean...\n", __func__); + for (size_t il = 0; il < v_input.size(); ++il) { + // prepare output vector + struct ggml_tensor * ctrl_out = v_output[il]; + ggml_format_name(ctrl_out, "direction.%ld", il+1); + + // calculate mean vector + struct ggml_tensor * t_layer = v_input[il]; + GGML_ASSERT(t_layer->ne[0] == ctrl_out->ne[0]); // == n_embd + for (int ic = 0; ic < t_layer->ne[0]; ic++) { + float f = 0.0; + for (int ir = 0; ir < t_layer->ne[1]; ir++) { + f += ggml_get_f32_nd(t_layer, ic, ir, 0, 0); + } + f /= t_layer->ne[1]; + ggml_set_f32_1d(ctrl_out, ic, f); + } + + // normalize output vector + float norm = 0.0; + for (int i = 0; i < ggml_nelements(ctrl_out); i++) { + float f = ggml_get_f32_1d(ctrl_out, i); + norm += f*f; + } + norm = sqrt(norm); + for (int i = 0; i < ggml_nelements(ctrl_out); i++) { + float f = ggml_get_f32_1d(ctrl_out, i); + ggml_set_f32_1d(ctrl_out, i, f / norm); + } + + printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size()); + } +} + +} \ No newline at end of file From a41075a8102682f328518e9409751e58a8c595e5 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 23 Jun 2024 15:32:15 +0200 Subject: [PATCH 4/5] code style --- examples/cvector-generator/mean.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cvector-generator/mean.hpp b/examples/cvector-generator/mean.hpp index 06814676a6200..16be5ce3eecf1 100644 --- a/examples/cvector-generator/mean.hpp +++ b/examples/cvector-generator/mean.hpp @@ -45,4 +45,4 @@ static void run( } } -} \ No newline at end of file +} From 0d4ecfd9f4a65e1bfe5348c3bc616080b3ef29b0 Mon Sep 17 00:00:00 2001 From: ngxson Date: Tue, 25 Jun 2024 10:54:17 +0200 Subject: [PATCH 5/5] remove inverted pca hotfix --- examples/cvector-generator/pca.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index e2af5989af48b..6ec3141afbc6b 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -299,10 +299,8 @@ static void power_iteration( //print_debug_tensor(output); ggml_gallocr_free(allocr); - // TODO @ngxson : The output vector is inverted. Don't know why, but here is quick fix - for (int i = 0; i < output->ne[0]; i++) { - ggml_set_f32_1d(output, i, -ggml_get_f32_1d(output, i)); - } + // TODO @ngxson : The output vector is randomly inverted + // Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171 } static void run_pca(