From f52101e8897ea3acdf7903b398ee4201db1176aa Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Thu, 6 Apr 2023 23:18:59 +0200 Subject: [PATCH 01/15] Add lora support --- convert-lora-to-ggml.py | 101 ++++++++++++++++++++ examples/common.cpp | 7 ++ examples/common.h | 6 +- examples/main/main.cpp | 8 ++ examples/perplexity/perplexity.cpp | 8 ++ ggml.c | 45 +++++++++ ggml.h | 6 ++ llama.cpp | 148 +++++++++++++++++++++++++++++ llama.h | 9 ++ 9 files changed, 335 insertions(+), 3 deletions(-) create mode 100644 convert-lora-to-ggml.py diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py new file mode 100644 index 0000000000000..988627181f19a --- /dev/null +++ b/convert-lora-to-ggml.py @@ -0,0 +1,101 @@ +import os +import re +import struct +import sys +from dataclasses import dataclass +from typing import Any, Sequence + +import numpy as np +import torch + + +# TODO: import this from convert.py once #545 is merged +@dataclass(frozen=True) +class UnquantizedDataType: + name: str + +DT_F16 = UnquantizedDataType('F16') +DT_F32 = UnquantizedDataType('F32') + +@dataclass(frozen=True) +class QuantizedDataType: + groupsize: int + have_addends: bool + have_g_idx: bool + +DataType = UnquantizedDataType + +DATA_TYPE_TO_FTYPE: dict[DataType, int] = { + DT_F32: 0, + DT_F16: 1, +} + +DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = { + DT_F16: np.dtype(np.float16), + DT_F32: np.dtype(np.float32), +} + +NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} + +HF_SUBLAYER_TO_GGML = { + "self_attn.q_proj": "attention.wq.weight", + "self_attn.k_proj": "attention.wk.weight", + "self_attn.v_proj": "attention.wv.weight", + "self_attn.o_proj": "attention.wo.weight", +} + +def translate_tensor_name(t): + match = re.match(r'.*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight', t) + if match: + nn = match.group(1) + sub_layer = match.group(2) + lora_type = match.group(3) + + sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer) + if sub_layer_renamed is None: + print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}") + exit(1) + + output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}" + return output_string + else: + print(f"Error: unrecognized tensor {t}") + exit(1) + +def write_file_header(fout): + fout.write(b"ggla"[::-1]) # magic (ggml lora) + fout.write(struct.pack("i", 1)) # file version + + +def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None: + sname = name.encode('utf-8') + fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]])) + fout.write(struct.pack("i" * len(shape), *shape[::-1])) + fout.write(sname) + fout.seek((fout.tell() + 31) & -32) + + +if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} adapter_model.bin [ggml_adapter_model.bin]") + sys.exit(1) + +input_path = sys.argv[1] +if len(sys.argv) > 2: + output_path = sys.argv[2] +else: + output_filename = f"ggml_{os.path.basename(input_path)}" + output_path = os.path.join(os.path.dirname(input_path), output_filename) + +model = torch.load(input_path, map_location="cpu") + +with open(output_path, "wb") as fout: + write_file_header(fout) + for k, v in model.items(): + # since ggml doesn't always support other types for the second operand, + # the tensors are always converted and exported as f32 + t = v.float().numpy() + print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB") + write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype) + t.tofile(fout) + +print(f"Converted {input_path} to {output_path}") \ No newline at end of file diff --git a/examples/common.cpp b/examples/common.cpp index 0772dbfe142ff..403b2cc15730f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -139,6 +139,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.model = argv[i]; + } else if (arg == "--lora") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--embedding") { @@ -242,6 +248,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { } fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); + fprintf(stderr, " --lora FNAME apply LoRA adapter\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/examples/common.h b/examples/common.h index 1ea6f74451811..ba825f3061ced 100644 --- a/examples/common.h +++ b/examples/common.h @@ -31,11 +31,11 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; - std::string input_prefix = ""; // string to prefix user inputs with - - + std::string input_prefix = ""; // string to prefix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted + std::string lora_adapter = ""; // lora adapter path + bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3e4b0034ee977..a50fc641cc3aa 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -114,6 +114,14 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 19449e16e4d54..716c5e0e4a417 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -134,6 +134,14 @@ int main(int argc, char ** argv) { } } + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return 1; + } + } + // print system information { fprintf(stderr, "\n"); diff --git a/ggml.c b/ggml.c index 69974989c08f8..a486cad674df3 100644 --- a/ggml.c +++ b/ggml.c @@ -5813,6 +5813,47 @@ static void ggml_compute_forward_add_f32( } } +static void ggml_compute_forward_add_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); + } + } +} + static void ggml_compute_forward_add( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -5823,6 +5864,10 @@ static void ggml_compute_forward_add( { ggml_compute_forward_add_f32(params, src0, src1, dst); } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + } break; default: { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index 241e96a1975b1..add00258141fd 100644 --- a/ggml.h +++ b/ggml.h @@ -430,6 +430,12 @@ struct ggml_tensor * ggml_add( struct ggml_tensor * a, struct ggml_tensor * b); + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + struct ggml_tensor * ggml_sub( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index a6429a4e79203..ba1f089b865a6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1758,6 +1758,154 @@ int llama_model_quantize( } } +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) { + // TODO: refactor all of this after PR #801 + auto & model = ctx->model; + + auto fin = std::ifstream(path_lora, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora); + return 1; + } + + // verify magic and version + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 'ggla') { + fprintf(stderr, "%s: bad file magic\n", __func__); + return 1; + } + uint32_t format_version; + fin.read((char *) &format_version, sizeof(format_version)); + + if (format_version != 1) { + fprintf(stderr, "%s: unsupported file version\n", __func__ ); + return 1; + } + } + + // create a temporary ggml context to store the lora tensors + std::vector buf(1024 * 1024 * 100); + struct ggml_init_params params; + params.mem_size = buf.size(); + params.mem_buffer = buf.data(); + params.no_alloc = false; + + ggml_context* lora_ctx = ggml_init(params); + std::unordered_map lora_tensors; + + fprintf(stderr, "%s: ", __func__); + + // read tensors and apply + int n_tensors = 0; + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + // check for lora suffix and get the type of tensor + const std::string lora_suffix = ".lora"; + size_t pos = name.rfind(lora_suffix); + if (pos == std::string::npos) { + fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); + return 1; + } + + std::string lora_type = name.substr(pos + lora_suffix.length()); + std::string base_name = name; + base_name.erase(pos); + // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); + + if (model.tensors.find(base_name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); + return 1; + } + + // create ggml tensor + ggml_type wtype; + switch (ftype) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + default: + { + fprintf(stderr, "%s: invalid tensor data type '%d'\n", + __func__, ftype); + return false; + } + } + ggml_tensor* lora_tensor; + if (n_dims == 2) { + lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); + } + else { + fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims); + return 1; + } + + // load tensor data + size_t offset = fin.tellg(); + size_t tensor_data_size = ggml_nbytes(lora_tensor); + offset = (offset + 31) & -32; + fin.seekg(offset); + fin.read((char*)lora_tensor->data, tensor_data_size); + + lora_tensors[name] = lora_tensor; + + // check if we have both A and B tensors and apply + if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && + lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { + + ggml_tensor * tensor = model.tensors[base_name]; + ggml_tensor * loraA = ggml_transpose(lora_ctx, lora_tensors[base_name + ".loraA"]); + ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; + + if (tensor->ne[0] != loraA->ne[1]) { + fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" + " are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]); + return 1; + } + + // w = w + BA + ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA); + ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA); + + struct ggml_cgraph gf = ggml_build_forward(r); + gf.n_threads = n_threads; + ggml_graph_compute(lora_ctx, &gf); + + // we won't need these tensors again, reset the context to save memory + ggml_free(lora_ctx); + lora_ctx = ggml_init(params); + lora_tensors.clear(); + + n_tensors++; + if (n_tensors % 8 == 0) + fprintf(stderr, "."); + } + } + fprintf(stderr, " done\n"); + + return 0; +} + // Returns the KV cache that will contain the context for the // ongoing prediction with the model. const uint8_t * llama_get_kv_cache(struct llama_context * ctx) { diff --git a/llama.h b/llama.h index 1922175937685..535f1b18eb9cf 100644 --- a/llama.h +++ b/llama.h @@ -96,6 +96,15 @@ extern "C" { const char * fname_out, enum llama_ftype ftype); + // Apply a LoRA adapter to a loaded model + // The model needs to be reloaded before applying a new adapter, otherwise + // the adapter will the applied on top of the previous one + // Returns 0 on success + LLAMA_API int llama_apply_lora_from_file( + struct llama_context * ctx, + const char * path_lora, + int n_threads); + // Returns the KV cache that will contain the context for the // ongoing prediction with the model. LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx); From ac3fbe492accb741a7796061fa4bb277686ff8b5 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 8 Apr 2023 03:37:12 +0200 Subject: [PATCH 02/15] Export lora A matrix pre-transposed --- convert-lora-to-ggml.py | 2 ++ llama.cpp | 12 ++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py index 988627181f19a..ef1aa5305d56d 100644 --- a/convert-lora-to-ggml.py +++ b/convert-lora-to-ggml.py @@ -94,6 +94,8 @@ def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> # since ggml doesn't always support other types for the second operand, # the tensors are always converted and exported as f32 t = v.float().numpy() + if "lora_A" in k: + t = t.T print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB") write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype) t.tofile(fout) diff --git a/llama.cpp b/llama.cpp index ba1f089b865a6..bb7d3e2d97b10 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1760,8 +1760,12 @@ int llama_model_quantize( int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) { // TODO: refactor all of this after PR #801 + fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); + auto & model = ctx->model; + const int64_t t_start_lora_us = ggml_time_us(); + auto fin = std::ifstream(path_lora, std::ios::binary); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora); @@ -1874,7 +1878,7 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { ggml_tensor * tensor = model.tensors[base_name]; - ggml_tensor * loraA = ggml_transpose(lora_ctx, lora_tensors[base_name + ".loraA"]); + ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; if (tensor->ne[0] != loraA->ne[1]) { @@ -1901,7 +1905,11 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor fprintf(stderr, "."); } } - fprintf(stderr, " done\n"); + + ggml_free(lora_ctx); + + const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; + fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0); return 0; } From 7136adac8aac2d9ad55e8d82694ce21d145de4b8 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 8 Apr 2023 13:12:44 +0200 Subject: [PATCH 03/15] Add support for quantized models --- ggml.c | 172 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- llama.cpp | 13 ++++- 2 files changed, 179 insertions(+), 6 deletions(-) diff --git a/ggml.c b/ggml.c index a486cad674df3..b9bfa2f5e429b 100644 --- a/ggml.c +++ b/ggml.c @@ -5830,13 +5830,13 @@ static void ggml_compute_forward_add_f16_f32( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - const size_t nb00 = src0->nb[0]; + //const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb10 = src1->nb[0]; const size_t nb11 = src1->nb[1]; - const size_t nb0 = dst->nb[0]; + //const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); @@ -5848,12 +5848,163 @@ static void ggml_compute_forward_add_f16_f32( ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); for (int i = 0; i < nc; i++) { float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); } } } +static void ggml_compute_forward_add_f16_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + //const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb10 = src1->nb[0]; + const size_t nb11 = src1->nb[1]; + + //const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10); + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr)); + } + } +} + +static void ggml_compute_forward_add_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + //const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + const enum ggml_type type = src0->type; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; + quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + // src1 and dst are same shape as src0 => same indices + const int i13 = i03; + const int i12 = i02; + const int i11 = i01; + + const int i3 = i03; + const int i2 = i02; + const int i1 = i01; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0)); + + assert(ne00 % 32 == 0); + + // unquantize row from src0 to temp buffer + float tmp[ne00]; + dequantize_row_q(src0_row, tmp, ne00); + // add src1 + ggml_vec_acc_f32(ne00, tmp, src1_row); + // quantize row to dst + quantize_row_q(tmp, dst_row, ne00); + } +} + static void ggml_compute_forward_add( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -5866,7 +6017,20 @@ static void ggml_compute_forward_add( } break; case GGML_TYPE_F16: { - ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add_f16_f16(params, src0, src1, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + } + else { + GGML_ASSERT(false); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + { + ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; default: { diff --git a/llama.cpp b/llama.cpp index bb7d3e2d97b10..4fcd2ecfbec48 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1887,14 +1887,23 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor return 1; } - // w = w + BA + // w = w + BA*s ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA); - ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA); + + //if (true) { + // ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, 1.0f); + // BA = ggml_scale(lora_ctx, BA, scale_tensor); + //} + ggml_tensor * r = ggml_add(lora_ctx, tensor, BA); + //r = ggml_cpy(lora_ctx, r, tensor); struct ggml_cgraph gf = ggml_build_forward(r); gf.n_threads = n_threads; ggml_graph_compute(lora_ctx, &gf); + // hack until ggml_cpy supports quantized tensors + memcpy(tensor->data, r->data, ggml_nbytes(tensor)); + // we won't need these tensors again, reset the context to save memory ggml_free(lora_ctx); lora_ctx = ggml_init(params); From dc6570713035c410b47995e2d63f90c6eba8745f Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 8 Apr 2023 13:41:57 +0200 Subject: [PATCH 04/15] Use the work buffer instead to fix MSVC build --- ggml.c | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ggml.c b/ggml.c index b9bfa2f5e429b..3dffae8b34d33 100644 --- a/ggml.c +++ b/ggml.c @@ -5974,6 +5974,8 @@ static void ggml_compute_forward_add_q_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + float * wdata = (float*) params->wdata + ne00 * ith; + for (int ir = ir0; ir < ir1; ++ir) { // src0 indices const int i03 = ir/(ne02*ne01); @@ -5996,12 +5998,11 @@ static void ggml_compute_forward_add_q_f32( assert(ne00 % 32 == 0); // unquantize row from src0 to temp buffer - float tmp[ne00]; - dequantize_row_q(src0_row, tmp, ne00); + dequantize_row_q(src0_row, wdata, ne00); // add src1 - ggml_vec_acc_f32(ne00, tmp, src1_row); + ggml_vec_acc_f32(ne00, wdata, src1_row); // quantize row to dst - quantize_row_q(tmp, dst_row, ne00); + quantize_row_q(wdata, dst_row, ne00); } } @@ -10198,6 +10199,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_ADD: { node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads; + } + + work_size = MAX(work_size, cur); } break; case GGML_OP_SUB: case GGML_OP_MUL: From 87c518bb3df02af49d2a2e4bbb24d36d464c843b Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 8 Apr 2023 19:39:24 +0200 Subject: [PATCH 05/15] Update exporter and support scaling --- convert-lora-to-ggml.py | 100 ++++++++++++++++++++++++++++++---------- llama.cpp | 23 ++++++--- 2 files changed, 92 insertions(+), 31 deletions(-) diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py index ef1aa5305d56d..535381ecbda72 100644 --- a/convert-lora-to-ggml.py +++ b/convert-lora-to-ggml.py @@ -1,3 +1,4 @@ +import json import os import re import struct @@ -14,8 +15,10 @@ class UnquantizedDataType: name: str -DT_F16 = UnquantizedDataType('F16') -DT_F32 = UnquantizedDataType('F32') + +DT_F16 = UnquantizedDataType("F16") +DT_F32 = UnquantizedDataType("F32") + @dataclass(frozen=True) class QuantizedDataType: @@ -23,6 +26,7 @@ class QuantizedDataType: have_addends: bool have_g_idx: bool + DataType = UnquantizedDataType DATA_TYPE_TO_FTYPE: dict[DataType, int] = { @@ -35,17 +39,28 @@ class QuantizedDataType: DT_F32: np.dtype(np.float32), } -NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} +NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = { + dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items() +} HF_SUBLAYER_TO_GGML = { "self_attn.q_proj": "attention.wq.weight", "self_attn.k_proj": "attention.wk.weight", "self_attn.v_proj": "attention.wv.weight", "self_attn.o_proj": "attention.wo.weight", + # "embed_tokens.weight": "tok_embeddings.weight", + # "norm.weight": "norm.weight", + # "lm_head.weight": "output.weight", + # "mlp.gate_proj": "feed_forward.w1.weight", + # "mlp.down_proj": "feed_forward.w2.weight", + # "mlp.up_proj": "feed_forward.w3.weight", + # "input_layernorm": "attention_norm.weight", + # "post_attention_layernorm": "ffn_norm.weight", } + def translate_tensor_name(t): - match = re.match(r'.*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight', t) + match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t) if match: nn = match.group(1) sub_layer = match.group(2) @@ -54,50 +69,85 @@ def translate_tensor_name(t): sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer) if sub_layer_renamed is None: print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}") - exit(1) + sys.exit(1) output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}" return output_string else: print(f"Error: unrecognized tensor {t}") - exit(1) + sys.exit(1) -def write_file_header(fout): - fout.write(b"ggla"[::-1]) # magic (ggml lora) - fout.write(struct.pack("i", 1)) # file version + +def write_file_header(fout, params): + fout.write(b"ggla"[::-1]) # magic (ggml lora) + fout.write(struct.pack("i", 1)) # file version + fout.write(struct.pack("ii", params["r"], params["lora_alpha"])) def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None: - sname = name.encode('utf-8') - fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]])) + sname = name.encode("utf-8") + fout.write( + struct.pack( + "iii", + len(shape), + len(sname), + DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]], + ) + ) fout.write(struct.pack("i" * len(shape), *shape[::-1])) fout.write(sname) fout.seek((fout.tell() + 31) & -32) - -if len(sys.argv) < 2: - print(f"Usage: python {sys.argv[0]} adapter_model.bin [ggml_adapter_model.bin]") + +if len(sys.argv) != 2: + print(f"Usage: python {sys.argv[0]} ") + print( + "Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'" + ) sys.exit(1) -input_path = sys.argv[1] -if len(sys.argv) > 2: - output_path = sys.argv[2] -else: - output_filename = f"ggml_{os.path.basename(input_path)}" - output_path = os.path.join(os.path.dirname(input_path), output_filename) +input_json = os.path.join(sys.argv[1], "adapter_config.json") +input_model = os.path.join(sys.argv[1], "adapter_model.bin") +output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin") -model = torch.load(input_path, map_location="cpu") +model = torch.load(input_model, map_location="cpu") + +with open(input_json, "r") as f: + params = json.load(f) + +if params["peft_type"] != "LORA": + print(f"Error: unsupported adapter type {params['peft_type']}, expected LORA") + sys.exit(1) + +if params["fan_in_fan_out"] == True: + print("Error: param fan_in_fan_out is not supported") + sys.exit(1) + +if params["bias"] is not None and params["bias"] != "none": + print("Error: param bias is not supported") + sys.exit(1) + +# TODO: these seem to be layers that have been trained but without lora. +# doesn't seem widely used but eventually should be supported +if params["modules_to_save"] is not None and len(params["modules_to_save"]) > 0: + print("Error: param modules_to_save is not supported") + sys.exit(1) with open(output_path, "wb") as fout: - write_file_header(fout) + fout.truncate() + + write_file_header(fout, params) for k, v in model.items(): # since ggml doesn't always support other types for the second operand, # the tensors are always converted and exported as f32 - t = v.float().numpy() + v = v.float() + t = v.numpy() if "lora_A" in k: t = t.T - print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB") + print( + f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB" + ) write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype) t.tofile(fout) -print(f"Converted {input_path} to {output_path}") \ No newline at end of file +print(f"Converted {input_json} and {input_model} to {output_path}") diff --git a/llama.cpp b/llama.cpp index 4fcd2ecfbec48..40e0beaf51236 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1789,6 +1789,15 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor } } + int32_t lora_r; + int32_t lora_alpha; + fin.read((char *) &lora_r, sizeof(lora_r)); + fin.read((char *) &lora_alpha, sizeof(lora_alpha)); + float scaling = (float)lora_alpha / (float)lora_r; + + fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); + + // create a temporary ggml context to store the lora tensors std::vector buf(1024 * 1024 * 100); struct ggml_init_params params; @@ -1890,11 +1899,13 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor // w = w + BA*s ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA); - //if (true) { - // ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, 1.0f); - // BA = ggml_scale(lora_ctx, BA, scale_tensor); - //} - ggml_tensor * r = ggml_add(lora_ctx, tensor, BA); + if (scaling != 1.0f) { + ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling); + BA = ggml_scale(lora_ctx, BA, scale_tensor); + } + + ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA); + //ggml_tensor * r = ggml_add(lora_ctx, tensor, BA); //r = ggml_cpy(lora_ctx, r, tensor); struct ggml_cgraph gf = ggml_build_forward(r); @@ -1902,7 +1913,7 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor ggml_graph_compute(lora_ctx, &gf); // hack until ggml_cpy supports quantized tensors - memcpy(tensor->data, r->data, ggml_nbytes(tensor)); + // memcpy(tensor->data, r->data, ggml_nbytes(tensor)); // we won't need these tensors again, reset the context to save memory ggml_free(lora_ctx); From c920f00136292f15632c4ee88546da819b57e0a6 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Mon, 10 Apr 2023 21:52:10 +0200 Subject: [PATCH 06/15] Add compatibility with #801 --- examples/common.cpp | 3 ++- llama.cpp | 10 ++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 403b2cc15730f..2656b3c903a63 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -145,6 +145,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.lora_adapter = argv[i]; + params.use_mmap = false; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--embedding") { @@ -248,7 +249,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { } fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); - fprintf(stderr, " --lora FNAME apply LoRA adapter\n"); + fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/llama.cpp b/llama.cpp index 40e0beaf51236..db534ddb7d194 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1808,6 +1808,12 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor ggml_context* lora_ctx = ggml_init(params); std::unordered_map lora_tensors; + // create a name -> tensor map of the model to accelerate lookups + std::unordered_map model_tensors; + for (auto & kv: model.tensors_by_name) { + model_tensors.insert(kv); + } + fprintf(stderr, "%s: ", __func__); // read tensors and apply @@ -1847,7 +1853,7 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor base_name.erase(pos); // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); - if (model.tensors.find(base_name.data()) == model.tensors.end()) { + if (model_tensors.find(base_name.data()) == model_tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); return 1; } @@ -1886,7 +1892,7 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { - ggml_tensor * tensor = model.tensors[base_name]; + ggml_tensor * tensor = model_tensors[base_name]; ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; From c45868ba9f5e358b42e8957a7c025efd6139ad7a Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Tue, 11 Apr 2023 23:15:29 +0200 Subject: [PATCH 07/15] Support more layer types, fix memory and generation issues --- convert-lora-to-ggml.py | 47 ++++++++++++++++++++++------------------- ggml.c | 5 ----- llama.cpp | 20 ++++++++++-------- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py index 535381ecbda72..c6d48b97643ec 100644 --- a/convert-lora-to-ggml.py +++ b/convert-lora-to-ggml.py @@ -44,18 +44,18 @@ class QuantizedDataType: } HF_SUBLAYER_TO_GGML = { - "self_attn.q_proj": "attention.wq.weight", - "self_attn.k_proj": "attention.wk.weight", - "self_attn.v_proj": "attention.wv.weight", - "self_attn.o_proj": "attention.wo.weight", - # "embed_tokens.weight": "tok_embeddings.weight", - # "norm.weight": "norm.weight", - # "lm_head.weight": "output.weight", - # "mlp.gate_proj": "feed_forward.w1.weight", - # "mlp.down_proj": "feed_forward.w2.weight", - # "mlp.up_proj": "feed_forward.w3.weight", - # "input_layernorm": "attention_norm.weight", - # "post_attention_layernorm": "ffn_norm.weight", + "self_attn.q_proj": "attention.wq", + "self_attn.k_proj": "attention.wk", + "self_attn.v_proj": "attention.wv", + "self_attn.o_proj": "attention.wo", + "mlp.gate_proj": "feed_forward.w1", + "mlp.down_proj": "feed_forward.w2", + "mlp.up_proj": "feed_forward.w3", + "input_layernorm": "attention_norm", + "post_attention_layernorm": "ffn_norm", + # "norm": "norm", + # "embed_tokens": "tok_embeddings", + # "lm_head": "output", } @@ -71,7 +71,9 @@ def translate_tensor_name(t): print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}") sys.exit(1) - output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}" + output_string = ( + f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}" + ) return output_string else: print(f"Error: unrecognized tensor {t}") @@ -138,16 +140,17 @@ def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> write_file_header(fout, params) for k, v in model.items(): - # since ggml doesn't always support other types for the second operand, - # the tensors are always converted and exported as f32 - v = v.float() + if k.endswith("lora_A.weight"): + if v.dtype != torch.float16 and v.dtype != torch.float32: + v = v.float() + v = v.T + else: + v = v.float() + t = v.numpy() - if "lora_A" in k: - t = t.T - print( - f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB" - ) - write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype) + tname = translate_tensor_name(k) + print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB") + write_tensor_header(fout, tname, t.shape, t.dtype) t.tofile(fout) print(f"Converted {input_json} and {input_model} to {output_path}") diff --git a/ggml.c b/ggml.c index 3dffae8b34d33..8606e9344c172 100644 --- a/ggml.c +++ b/ggml.c @@ -5955,11 +5955,6 @@ static void ggml_compute_forward_add_q_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1); GGML_ASSERT(dst->type == src0->type); GGML_ASSERT(src1->type == GGML_TYPE_F32); diff --git a/llama.cpp b/llama.cpp index db534ddb7d194..0627c9b9c9d36 100644 --- a/llama.cpp +++ b/llama.cpp @@ -617,6 +617,7 @@ struct llama_model_loader { throw format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s", name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str()); } + return get_tensor_for(lt); } @@ -1799,7 +1800,8 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor // create a temporary ggml context to store the lora tensors - std::vector buf(1024 * 1024 * 100); + // todo: calculate size from biggest possible tensor + std::vector buf(1024ull * 1024ull * 1024ull); struct ggml_init_params params; params.mem_size = buf.size(); params.mem_buffer = buf.data(); @@ -1830,11 +1832,9 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor break; } - int32_t nelements = 1; int32_t ne[2] = { 1, 1 }; for (int i = 0; i < n_dims; ++i) { fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); - nelements *= ne[i]; } std::string name(length, 0); @@ -1903,24 +1903,26 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor } // w = w + BA*s - ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA); + ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB); if (scaling != 1.0f) { ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling); BA = ggml_scale(lora_ctx, BA, scale_tensor); } + //printf("%s: (B)(%d %d %d %d) x (A)(%d %d %d %d) => (BA)(%d %d %d %d) + (T)(%d %d %d %d)\n", + // base_name.c_str(), + // (int)loraB->ne[0], (int)loraB->ne[1], (int)loraB->ne[2], (int)loraB->ne[3], + // (int)loraA->ne[0], (int)loraA->ne[1], (int)loraA->ne[2], (int)loraA->ne[3], + // (int)BA->ne[0], (int)BA->ne[1], (int)BA->ne[2], (int)BA->ne[3], + // (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], (int)tensor->ne[3] + //); ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA); - //ggml_tensor * r = ggml_add(lora_ctx, tensor, BA); - //r = ggml_cpy(lora_ctx, r, tensor); struct ggml_cgraph gf = ggml_build_forward(r); gf.n_threads = n_threads; ggml_graph_compute(lora_ctx, &gf); - // hack until ggml_cpy supports quantized tensors - // memcpy(tensor->data, r->data, ggml_nbytes(tensor)); - // we won't need these tensors again, reset the context to save memory ggml_free(lora_ctx); lora_ctx = ggml_init(params); From 57627f0e5faa5bb6b8fd456ccbe1aaf83cff695f Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Thu, 13 Apr 2023 18:06:33 +0200 Subject: [PATCH 08/15] Rebase to master --- ggml.c | 58 +++++++++++++++++++++++++++++-------------------------- llama.cpp | 4 ++-- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/ggml.c b/ggml.c index 8606e9344c172..4f1ad459c0a18 100644 --- a/ggml.c +++ b/ggml.c @@ -1420,6 +1420,34 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in #endif } +static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { + [GGML_TYPE_Q4_0] = { + .dequantize_row_q = dequantize_row_q4_0, + .quantize_row_q = quantize_row_q4_0, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q4_0_q8_0, + }, + [GGML_TYPE_Q4_1] = { + .dequantize_row_q = dequantize_row_q4_1, + .quantize_row_q = quantize_row_q4_1, + .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, + .quantize_row_q_dot = quantize_row_q4_1, + .vec_dot_q = ggml_vec_dot_q4_1, + }, + // TODO: GGML_TYPE_Q8_0 +}; + +// For internal test use +quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { + GGML_ASSERT(i < GGML_TYPE_COUNT); + return quantize_fns[i]; +} + + // // simd mappings // @@ -5910,12 +5938,12 @@ static void ggml_compute_forward_add_q_f32( const int64_t ne03 = src0->ne[3]; //const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; + //const int64_t ne11 = src1->ne[1]; const int64_t ne12 = src1->ne[2]; const int64_t ne13 = src1->ne[3]; - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; + //const int64_t ne0 = dst->ne[0]; + //const int64_t ne1 = dst->ne[1]; const int64_t ne2 = dst->ne[2]; const int64_t ne3 = dst->ne[3]; @@ -7307,30 +7335,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( //} } -static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { - [GGML_TYPE_Q4_0] = { - .dequantize_row_q = dequantize_row_q4_0, - .quantize_row_q = quantize_row_q4_0, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, - .quantize_row_q_dot = quantize_row_q8_0, - .vec_dot_q = ggml_vec_dot_q4_0_q8_0, - }, - [GGML_TYPE_Q4_1] = { - .dequantize_row_q = dequantize_row_q4_1, - .quantize_row_q = quantize_row_q4_1, - .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, - .quantize_row_q_dot = quantize_row_q4_1, - .vec_dot_q = ggml_vec_dot_q4_1, - }, - // TODO: GGML_TYPE_Q8_0 -}; - -// For internal test use -quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { - GGML_ASSERT(i < GGML_TYPE_COUNT); - return quantize_fns[i]; -} - static void ggml_compute_forward_mul_mat_q_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, diff --git a/llama.cpp b/llama.cpp index 0627c9b9c9d36..33209615e4fdb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1896,8 +1896,8 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; - if (tensor->ne[0] != loraA->ne[1]) { - fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" + if (tensor->ne[0] != loraA->ne[1] || tensor->ne[1] != loraB->ne[1]) { + fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" " are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]); return 1; } From c150e1b0c3090f9f98452e2abf845d12b5bb8aed Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:45:00 +0200 Subject: [PATCH 09/15] Add support for using a different base model --- examples/common.cpp | 7 +++ examples/common.h | 1 + examples/main/main.cpp | 5 +- examples/perplexity/perplexity.cpp | 5 +- ggml.c | 36 ++++++++++++ llama.cpp | 92 ++++++++++++++++++++++++------ llama.h | 7 ++- llama_util.h | 28 +++++---- 8 files changed, 148 insertions(+), 33 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 2656b3c903a63..a0b6f10ad8c8b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -146,6 +146,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } params.lora_adapter = argv[i]; params.use_mmap = false; + } else if (arg == "--lora-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_base = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--embedding") { @@ -250,6 +256,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); + fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/examples/common.h b/examples/common.h index ba825f3061ced..cbbc2dfab16de 100644 --- a/examples/common.h +++ b/examples/common.h @@ -35,6 +35,7 @@ struct gpt_params { std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path + std::string lora_base = ""; // base model path for the lora adapter bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a50fc641cc3aa..b7b3c419655f6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -115,7 +115,10 @@ int main(int argc, char ** argv) { } if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads); + int err = llama_apply_lora_from_file(ctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); return 1; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 716c5e0e4a417..80792ea0d95d0 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -135,7 +135,10 @@ int main(int argc, char ** argv) { } if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads); + int err = llama_apply_lora_from_file(ctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); return 1; diff --git a/ggml.c b/ggml.c index 4f1ad459c0a18..0a7c811bc305c 100644 --- a/ggml.c +++ b/ggml.c @@ -5461,6 +5461,27 @@ static void ggml_compute_forward_dup_f16( } } } + } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) { + quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; + size_t id = 0; + uint8_t * dst_ptr = (uint8_t *) dst->data; + size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); + // todo: use work buffer + float * src0_f32 = (float *) alloca(ne00 * sizeof(float)); + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + // convert to f32 and quantize + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); + } + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += dst_row_size; + } + } + } } else { GGML_ASSERT(false); // TODO: implement } @@ -5653,6 +5674,21 @@ static void ggml_compute_forward_dup_f32( } } } + } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) { + quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; + size_t id = 0; + uint8_t * dst_ptr = (uint8_t *) dst->data; + size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += dst_row_size; + } + } + } } else { GGML_ASSERT(false); // TODO: implement } diff --git a/llama.cpp b/llama.cpp index 33209615e4fdb..c1bc073109f51 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,6 +1,8 @@ // Defines fileno on msys: #ifndef _GNU_SOURCE #define _GNU_SOURCE +#include +#include #endif #include "llama_util.h" @@ -1759,8 +1761,7 @@ int llama_model_quantize( } } -int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) { - // TODO: refactor all of this after PR #801 +int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); auto & model = ctx->model; @@ -1801,13 +1802,13 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor // create a temporary ggml context to store the lora tensors // todo: calculate size from biggest possible tensor - std::vector buf(1024ull * 1024ull * 1024ull); + std::vector lora_buf(1024ull * 1024ull * 1024ull); struct ggml_init_params params; - params.mem_size = buf.size(); - params.mem_buffer = buf.data(); + params.mem_size = lora_buf.size(); + params.mem_buffer = lora_buf.data(); params.no_alloc = false; - ggml_context* lora_ctx = ggml_init(params); + ggml_context * lora_ctx = ggml_init(params); std::unordered_map lora_tensors; // create a name -> tensor map of the model to accelerate lookups @@ -1816,6 +1817,32 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor model_tensors.insert(kv); } + + // load base model + std::unique_ptr model_loader; + ggml_context * base_ctx = NULL; + llama_buffer base_buf; + if (path_base_model) { + fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model); + model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false)); + + size_t ctx_size, mmapped_size; + model_loader->calc_sizes(&ctx_size, &mmapped_size); + base_buf.resize(ctx_size); + + ggml_init_params base_params; + base_params.mem_size = base_buf.size; + base_params.mem_buffer = base_buf.addr; + base_params.no_alloc = model_loader->use_mmap; + + base_ctx = ggml_init(base_params); + + model_loader->ggml_ctx = base_ctx; + + // maybe this should in llama_model_loader + model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, false)); + } + fprintf(stderr, "%s: ", __func__); // read tensors and apply @@ -1892,13 +1919,31 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { - ggml_tensor * tensor = model_tensors[base_name]; + ggml_tensor * dest_t = model_tensors[base_name]; + ggml_tensor * base_t; + if (model_loader) { + // load from base model + if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) { + fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); + return 1; + } + size_t idx = model_loader->tensors_map.name_to_idx[base_name]; + llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; + base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }); + lt.data = (uint8_t *) lt.ggml_tensor->data; + model_loader->load_data_for(lt); + lt.ggml_tensor->data = lt.data; + } + else { + base_t = dest_t; + } + ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; - if (tensor->ne[0] != loraA->ne[1] || tensor->ne[1] != loraB->ne[1]) { + if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" - " are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]); + " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); return 1; } @@ -1910,14 +1955,14 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor BA = ggml_scale(lora_ctx, BA, scale_tensor); } - //printf("%s: (B)(%d %d %d %d) x (A)(%d %d %d %d) => (BA)(%d %d %d %d) + (T)(%d %d %d %d)\n", - // base_name.c_str(), - // (int)loraB->ne[0], (int)loraB->ne[1], (int)loraB->ne[2], (int)loraB->ne[3], - // (int)loraA->ne[0], (int)loraA->ne[1], (int)loraA->ne[2], (int)loraA->ne[3], - // (int)BA->ne[0], (int)BA->ne[1], (int)BA->ne[2], (int)BA->ne[3], - // (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], (int)tensor->ne[3] - //); - ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA); + ggml_tensor * r; + if (base_t == dest_t) { + r = ggml_add_inplace(lora_ctx, dest_t, BA); + } + else { + r = ggml_add(lora_ctx, base_t, BA); + r = ggml_cpy(lora_ctx, r, dest_t); + } struct ggml_cgraph gf = ggml_build_forward(r); gf.n_threads = n_threads; @@ -1934,7 +1979,11 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor } } + // TODO: this should be in a destructor, it will leak on failure ggml_free(lora_ctx); + if (base_ctx) { + ggml_free(base_ctx); + } const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0); @@ -1942,6 +1991,15 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor return 0; } +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { + try { + return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads); + } catch (const std::string & err) { + fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str()); + return 1; + } +} + // Returns the KV cache that will contain the context for the // ongoing prediction with the model. const uint8_t * llama_get_kv_cache(struct llama_context * ctx) { diff --git a/llama.h b/llama.h index 535f1b18eb9cf..c35193a8a80de 100644 --- a/llama.h +++ b/llama.h @@ -97,12 +97,15 @@ extern "C" { enum llama_ftype ftype); // Apply a LoRA adapter to a loaded model - // The model needs to be reloaded before applying a new adapter, otherwise - // the adapter will the applied on top of the previous one + // path_base_model is the path to a higher quality model to use as a base for + // the layers modified by the adapter. Can be NULL to use the current loaded model. + // The model needs to be reloaded before applying a new adapter, otherwise the adapter + // will be applied on top of the previous one // Returns 0 on success LLAMA_API int llama_apply_lora_from_file( struct llama_context * ctx, const char * path_lora, + const char * path_base_model, int n_threads); // Returns the KV cache that will contain the context for the diff --git a/llama_util.h b/llama_util.h index d2110ebb4f642..da900a5e43528 100755 --- a/llama_util.h +++ b/llama_util.h @@ -168,7 +168,7 @@ struct llama_mmap { #ifdef _POSIX_MAPPED_FILES static constexpr bool SUPPORTED = true; - llama_mmap(struct llama_file * file) { + llama_mmap(struct llama_file * file, bool prefetch = true) { size = file->size; int fd = fileno(file->fp); int flags = MAP_SHARED; @@ -181,10 +181,12 @@ struct llama_mmap { throw format("mmap failed: %s", strerror(errno)); } - // Advise the kernel to preload the mapped memory - if (madvise(addr, file->size, MADV_WILLNEED)) { - fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", - strerror(errno)); + if (prefetch) { + // Advise the kernel to preload the mapped memory + if (madvise(addr, file->size, MADV_WILLNEED)) { + fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } } } @@ -216,13 +218,15 @@ struct llama_mmap { } #if _WIN32_WINNT >= _WIN32_WINNT_WIN8 - // Advise the kernel to preload the mapped memory - WIN32_MEMORY_RANGE_ENTRY range; - range.VirtualAddress = addr; - range.NumberOfBytes = (SIZE_T)size; - if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { - fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (prefetch) { + // Advise the kernel to preload the mapped memory + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T)size; + if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } #else #pragma message("warning: You are building for pre-Windows 8; prefetch not supported") From fc899160020d083814f44e1325855cf7122a5406 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 15 Apr 2023 19:54:56 +0200 Subject: [PATCH 10/15] Fix windows build --- llama_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_util.h b/llama_util.h index da900a5e43528..6531d0dadfc85 100755 --- a/llama_util.h +++ b/llama_util.h @@ -196,7 +196,7 @@ struct llama_mmap { #elif defined(_WIN32) static constexpr bool SUPPORTED = true; - llama_mmap(struct llama_file * file) { + llama_mmap(struct llama_file * file, bool prefetch = true) { size = file->size; HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp)); From 14858ba2bf9c7363861d76658a707c8cf5e57cae Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 15 Apr 2023 20:11:07 +0200 Subject: [PATCH 11/15] Show warning when using a quantized base model --- llama.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index c1bc073109f51..87c89b016921b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1843,9 +1843,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, false)); } - fprintf(stderr, "%s: ", __func__); - // read tensors and apply + bool warned = false; int n_tensors = 0; while (true) { int32_t n_dims; @@ -1938,6 +1937,14 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * base_t = dest_t; } + if (base_t->type == GGML_TYPE_Q4_0 || base_t->type == GGML_TYPE_Q4_1) { + if (!warned) { + fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, " + "use a f16 or f32 base model with --lora-base\n", __func__); + warned = true; + } + } + ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; @@ -1974,7 +1981,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * lora_tensors.clear(); n_tensors++; - if (n_tensors % 8 == 0) + if (n_tensors % 4 == 0) fprintf(stderr, "."); } } From 3df343b4f0ed9699dec4593f15d4be54f095962d Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sat, 15 Apr 2023 20:29:05 +0200 Subject: [PATCH 12/15] ggml_cpy: use the work buffer instead of alloca when quantizing --- ggml.c | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml.c b/ggml.c index 0a7c811bc305c..e88b9fbe4ab5b 100644 --- a/ggml.c +++ b/ggml.c @@ -5466,8 +5466,7 @@ static void ggml_compute_forward_dup_f16( size_t id = 0; uint8_t * dst_ptr = (uint8_t *) dst->data; size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); - // todo: use work buffer - float * src0_f32 = (float *) alloca(ne00 * sizeof(float)); + float * src0_f32 = (float *) params->wdata; for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { @@ -10227,9 +10226,17 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) struct ggml_tensor * node = cgraph->nodes[i]; switch (node->op) { + case GGML_OP_CPY: case GGML_OP_DUP: { node->n_tasks = 1; + + size_t cur = 0; + if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0]; + } + + work_size = MAX(work_size, cur); } break; case GGML_OP_ADD: { @@ -10322,7 +10329,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = n_threads; } break; - case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_RESHAPE: case GGML_OP_VIEW: From 63da54e0164a530414f948a3115e0535ca20f5bb Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sun, 16 Apr 2023 18:30:27 +0200 Subject: [PATCH 13/15] Only attempt to use mmap for the lora base model if it is supported --- llama.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 87c89b016921b..4f222ce57a129 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1840,7 +1840,9 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * model_loader->ggml_ctx = base_ctx; // maybe this should in llama_model_loader - model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, false)); + if (model_loader->use_mmap) { + model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false)); + } } // read tensors and apply From 0a6d5ad7cc9f950efa233e77e3effe14d5822956 Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sun, 16 Apr 2023 18:52:22 +0200 Subject: [PATCH 14/15] Reuse definitions from convert.py --- convert-lora-to-ggml.py | 46 +++++++---------------------------------- 1 file changed, 7 insertions(+), 39 deletions(-) diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py index c6d48b97643ec..8a2085c2511a1 100644 --- a/convert-lora-to-ggml.py +++ b/convert-lora-to-ggml.py @@ -3,45 +3,11 @@ import re import struct import sys -from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any, Dict, Sequence, TextIO -import numpy as np import torch - -# TODO: import this from convert.py once #545 is merged -@dataclass(frozen=True) -class UnquantizedDataType: - name: str - - -DT_F16 = UnquantizedDataType("F16") -DT_F32 = UnquantizedDataType("F32") - - -@dataclass(frozen=True) -class QuantizedDataType: - groupsize: int - have_addends: bool - have_g_idx: bool - - -DataType = UnquantizedDataType - -DATA_TYPE_TO_FTYPE: dict[DataType, int] = { - DT_F32: 0, - DT_F16: 1, -} - -DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = { - DT_F16: np.dtype(np.float16), - DT_F32: np.dtype(np.float32), -} - -NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = { - dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items() -} +from convert import DATA_TYPE_TO_FTYPE, NUMPY_TYPE_TO_DATA_TYPE, DataType HF_SUBLAYER_TO_GGML = { "self_attn.q_proj": "attention.wq", @@ -59,7 +25,7 @@ class QuantizedDataType: } -def translate_tensor_name(t): +def translate_tensor_name(t: str) -> str: match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t) if match: nn = match.group(1) @@ -80,13 +46,15 @@ def translate_tensor_name(t): sys.exit(1) -def write_file_header(fout, params): +def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None: fout.write(b"ggla"[::-1]) # magic (ggml lora) fout.write(struct.pack("i", 1)) # file version fout.write(struct.pack("ii", params["r"], params["lora_alpha"])) -def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None: +def write_tensor_header( + self, name: str, shape: Sequence[int], data_type: DataType +) -> None: sname = name.encode("utf-8") fout.write( struct.pack( From 8d37db3cdfdf700d94de2d18697bed54a859c03a Mon Sep 17 00:00:00 2001 From: Slaren <2141330+slaren@users.noreply.github.com> Date: Sun, 16 Apr 2023 18:53:44 +0200 Subject: [PATCH 15/15] ggml_add: Add more checks --- ggml.c | 50 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/ggml.c b/ggml.c index e88b9fbe4ab5b..a822cc9b4d410 100644 --- a/ggml.c +++ b/ggml.c @@ -5893,27 +5893,36 @@ static void ggml_compute_forward_add_f16_f32( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - //const size_t nb00 = src0->nb[0]; + const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb10 = src1->nb[0]; const size_t nb11 = src1->nb[1]; - //const size_t nb0 = dst->nb[0]; + const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F16); - for (int j = ith; j < n; j += nth) { - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { - float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(float)) { + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); + } } } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } } static void ggml_compute_forward_add_f16_f16( @@ -5933,27 +5942,36 @@ static void ggml_compute_forward_add_f16_f16( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - //const size_t nb00 = src0->nb[0]; + const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb10 = src1->nb[0]; const size_t nb11 = src1->nb[1]; - //const size_t nb0 = dst->nb[0]; + const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F16); - for (int j = ith; j < n; j += nth) { - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10); - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr)); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10); + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr)); + } } } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } } static void ggml_compute_forward_add_q_f32(