Skip to content

Commit e369c99

Browse files
committed
Add support for using a different base model
1 parent fdae61c commit e369c99

File tree

8 files changed

+148
-33
lines changed

8 files changed

+148
-33
lines changed

examples/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
146146
}
147147
params.lora_adapter = argv[i];
148148
params.use_mmap = false;
149+
} else if (arg == "--lora-base") {
150+
if (++i >= argc) {
151+
invalid_param = true;
152+
break;
153+
}
154+
params.lora_base = argv[i];
149155
} else if (arg == "-i" || arg == "--interactive") {
150156
params.interactive = true;
151157
} else if (arg == "--embedding") {
@@ -250,6 +256,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
250256
fprintf(stderr, " --mtest compute maximum memory usage\n");
251257
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
252258
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
259+
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
253260
fprintf(stderr, " -m FNAME, --model FNAME\n");
254261
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
255262
fprintf(stderr, "\n");

examples/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct gpt_params {
3535
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
3636

3737
std::string lora_adapter = ""; // lora adapter path
38+
std::string lora_base = ""; // base model path for the lora adapter
3839

3940
bool memory_f16 = true; // use f16 instead of f32 for memory kv
4041
bool random_prompt = false; // do not randomize prompt if none provided

examples/main/main.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ int main(int argc, char ** argv) {
114114
}
115115

116116
if (!params.lora_adapter.empty()) {
117-
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
117+
int err = llama_apply_lora_from_file(ctx,
118+
params.lora_adapter.c_str(),
119+
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
120+
params.n_threads);
118121
if (err != 0) {
119122
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
120123
return 1;

examples/perplexity/perplexity.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ int main(int argc, char ** argv) {
134134
}
135135

136136
if (!params.lora_adapter.empty()) {
137-
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
137+
int err = llama_apply_lora_from_file(ctx,
138+
params.lora_adapter.c_str(),
139+
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
140+
params.n_threads);
138141
if (err != 0) {
139142
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
140143
return 1;

ggml.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5461,6 +5461,27 @@ static void ggml_compute_forward_dup_f16(
54615461
}
54625462
}
54635463
}
5464+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5465+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5466+
size_t id = 0;
5467+
uint8_t * dst_ptr = (uint8_t *) dst->data;
5468+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5469+
// todo: use work buffer
5470+
float * src0_f32 = (float *) alloca(ne00 * sizeof(float));
5471+
5472+
for (int i03 = 0; i03 < ne03; i03++) {
5473+
for (int i02 = 0; i02 < ne02; i02++) {
5474+
for (int i01 = 0; i01 < ne01; i01++) {
5475+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5476+
// convert to f32 and quantize
5477+
for (int i00 = 0; i00 < ne00; i00++) {
5478+
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5479+
}
5480+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
5481+
id += dst_row_size;
5482+
}
5483+
}
5484+
}
54645485
} else {
54655486
GGML_ASSERT(false); // TODO: implement
54665487
}
@@ -5653,6 +5674,21 @@ static void ggml_compute_forward_dup_f32(
56535674
}
56545675
}
56555676
}
5677+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5678+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5679+
size_t id = 0;
5680+
uint8_t * dst_ptr = (uint8_t *) dst->data;
5681+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5682+
5683+
for (int i03 = 0; i03 < ne03; i03++) {
5684+
for (int i02 = 0; i02 < ne02; i02++) {
5685+
for (int i01 = 0; i01 < ne01; i01++) {
5686+
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5687+
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5688+
id += dst_row_size;
5689+
}
5690+
}
5691+
}
56565692
} else {
56575693
GGML_ASSERT(false); // TODO: implement
56585694
}

llama.cpp

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Defines fileno on msys:
22
#ifndef _GNU_SOURCE
33
#define _GNU_SOURCE
4+
#include <cstdint>
5+
#include <cstdio>
46
#endif
57

68
#include "llama_util.h"
@@ -1758,8 +1760,7 @@ int llama_model_quantize(
17581760
}
17591761
}
17601762

1761-
int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) {
1762-
// TODO: refactor all of this after PR #801
1763+
int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
17631764
fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
17641765

17651766
auto & model = ctx->model;
@@ -1800,13 +1801,13 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
18001801

18011802
// create a temporary ggml context to store the lora tensors
18021803
// todo: calculate size from biggest possible tensor
1803-
std::vector<uint8_t> buf(1024ull * 1024ull * 1024ull);
1804+
std::vector<uint8_t> lora_buf(1024ull * 1024ull * 1024ull);
18041805
struct ggml_init_params params;
1805-
params.mem_size = buf.size();
1806-
params.mem_buffer = buf.data();
1806+
params.mem_size = lora_buf.size();
1807+
params.mem_buffer = lora_buf.data();
18071808
params.no_alloc = false;
18081809

1809-
ggml_context* lora_ctx = ggml_init(params);
1810+
ggml_context * lora_ctx = ggml_init(params);
18101811
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
18111812

18121813
// create a name -> tensor map of the model to accelerate lookups
@@ -1815,6 +1816,32 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
18151816
model_tensors.insert(kv);
18161817
}
18171818

1819+
1820+
// load base model
1821+
std::unique_ptr<llama_model_loader> model_loader;
1822+
ggml_context * base_ctx = NULL;
1823+
llama_buffer base_buf;
1824+
if (path_base_model) {
1825+
fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model);
1826+
model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false));
1827+
1828+
size_t ctx_size, mmapped_size;
1829+
model_loader->calc_sizes(&ctx_size, &mmapped_size);
1830+
base_buf.resize(ctx_size);
1831+
1832+
ggml_init_params base_params;
1833+
base_params.mem_size = base_buf.size;
1834+
base_params.mem_buffer = base_buf.addr;
1835+
base_params.no_alloc = model_loader->use_mmap;
1836+
1837+
base_ctx = ggml_init(base_params);
1838+
1839+
model_loader->ggml_ctx = base_ctx;
1840+
1841+
// maybe this should in llama_model_loader
1842+
model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, false));
1843+
}
1844+
18181845
fprintf(stderr, "%s: ", __func__);
18191846

18201847
// read tensors and apply
@@ -1891,13 +1918,31 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
18911918
if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
18921919
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
18931920

1894-
ggml_tensor * tensor = model_tensors[base_name];
1921+
ggml_tensor * dest_t = model_tensors[base_name];
1922+
ggml_tensor * base_t;
1923+
if (model_loader) {
1924+
// load from base model
1925+
if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) {
1926+
fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
1927+
return 1;
1928+
}
1929+
size_t idx = model_loader->tensors_map.name_to_idx[base_name];
1930+
llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
1931+
base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
1932+
lt.data = (uint8_t *) lt.ggml_tensor->data;
1933+
model_loader->load_data_for(lt);
1934+
lt.ggml_tensor->data = lt.data;
1935+
}
1936+
else {
1937+
base_t = dest_t;
1938+
}
1939+
18951940
ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
18961941
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
18971942

1898-
if (tensor->ne[0] != loraA->ne[1] || tensor->ne[1] != loraB->ne[1]) {
1943+
if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
18991944
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
1900-
" are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]);
1945+
" are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
19011946
return 1;
19021947
}
19031948

@@ -1909,14 +1954,14 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
19091954
BA = ggml_scale(lora_ctx, BA, scale_tensor);
19101955
}
19111956

1912-
//printf("%s: (B)(%d %d %d %d) x (A)(%d %d %d %d) => (BA)(%d %d %d %d) + (T)(%d %d %d %d)\n",
1913-
// base_name.c_str(),
1914-
// (int)loraB->ne[0], (int)loraB->ne[1], (int)loraB->ne[2], (int)loraB->ne[3],
1915-
// (int)loraA->ne[0], (int)loraA->ne[1], (int)loraA->ne[2], (int)loraA->ne[3],
1916-
// (int)BA->ne[0], (int)BA->ne[1], (int)BA->ne[2], (int)BA->ne[3],
1917-
// (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], (int)tensor->ne[3]
1918-
//);
1919-
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
1957+
ggml_tensor * r;
1958+
if (base_t == dest_t) {
1959+
r = ggml_add_inplace(lora_ctx, dest_t, BA);
1960+
}
1961+
else {
1962+
r = ggml_add(lora_ctx, base_t, BA);
1963+
r = ggml_cpy(lora_ctx, r, dest_t);
1964+
}
19201965

19211966
struct ggml_cgraph gf = ggml_build_forward(r);
19221967
gf.n_threads = n_threads;
@@ -1933,14 +1978,27 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
19331978
}
19341979
}
19351980

1981+
// TODO: this should be in a destructor, it will leak on failure
19361982
ggml_free(lora_ctx);
1983+
if (base_ctx) {
1984+
ggml_free(base_ctx);
1985+
}
19371986

19381987
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
19391988
fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0);
19401989

19411990
return 0;
19421991
}
19431992

1993+
int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
1994+
try {
1995+
return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads);
1996+
} catch (const std::string & err) {
1997+
fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str());
1998+
return 1;
1999+
}
2000+
}
2001+
19442002
// Returns the KV cache that will contain the context for the
19452003
// ongoing prediction with the model.
19462004
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {

llama.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,15 @@ extern "C" {
9797
enum llama_ftype ftype);
9898

9999
// Apply a LoRA adapter to a loaded model
100-
// The model needs to be reloaded before applying a new adapter, otherwise
101-
// the adapter will the applied on top of the previous one
100+
// path_base_model is the path to a higher quality model to use as a base for
101+
// the layers modified by the adapter. Can be NULL to use the current loaded model.
102+
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
103+
// will be applied on top of the previous one
102104
// Returns 0 on success
103105
LLAMA_API int llama_apply_lora_from_file(
104106
struct llama_context * ctx,
105107
const char * path_lora,
108+
const char * path_base_model,
106109
int n_threads);
107110

108111
// Returns the KV cache that will contain the context for the

llama_util.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ struct llama_mmap {
164164
#ifdef _POSIX_MAPPED_FILES
165165
static constexpr bool SUPPORTED = true;
166166

167-
llama_mmap(struct llama_file * file) {
167+
llama_mmap(struct llama_file * file, bool prefetch = true) {
168168
size = file->size;
169169
int fd = fileno(file->fp);
170170
int flags = MAP_SHARED;
@@ -177,10 +177,12 @@ struct llama_mmap {
177177
throw format("mmap failed: %s", strerror(errno));
178178
}
179179

180-
// Advise the kernel to preload the mapped memory
181-
if (madvise(addr, file->size, MADV_WILLNEED)) {
182-
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
183-
strerror(errno));
180+
if (prefetch) {
181+
// Advise the kernel to preload the mapped memory
182+
if (madvise(addr, file->size, MADV_WILLNEED)) {
183+
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
184+
strerror(errno));
185+
}
184186
}
185187
}
186188

@@ -212,13 +214,15 @@ struct llama_mmap {
212214
}
213215

214216
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
215-
// Advise the kernel to preload the mapped memory
216-
WIN32_MEMORY_RANGE_ENTRY range;
217-
range.VirtualAddress = addr;
218-
range.NumberOfBytes = (SIZE_T)size;
219-
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
220-
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
221-
llama_format_win_err(GetLastError()).c_str());
217+
if (prefetch) {
218+
// Advise the kernel to preload the mapped memory
219+
WIN32_MEMORY_RANGE_ENTRY range;
220+
range.VirtualAddress = addr;
221+
range.NumberOfBytes = (SIZE_T)size;
222+
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
223+
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
224+
llama_format_win_err(GetLastError()).c_str());
225+
}
222226
}
223227
#else
224228
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")

0 commit comments

Comments
 (0)