Skip to content

Commit f52101e

Browse files
committed
Add lora support
1 parent 3173a62 commit f52101e

File tree

9 files changed

+335
-3
lines changed

9 files changed

+335
-3
lines changed

convert-lora-to-ggml.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
import re
3+
import struct
4+
import sys
5+
from dataclasses import dataclass
6+
from typing import Any, Sequence
7+
8+
import numpy as np
9+
import torch
10+
11+
12+
# TODO: import this from convert.py once #545 is merged
13+
@dataclass(frozen=True)
14+
class UnquantizedDataType:
15+
name: str
16+
17+
DT_F16 = UnquantizedDataType('F16')
18+
DT_F32 = UnquantizedDataType('F32')
19+
20+
@dataclass(frozen=True)
21+
class QuantizedDataType:
22+
groupsize: int
23+
have_addends: bool
24+
have_g_idx: bool
25+
26+
DataType = UnquantizedDataType
27+
28+
DATA_TYPE_TO_FTYPE: dict[DataType, int] = {
29+
DT_F32: 0,
30+
DT_F16: 1,
31+
}
32+
33+
DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = {
34+
DT_F16: np.dtype(np.float16),
35+
DT_F32: np.dtype(np.float32),
36+
}
37+
38+
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
39+
40+
HF_SUBLAYER_TO_GGML = {
41+
"self_attn.q_proj": "attention.wq.weight",
42+
"self_attn.k_proj": "attention.wk.weight",
43+
"self_attn.v_proj": "attention.wv.weight",
44+
"self_attn.o_proj": "attention.wo.weight",
45+
}
46+
47+
def translate_tensor_name(t):
48+
match = re.match(r'.*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight', t)
49+
if match:
50+
nn = match.group(1)
51+
sub_layer = match.group(2)
52+
lora_type = match.group(3)
53+
54+
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
55+
if sub_layer_renamed is None:
56+
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
57+
exit(1)
58+
59+
output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}"
60+
return output_string
61+
else:
62+
print(f"Error: unrecognized tensor {t}")
63+
exit(1)
64+
65+
def write_file_header(fout):
66+
fout.write(b"ggla"[::-1]) # magic (ggml lora)
67+
fout.write(struct.pack("i", 1)) # file version
68+
69+
70+
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None:
71+
sname = name.encode('utf-8')
72+
fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]]))
73+
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
74+
fout.write(sname)
75+
fout.seek((fout.tell() + 31) & -32)
76+
77+
78+
if len(sys.argv) < 2:
79+
print(f"Usage: python {sys.argv[0]} adapter_model.bin [ggml_adapter_model.bin]")
80+
sys.exit(1)
81+
82+
input_path = sys.argv[1]
83+
if len(sys.argv) > 2:
84+
output_path = sys.argv[2]
85+
else:
86+
output_filename = f"ggml_{os.path.basename(input_path)}"
87+
output_path = os.path.join(os.path.dirname(input_path), output_filename)
88+
89+
model = torch.load(input_path, map_location="cpu")
90+
91+
with open(output_path, "wb") as fout:
92+
write_file_header(fout)
93+
for k, v in model.items():
94+
# since ggml doesn't always support other types for the second operand,
95+
# the tensors are always converted and exported as f32
96+
t = v.float().numpy()
97+
print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
98+
write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype)
99+
t.tofile(fout)
100+
101+
print(f"Converted {input_path} to {output_path}")

examples/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
139139
break;
140140
}
141141
params.model = argv[i];
142+
} else if (arg == "--lora") {
143+
if (++i >= argc) {
144+
invalid_param = true;
145+
break;
146+
}
147+
params.lora_adapter = argv[i];
142148
} else if (arg == "-i" || arg == "--interactive") {
143149
params.interactive = true;
144150
} else if (arg == "--embedding") {
@@ -242,6 +248,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
242248
}
243249
fprintf(stderr, " --mtest compute maximum memory usage\n");
244250
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
251+
fprintf(stderr, " --lora FNAME apply LoRA adapter\n");
245252
fprintf(stderr, " -m FNAME, --model FNAME\n");
246253
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
247254
fprintf(stderr, "\n");

examples/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ struct gpt_params {
3131

3232
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
3333
std::string prompt = "";
34-
std::string input_prefix = ""; // string to prefix user inputs with
35-
36-
34+
std::string input_prefix = ""; // string to prefix user inputs with
3735
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
3836

37+
std::string lora_adapter = ""; // lora adapter path
38+
3939
bool memory_f16 = true; // use f16 instead of f32 for memory kv
4040
bool random_prompt = false; // do not randomize prompt if none provided
4141
bool use_color = false; // use color to distinguish generations and inputs

examples/main/main.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ int main(int argc, char ** argv) {
114114
}
115115
}
116116

117+
if (!params.lora_adapter.empty()) {
118+
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
119+
if (err != 0) {
120+
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
121+
return 1;
122+
}
123+
}
124+
117125
// print system information
118126
{
119127
fprintf(stderr, "\n");

examples/perplexity/perplexity.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ int main(int argc, char ** argv) {
134134
}
135135
}
136136

137+
if (!params.lora_adapter.empty()) {
138+
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
139+
if (err != 0) {
140+
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
141+
return 1;
142+
}
143+
}
144+
137145
// print system information
138146
{
139147
fprintf(stderr, "\n");

ggml.c

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5813,6 +5813,47 @@ static void ggml_compute_forward_add_f32(
58135813
}
58145814
}
58155815

5816+
static void ggml_compute_forward_add_f16_f32(
5817+
const struct ggml_compute_params * params,
5818+
const struct ggml_tensor * src0,
5819+
const struct ggml_tensor * src1,
5820+
struct ggml_tensor * dst) {
5821+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5822+
5823+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5824+
return;
5825+
}
5826+
5827+
const int ith = params->ith;
5828+
const int nth = params->nth;
5829+
5830+
const int n = ggml_nrows(src0);
5831+
const int nc = src0->ne[0];
5832+
5833+
const size_t nb00 = src0->nb[0];
5834+
const size_t nb01 = src0->nb[1];
5835+
5836+
const size_t nb10 = src1->nb[0];
5837+
const size_t nb11 = src1->nb[1];
5838+
5839+
const size_t nb0 = dst->nb[0];
5840+
const size_t nb1 = dst->nb[1];
5841+
5842+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5843+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
5844+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5845+
5846+
for (int j = ith; j < n; j += nth) {
5847+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5848+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5849+
for (int i = 0; i < nc; i++) {
5850+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5851+
5852+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
5853+
}
5854+
}
5855+
}
5856+
58165857
static void ggml_compute_forward_add(
58175858
const struct ggml_compute_params * params,
58185859
const struct ggml_tensor * src0,
@@ -5823,6 +5864,10 @@ static void ggml_compute_forward_add(
58235864
{
58245865
ggml_compute_forward_add_f32(params, src0, src1, dst);
58255866
} break;
5867+
case GGML_TYPE_F16:
5868+
{
5869+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5870+
} break;
58265871
default:
58275872
{
58285873
GGML_ASSERT(false);

ggml.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,12 @@ struct ggml_tensor * ggml_add(
430430
struct ggml_tensor * a,
431431
struct ggml_tensor * b);
432432

433+
434+
struct ggml_tensor * ggml_add_inplace(
435+
struct ggml_context * ctx,
436+
struct ggml_tensor * a,
437+
struct ggml_tensor * b);
438+
433439
struct ggml_tensor * ggml_sub(
434440
struct ggml_context * ctx,
435441
struct ggml_tensor * a,

llama.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,154 @@ int llama_model_quantize(
17581758
}
17591759
}
17601760

1761+
int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, int n_threads) {
1762+
// TODO: refactor all of this after PR #801
1763+
auto & model = ctx->model;
1764+
1765+
auto fin = std::ifstream(path_lora, std::ios::binary);
1766+
if (!fin) {
1767+
fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora);
1768+
return 1;
1769+
}
1770+
1771+
// verify magic and version
1772+
{
1773+
uint32_t magic;
1774+
fin.read((char *) &magic, sizeof(magic));
1775+
if (magic != 'ggla') {
1776+
fprintf(stderr, "%s: bad file magic\n", __func__);
1777+
return 1;
1778+
}
1779+
uint32_t format_version;
1780+
fin.read((char *) &format_version, sizeof(format_version));
1781+
1782+
if (format_version != 1) {
1783+
fprintf(stderr, "%s: unsupported file version\n", __func__ );
1784+
return 1;
1785+
}
1786+
}
1787+
1788+
// create a temporary ggml context to store the lora tensors
1789+
std::vector<uint8_t> buf(1024 * 1024 * 100);
1790+
struct ggml_init_params params;
1791+
params.mem_size = buf.size();
1792+
params.mem_buffer = buf.data();
1793+
params.no_alloc = false;
1794+
1795+
ggml_context* lora_ctx = ggml_init(params);
1796+
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
1797+
1798+
fprintf(stderr, "%s: ", __func__);
1799+
1800+
// read tensors and apply
1801+
int n_tensors = 0;
1802+
while (true) {
1803+
int32_t n_dims;
1804+
int32_t length;
1805+
int32_t ftype;
1806+
1807+
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
1808+
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
1809+
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
1810+
if (fin.eof()) {
1811+
break;
1812+
}
1813+
1814+
int32_t nelements = 1;
1815+
int32_t ne[2] = { 1, 1 };
1816+
for (int i = 0; i < n_dims; ++i) {
1817+
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
1818+
nelements *= ne[i];
1819+
}
1820+
1821+
std::string name(length, 0);
1822+
fin.read(&name[0], length);
1823+
1824+
// check for lora suffix and get the type of tensor
1825+
const std::string lora_suffix = ".lora";
1826+
size_t pos = name.rfind(lora_suffix);
1827+
if (pos == std::string::npos) {
1828+
fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
1829+
return 1;
1830+
}
1831+
1832+
std::string lora_type = name.substr(pos + lora_suffix.length());
1833+
std::string base_name = name;
1834+
base_name.erase(pos);
1835+
// fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
1836+
1837+
if (model.tensors.find(base_name.data()) == model.tensors.end()) {
1838+
fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
1839+
return 1;
1840+
}
1841+
1842+
// create ggml tensor
1843+
ggml_type wtype;
1844+
switch (ftype) {
1845+
case 0: wtype = GGML_TYPE_F32; break;
1846+
case 1: wtype = GGML_TYPE_F16; break;
1847+
default:
1848+
{
1849+
fprintf(stderr, "%s: invalid tensor data type '%d'\n",
1850+
__func__, ftype);
1851+
return false;
1852+
}
1853+
}
1854+
ggml_tensor* lora_tensor;
1855+
if (n_dims == 2) {
1856+
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
1857+
}
1858+
else {
1859+
fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
1860+
return 1;
1861+
}
1862+
1863+
// load tensor data
1864+
size_t offset = fin.tellg();
1865+
size_t tensor_data_size = ggml_nbytes(lora_tensor);
1866+
offset = (offset + 31) & -32;
1867+
fin.seekg(offset);
1868+
fin.read((char*)lora_tensor->data, tensor_data_size);
1869+
1870+
lora_tensors[name] = lora_tensor;
1871+
1872+
// check if we have both A and B tensors and apply
1873+
if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
1874+
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
1875+
1876+
ggml_tensor * tensor = model.tensors[base_name];
1877+
ggml_tensor * loraA = ggml_transpose(lora_ctx, lora_tensors[base_name + ".loraA"]);
1878+
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
1879+
1880+
if (tensor->ne[0] != loraA->ne[1]) {
1881+
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
1882+
" are you sure that this adapter is for this model?\n", __func__, tensor->ne[0], loraA->ne[1]);
1883+
return 1;
1884+
}
1885+
1886+
// w = w + BA
1887+
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA);
1888+
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
1889+
1890+
struct ggml_cgraph gf = ggml_build_forward(r);
1891+
gf.n_threads = n_threads;
1892+
ggml_graph_compute(lora_ctx, &gf);
1893+
1894+
// we won't need these tensors again, reset the context to save memory
1895+
ggml_free(lora_ctx);
1896+
lora_ctx = ggml_init(params);
1897+
lora_tensors.clear();
1898+
1899+
n_tensors++;
1900+
if (n_tensors % 8 == 0)
1901+
fprintf(stderr, ".");
1902+
}
1903+
}
1904+
fprintf(stderr, " done\n");
1905+
1906+
return 0;
1907+
}
1908+
17611909
// Returns the KV cache that will contain the context for the
17621910
// ongoing prediction with the model.
17631911
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {

0 commit comments

Comments
 (0)