From 3f2839efdadbbacd51571f12b0036248e1cc31db Mon Sep 17 00:00:00 2001 From: NoAmateur <1972025243@qq.com> Date: Mon, 26 May 2025 09:51:45 +0000 Subject: [PATCH] Add OPT model support - Add OPT architecture support in C++ code - Implement OPT-specific graph builder with separate Q/K/V projections - Add OPT model conversion support in Python - Add OPT tensor mappings and constants in gguf-py - Support some OPT model sizes - Tested with OPT-125M and OPT-13B models --- convert_hf_to_gguf.py | 34 +++++- gguf-py/gguf/constants.py | 16 +++ gguf-py/gguf/tensor_mapping.py | 13 ++- src/llama-arch.cpp | 18 ++++ src/llama-arch.h | 1 + src/llama-model.cpp | 184 +++++++++++++++++++++++++++++++++ 6 files changed, 261 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 91af508a2fb28..a58143153bedf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -803,6 +803,12 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base res = "seed-coder" + if chkhsh == "7f2212c1b7fec62b4b75447509a4ecc8acd82813ce90d715dd99c1460a52d978": + # ref: https://huggingface.co/facebook/opt-13b + res = "gpt-2" + if chkhsh == "2c934e5e1c8275b75011b9942836389a87eaa1a63116104e52424515e7649c46": + # ref: https://huggingface.co/SousChef/OPT-13B-Erebus (OPT-13B-Erebus model) + res = "gpt-2" if res is None: logger.warning("\n") @@ -3902,7 +3908,7 @@ def set_vocab(self): def set_gguf_parameters(self): hparams = self.hparams block_count = hparams["num_hidden_layers"] - + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) @@ -4014,7 +4020,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - + if name.startswith("language_model."): name = name.replace("language_model.", "") @@ -4520,7 +4526,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith("k_proj.weight"): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) - + return [(self.map_tensor_name(name), data_torch)] @@ -5231,7 +5237,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - + # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder @@ -6124,6 +6130,26 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return cls._wrap_fn(func)(*args, **kwargs) +@ModelBase.register("OPTForCausalLM") +class OPTModel(TextModel): + model_arch = gguf.MODEL_ARCH.OPT + + def set_vocab(self): + # OPT typically uses GPT2 tokenizer + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + # OPT-specific parameters that are not handled by the base class + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # OPT model uses standard tensor mapping - let the mapping handle the conversion + return [(self.map_tensor_name(name), data_torch)] + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Convert a huggingface model to a GGML compatible file") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c6255d6867a15..30615ec4b1976 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -340,6 +340,7 @@ class MODEL_ARCH(IntEnum): WAVTOKENIZER_DEC = auto() PLM = auto() BAILINGMOE = auto() + OPT = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -620,6 +621,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", MODEL_ARCH.PLM: "plm", MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.OPT: "opt", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2040,6 +2042,20 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.OPT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4a0615b656812..dc566a6050ffb 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -31,6 +31,7 @@ class TensorNameMap: "model.embeddings", # rwkv7 "model.word_embeddings", # bailingmoe "language_model.model.embed_tokens", # llama4 + "model.decoder.embed_tokens", # opt ), # Token type embeddings @@ -56,6 +57,7 @@ class TensorNameMap: "transformer.wpe", # gpt2 "embeddings.position_embeddings", # bert "wpe", # gpt2 + "model.decoder.embed_positions", # opt ), # Output @@ -68,7 +70,7 @@ class TensorNameMap: "output_layer", # chatglm "head", # rwkv "head.out", # wavtokenizer - "lm_head", # llama4 + "lm_head", # llama4 opt ), # Output norm @@ -92,6 +94,7 @@ class TensorNameMap: "model.ln_out", # rwkv7 "backbone.final_layer_norm", # wavtokenizer "model.norm", # llama4 + "model.decoder.final_layer_norm", # opt ), # Rope frequencies @@ -134,6 +137,7 @@ class TensorNameMap: "rwkv.blocks.{bid}.ln1", # rwkv6 "model.layers.{bid}.ln1", # rwkv7 "model.layers.{bid}.input_layernorm", # llama4 + "model.decoder.layers.{bid}.self_attn_layer_norm", # opt ), # Attention norm 2 @@ -174,6 +178,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.h.{bid}.attn.attention.q_proj", # exaone "model.layers.{bid}.self_attn.q_proj", # llama4 + "model.decoder.layers.{bid}.self_attn.q_proj", # opt ), # Attention key @@ -189,6 +194,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.h.{bid}.attn.attention.k_proj", # exaone "model.layers.{bid}.self_attn.k_proj", # llama4 + "model.decoder.layers.{bid}.self_attn.k_proj", # opt ), # Attention value @@ -203,6 +209,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok "transformer.h.{bid}.attn.attention.v_proj", # exaone "model.layers.{bid}.self_attn.v_proj", # llama4 + "model.decoder.layers.{bid}.self_attn.v_proj", # opt ), # Attention output @@ -230,6 +237,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn.out_proj", # openelm "transformer.h.{bid}.attn.attention.out_proj", # exaone "model.layers.{bid}.self_attn.o_proj", # llama4 + "model.decoder.layers.{bid}.self_attn.out_proj", # opt ), # Attention output norm @@ -269,6 +277,7 @@ class TensorNameMap: "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm "model.layers.{bid}.post_attention_layernorm", # llama4 + "model.decoder.layers.{bid}.final_layer_norm", # opt ), # Post feed-forward norm @@ -330,6 +339,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone "model.layers.{bid}.feed_forward.up_proj", # llama4 + "model.decoder.layers.{bid}.fc1", # opt ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -411,6 +421,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.{bid}.feed_forward.down_proj", # llama4 + "model.decoder.layers.{bid}.fc2", # opt ), MODEL_TENSOR.FFN_DOWN_EXP: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index abf436adac416..eb7a1890fdd52 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -72,6 +72,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_OPT, "opt" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1530,6 +1531,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_OPT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_BAILINGMOE, { diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..d85bce1aeecfc 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -76,6 +76,7 @@ enum llm_arch { LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, + LLM_ARCH_OPT, LLM_ARCH_UNKNOWN, }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e99f5309f9904..35e0d23d3777d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1429,6 +1429,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_OPT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + if (hparams.f_norm_eps == 0.0f) { + hparams.f_norm_eps = 1e-5f; // default value for OPT + } + + // OPT does not use RoPE, so set n_rot to 0 + hparams.n_rot = 0; + + // OPT uses uniform architecture, so if n_ff_arr is not set in GGUF, set it manually + if (hparams.n_ff_arr[0] == 0) { + const uint32_t n_ff = hparams.n_embd * 4; // OPT uses 4x expansion ratio + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.n_ff_arr[il] = n_ff; + } + } + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 32: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_30B; break; + case 64: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -4111,6 +4139,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); } } break; + case LLM_ARCH_OPT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train + 2}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7732,6 +7799,118 @@ struct llm_build_gpt2 : public llm_graph_context { } }; +struct llm_build_opt : public llm_graph_context { + llm_build_opt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * pos; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // OPT uses separate Q, K, V projections + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_codeshell : public llm_graph_context { llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -13526,6 +13705,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_OPT: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -13635,6 +13818,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_OPT: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values