Skip to content

model : jina-embeddings-v3 support #13693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2460,15 +2460,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) {
params.lora_adapters.push_back({ std::string(value), 1.0, nullptr });
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale), nullptr });
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
Expand Down
2 changes: 2 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,8 @@ struct common_init_result common_init_from_params(common_params & params) {
}

la.ptr = lora.get();
la.task_name = llama_adapter_lora_task_name(la.ptr);
la.prompt_prefix = llama_adapter_lora_prompt_prefix(la.ptr);
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}

Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct common_adapter_lora_info {
std::string path;
float scale;

std::string task_name;
std::string prompt_prefix;

struct llama_adapter_lora * ptr;
};

Expand Down
67 changes: 49 additions & 18 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ModelBase:
endianess: gguf.GGUFEndian
use_temp_file: bool
lazy: bool
dry_run: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
self.lazy = not eager or (remote_hf_model_id is not None)
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
if remote_hf_model_id is not None:
self.is_safetensors = True
Expand Down Expand Up @@ -4153,18 +4155,31 @@ def modify_tensors(self, data_torch, name, bid):
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
_lora_files = {}

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
hparams = ModelBase.load_hparams(dir_model)

if hparams.get("lora_adaptations"):
if lora_names := hparams.get("lora_adaptations"):
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3

super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)

if lora_names:
for name in lora_names:
fname = self.add_prefix_to_filename(self.fname_out, f"lora-{name}-")
self._lora_files[name] = gguf.GGUFWriter(fname, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, dry_run=self.dry_run)

self._xlmroberta_tokenizer_init()

def set_type(self):
for lora_writer in self._lora_files.values():
lora_writer.add_type(gguf.GGUFType.ADAPTER)
lora_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
super().set_type()

def set_vocab(self):
self._xlmroberta_set_vocab()

Expand All @@ -4185,36 +4200,52 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if self._position_offset is not None:
data_torch = data_torch[self._position_offset:,:]

if name.endswith(".weight.0.lora_A") or name.endswith(".weight.0.lora_B"):
if name.endswith(".0.lora_A") or name.endswith(".0.lora_B"):
if name.startswith("pooler.dense"):
return
return []

lora_name = self.hparams["lora_adaptations"]
num_loras = data_torch.size(0)
assert num_loras == len(lora_name)
assert num_loras == len(self._lora_files)

# Split out each LoRA in their own GGUF
for i, lora_writer in enumerate(self._lora_files.values()):
new_name = self.map_tensor_name(name[:-9]) + name[-7:].lower()
data_qtype = gguf.GGMLQuantizationType.F32
data = data_torch[i, :, :]
# Transpose/flip token_embd/types into correct shape
if new_name == "token_embd.weight.lora_b":
data = data.T
elif new_name.startswith("token_types.weight."):
new_name = new_name[:-1] + ("a" if new_name[-1:] == "b" else "b")
data = gguf.quants.quantize(data.numpy(), data_qtype)
lora_writer.add_tensor(new_name, data, raw_dtype=data_qtype)

# Split out each LoRA in their own named tensors
# Remove "weight" from the name to not confuse quantize
for i in range(num_loras):
data_lora = data_torch[i, :, :]
yield (self.map_tensor_name(name[:-16]) + name[-16:].lower().replace("weight.0.", f"<{lora_name[i]}>"), data_lora)
return
return []

yield from super().modify_tensors(data_torch, name, bid)
return super().modify_tensors(data_torch, name, bid)

def set_gguf_parameters(self):
super().set_gguf_parameters()

# jina-embeddings-v3
if rotary_emb_base := self.hparams.get("rotary_emb_base"):
self.gguf_writer.add_rope_freq_base(rotary_emb_base)
if lora_alpha := self.hparams.get("lora_alpha"):
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha)
if lora_names := self.hparams.get("lora_adaptations"):
self.gguf_writer.add_array(gguf.Keys.Adapter.LORA_NAMES, lora_names)
lora_alpha = self.hparams.get("lora_alpha")
if lora_prompt_prefixes := self.hparams.get("task_instructions"):
assert lora_names and all(lora_name in lora_prompt_prefixes for lora_name in lora_names)
self.gguf_writer.add_array(gguf.Keys.Adapter.LORA_PROMPT_PREFIXES, [lora_prompt_prefixes[lora_name] for lora_name in lora_names])
assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
for lora_name, lora_writer in self._lora_files.items():
lora_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, lora_alpha if lora_alpha is not None else 1.0)
lora_writer.add_string(gguf.Keys.Adapter.LORA_TASK_NAME, lora_name)
if lora_prompt_prefixes:
lora_writer.add_string(gguf.Keys.Adapter.LORA_PROMPT_PREFIX, lora_prompt_prefixes[lora_name])

def write(self):
super().write()
for lora_writer in self._lora_files.values():
lora_writer.write_header_to_file()
lora_writer.write_kv_data_to_file()
lora_writer.write_tensors_to_file(progress=True)
lora_writer.close()


@ModelBase.register("GemmaForCausalLM")
Expand Down
8 changes: 4 additions & 4 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ class Tokenizer:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"

class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
LORA_NAMES = "adapter.lora.names"
LORA_PROMPT_PREFIXES = "adapter.lora.prompt_prefixes"
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
LORA_TASK_NAME = "adapter.lora.task_name"
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"

class Clip:
PROJECTOR_TYPE = "clip.projector_type"
Expand Down
6 changes: 6 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,12 @@ extern "C" {
struct llama_model * model,
const char * path_lora);

// Get the LoRA task name. Returns a blank string if not applicable
LLAMA_API const char * llama_adapter_lora_task_name(struct llama_adapter_lora * adapter);

// Get the required LoRA prompt prefix. Returns a blank string if not applicable
LLAMA_API const char * llama_adapter_lora_prompt_prefix(struct llama_adapter_lora * adapter);

// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
Expand Down
10 changes: 10 additions & 0 deletions src/llama-adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
}

adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
adapter.task_name = get_kv_str(llm_kv(LLM_KV_ADAPTER_LORA_TASK_NAME));
adapter.prompt_prefix = get_kv_str(llm_kv(LLM_KV_ADAPTER_LORA_PROMPT_PREFIX));
}

int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
Expand Down Expand Up @@ -383,6 +385,14 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
return nullptr;
}

const char * llama_adapter_lora_task_name(llama_adapter_lora * adapter) {
return adapter->task_name.c_str();
}

const char * llama_adapter_lora_prompt_prefix(llama_adapter_lora * adapter) {
return adapter->prompt_prefix.c_str();
}

void llama_adapter_lora_free(llama_adapter_lora * adapter) {
delete adapter;
}
1 change: 1 addition & 0 deletions src/llama-adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct llama_adapter_lora {
std::vector<ggml_backend_buffer_ptr> bufs;

float alpha;
std::string task_name;
std::string prompt_prefix;

llama_adapter_lora() = default;
Expand Down
8 changes: 4 additions & 4 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },

{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_LORA_NAMES, "adapter.lora.names" },
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES, "adapter.lora.prompt_prefixes" },
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
{ LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" },
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },

// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
Expand Down
4 changes: 2 additions & 2 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ enum llm_kv {

LLM_KV_ADAPTER_TYPE,
LLM_KV_ADAPTER_LORA_ALPHA,
LLM_KV_ADAPTER_LORA_NAMES,
LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES,
LLM_KV_ADAPTER_LORA_TASK_NAME,
LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,

LLM_KV_POSNET_EMBEDDING_LENGTH,
LLM_KV_POSNET_BLOCK_COUNT,
Expand Down
48 changes: 0 additions & 48 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1720,16 +1720,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
ggml_backend_buffer_type_t first_moved_from_buft = nullptr;
ggml_backend_buffer_type_t first_moved_to_buft = nullptr;

auto add_lora_tensors = [&](const std::string & lora_name, const std::string & tensor_name) -> void {
std::string base_name = tensor_name.substr(0, tensor_name.size() - 6);

ggml_tensor * lora_a = ml.get_tensor_meta((base_name + "<" + lora_name + ">lora_a").c_str());
ggml_tensor * lora_b = ml.get_tensor_meta((base_name + "<" + lora_name + ">lora_b").c_str());
loras[lora_name]->ab_map[tensor_name] = llama_adapter_lora_weight(lora_a, lora_b);

ml.n_created += 2;
};

auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) -> ggml_tensor * {
ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());

Expand Down Expand Up @@ -2256,8 +2246,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_JINA_BERT_V3:
{
std::vector<std::string> lora_names;

tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);

Expand All @@ -2274,31 +2262,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);

if (arch == LLM_ARCH_JINA_BERT_V3) {
float lora_alpha = 1.0f;
std::vector<std::string> lora_prompt_prefixes;

ml.get_key(LLM_KV_ADAPTER_LORA_ALPHA, lora_alpha, false);
ml.get_arr(LLM_KV_ADAPTER_LORA_NAMES, lora_names, false);
ml.get_arr(LLM_KV_ADAPTER_LORA_PROMPT_PREFIXES, lora_prompt_prefixes, false);
GGML_ASSERT(lora_names.size() == lora_prompt_prefixes.size());

for (size_t i = 0; i < lora_names.size(); ++i) {
llama_adapter_lora * adapter = new llama_adapter_lora();
std::string lora_name = lora_names[i];

adapter->alpha = lora_alpha;
adapter->prompt_prefix = lora_prompt_prefixes[i];
loras[lora_name] = adapter;

add_lora_tensors(lora_name, tok_embd->name);

if (type_embd) {
add_lora_tensors(lora_name, type_embd->name);
}
}
}

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

Expand Down Expand Up @@ -2337,17 +2300,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}

if (arch == LLM_ARCH_JINA_BERT_V3) {
GGML_ASSERT(layer.wqkv != nullptr);

for (const auto & lora_name : lora_names) {
add_lora_tensors(lora_name, layer.wqkv->name);
add_lora_tensors(lora_name, layer.wo->name);
add_lora_tensors(lora_name, layer.ffn_up->name);
add_lora_tensors(lora_name, layer.ffn_down->name);
}
}

layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
}
Expand Down
4 changes: 0 additions & 4 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "llama-memory.h"
#include "llama-vocab.h"

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -384,9 +383,6 @@ struct llama_model {

llama_model_params params;

// built-in LoRAs
std::map<std::string, llama_adapter_lora *> loras;

// gguf metadata
std::unordered_map<std::string, std::string> gguf_kv;

Expand Down
2 changes: 2 additions & 0 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4761,6 +4761,8 @@ int main(int argc, char ** argv) {
{"id", i},
{"path", lora.path},
{"scale", lora.scale},
{"task_name", lora.task_name},
{"prompt_prefix", lora.prompt_prefix},
});
}
res_ok(res, result);
Expand Down
Loading