From 0061955a067be69104655f0677d367c680ac5a43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 2 Jan 2025 10:14:39 +0100 Subject: [PATCH 1/9] convert : add support for DeepSeek V3 model --- convert_hf_to_gguf.py | 23 +++++++++++++++++++++++ convert_hf_to_gguf_update.py | 1 + gguf-py/gguf/constants.py | 10 ++++++++++ gguf-py/gguf/gguf_writer.py | 7 +++++++ gguf-py/gguf/tensor_mapping.py | 4 ++++ 5 files changed, 45 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 4e6c0f60c0621..43e61b500315e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -687,6 +687,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1": # ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct res = "megrez" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" if res is None: logger.warning("\n") @@ -3831,6 +3834,7 @@ def prepare_tensors(self): @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3852,6 +3856,15 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3864,6 +3877,16 @@ def set_gguf_parameters(self): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index fea23ddb4ae48..56edc64a72761 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -107,6 +107,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, ] diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 273370370e6ca..13b77979b1718 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -102,6 +102,8 @@ class LLM: EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_WEIGHTS_FUNC = "{arch}.expert_weights_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -312,6 +314,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_EXPERT_WEIGHTS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -496,6 +499,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -1276,6 +1280,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B, ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, @@ -1576,6 +1581,11 @@ class GGMLQuantizationType(IntEnum): TQ2_0 = 35 +class ExpertWeightsFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + + # TODO: add GGMLFileType from ggml_ftype in ggml.h diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3023b539ae82b..a0dadeaf8183a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ RopeScalingType, PoolingType, TokenType, + ExpertWeightsFuncType, ) from .quants import quant_shape_from_byte_shape @@ -715,6 +716,12 @@ def add_expert_shared_count(self, count: int) -> None: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_weights_func(self, value: ExpertWeightsFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_WEIGHTS_FUNC.format(arch=self.arch), value.value) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7009a11d46bc8..f48769ad1aae6 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -276,6 +276,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), + MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox From a43d4953ba77dda8ece5f46d21d6675e20f8c696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 2 Jan 2025 10:15:53 +0100 Subject: [PATCH 2/9] llama : add support for DeepSeek V3 model. --- include/llama.h | 1 + src/llama-vocab.cpp | 7 ++++ src/llama.cpp | 83 ++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 86 insertions(+), 5 deletions(-) diff --git a/include/llama.h b/include/llama.h index a4abf395bcd93..e340e3a8eb844 100644 --- a/include/llama.h +++ b/include/llama.h @@ -105,6 +105,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, }; enum llama_rope_type { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 0a477d6dd85f1..3f7afefed2be7 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -396,6 +396,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "\\p{N}+", }; break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + regex_exprs = { + "\\p{N}{1,3}", + "[一-龥぀-ゟ゠-ヿ]+", + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", diff --git a/src/llama.cpp b/src/llama.cpp index 4d41602fe2010..1c8479c965c9d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -78,7 +78,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 +#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 // // helpers @@ -289,6 +289,8 @@ enum llm_kv { LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -415,6 +417,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, + { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -560,6 +564,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -1429,6 +1434,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, "blk.%d.expert_weights_b" }, }, }, { @@ -2558,6 +2564,7 @@ enum e_model { MODEL_70B, MODEL_236B, MODEL_314B, + MODEL_671B, MODEL_SMALL, MODEL_MEDIUM, MODEL_LARGE, @@ -2586,6 +2593,19 @@ struct llama_hparams_convnext { uint32_t n_layer; }; +enum llm_expert_gating_func_type { + LLM_EXPERT_GATING_FUNC_SOFTMAX = 1, + LLM_EXPERT_GATING_FUNC_SIGMOID = 2, +}; + +static const char * llama_expert_gating_func_name(llm_expert_gating_func_type type) { + switch (type) { + case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax"; + case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid"; + default: return "unknown"; + } +} + struct llama_hparams { bool vocab_only; bool rope_finetuned; @@ -2621,6 +2641,8 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; float f_norm_eps; float f_norm_rms_eps; @@ -2912,6 +2934,7 @@ struct llama_layer { struct ggml_tensor * ffn_down_b = nullptr; // b2 struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act = nullptr; + struct ggml_tensor * ffn_expert_weights_bias = nullptr; // mamba proj struct ggml_tensor * ssm_in = nullptr; @@ -5577,6 +5600,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_70B: return "70B"; case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; + case MODEL_671B: return "671B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; @@ -6288,11 +6312,14 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); switch (hparams.n_layer) { case 27: model.type = e_model::MODEL_16B; break; case 60: model.type = e_model::MODEL_236B; break; + case 61: model.type = e_model::MODEL_671B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -6616,6 +6643,10 @@ static void llm_load_vocab( tokenizer_pre == "deepseek-coder") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "falcon") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -7300,6 +7331,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llm_expert_gating_func_type) hparams.expert_gating_func)); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } @@ -7447,6 +7480,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, @@ -9249,6 +9283,7 @@ static bool llm_load_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } else { layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_expert_weights_bias = create_tensor(tn(LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (n_expert == 0) { throw std::runtime_error("n_expert must be > 0"); @@ -10229,12 +10264,14 @@ static struct ggml_tensor * llm_build_moe_ffn( struct ggml_tensor * up_exps, struct ggml_tensor * gate_exps, struct ggml_tensor * down_exps, + struct ggml_tensor * expert_weights_b, int64_t n_expert, int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, bool scale_w, float w_scale, +llm_expert_gating_func_type gating_op, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -10243,11 +10280,31 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); - ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] - cb(probs, "ffn_moe_probs", il); + ggml_tensor * probs = nullptr; + switch (gating_op) { + case LLM_EXPERT_GATING_FUNC_SOFTMAX: + { + probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + cb(probs, "ffn_moe_probs", il); + } break; + case LLM_EXPERT_GATING_FUNC_SIGMOID: + { + probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] + cb(probs, "ffn_moe_sigm", il); + } break; + default: + GGML_ABORT("fatal error"); + } + + // add experts selection bias - introduced in DeepSeek V3 + ggml_tensor * selection_probs = probs; + if (expert_weights_b != nullptr) { + selection_probs = ggml_add(ctx, probs, expert_weights_b); + cb(selection_probs, "ffn_moe_sigm_biased", il); + } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -11368,9 +11425,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); } @@ -12020,9 +12079,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_GELU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -12161,9 +12222,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -13409,9 +13472,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -15403,9 +15468,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -15800,9 +15867,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -15941,9 +16010,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, hparams.expert_weights_scale, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(moe_out, "ffn_moe_out", il); @@ -16170,9 +16241,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_expert_weights_bias, n_expert, n_expert_used, - LLM_FFN_SILU, false, + LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, + (enum llm_expert_gating_func_type) hparams.expert_gating_func, cb, il); cb(moe_out, "ffn_moe_out", il); From 93aca64520f907cb1b56ee35e6c485af567e6ecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 2 Jan 2025 12:04:58 +0100 Subject: [PATCH 3/9] convert : renamed expert_weights_func to expert_gating_func --- convert_hf_to_gguf.py | 4 ++-- gguf-py/gguf/constants.py | 4 ++-- gguf-py/gguf/gguf_writer.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 43e61b500315e..bb15707ff6a70 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3859,9 +3859,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) if hparams["scoring_func"] == "sigmoid": - self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SIGMOID) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) elif hparams["scoring_func"] == "softmax": - self.gguf_writer.add_expert_weights_func(gguf.ExpertWeightsFuncType.SOFTMAX) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) else: raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 13b77979b1718..1302000ee95d0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -103,7 +103,7 @@ class LLM: EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" - EXPERT_WEIGHTS_FUNC = "{arch}.expert_weights_func" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -1581,7 +1581,7 @@ class GGMLQuantizationType(IntEnum): TQ2_0 = 35 -class ExpertWeightsFuncType(IntEnum): +class ExpertGatingFuncType(IntEnum): SOFTMAX = 1 SIGMOID = 2 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a0dadeaf8183a..4a0a65e3cc33e 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,7 +26,7 @@ RopeScalingType, PoolingType, TokenType, - ExpertWeightsFuncType, + ExpertGatingFuncType, ) from .quants import quant_shape_from_byte_shape @@ -719,8 +719,8 @@ def add_expert_weights_scale(self, value: float) -> None: def add_expert_weights_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) - def add_expert_weights_func(self, value: ExpertWeightsFuncType) -> None: - self.add_uint32(Keys.LLM.EXPERT_WEIGHTS_FUNC.format(arch=self.arch), value.value) + def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) From d2f784d50d3b64ce247a29f7c449bd255fe6e18a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 2 Jan 2025 21:13:35 +0100 Subject: [PATCH 4/9] convert : correct indentation --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bb15707ff6a70..7839875159cd9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3879,7 +3879,7 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): - name = name.replace("e_score_correction_bias", "e_score_correction.bias") + name = name.replace("e_score_correction_bias", "e_score_correction.bias") # skip Multi-Token Prediction (MTP) layers block_count = self.hparams["num_hidden_layers"] From 140eb292644f201aadc042392419dea0da236ecc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 3 Jan 2025 13:51:14 +0100 Subject: [PATCH 5/9] gguf-py, llama : rename expert_weights to exp_probs in tensor and variable names --- gguf-py/gguf/constants.py | 6 +++--- gguf-py/gguf/tensor_mapping.py | 2 +- src/llama.cpp | 18 +++++++++--------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1302000ee95d0..ef795c04e1ca5 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -314,7 +314,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() - FFN_EXPERT_WEIGHTS_B = auto() + FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -499,7 +499,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", - MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: "blk.{bid}.expert_weights_b", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -1280,7 +1280,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, - MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B, + MODEL_TENSOR.FFN_EXP_PROBS_B, ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index f48769ad1aae6..efe2a4aa4fe28 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -276,7 +276,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), - MODEL_TENSOR.FFN_EXPERT_WEIGHTS_B: ( + MODEL_TENSOR.FFN_EXP_PROBS_B: ( "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 ), diff --git a/src/llama.cpp b/src/llama.cpp index 1c8479c965c9d..1ab930e3e27c8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -564,7 +564,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, + LLM_TENSOR_FFN_EXP_PROBS_B, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -1434,7 +1434,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, - { LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, "blk.%d.expert_weights_b" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, { @@ -2934,7 +2934,7 @@ struct llama_layer { struct ggml_tensor * ffn_down_b = nullptr; // b2 struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act = nullptr; - struct ggml_tensor * ffn_expert_weights_bias = nullptr; + struct ggml_tensor * ffn_exp_probs_b = nullptr; // mamba proj struct ggml_tensor * ssm_in = nullptr; @@ -7480,7 +7480,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, - {LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, @@ -9283,7 +9283,7 @@ static bool llm_load_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } else { layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_expert_weights_bias = create_tensor(tn(LLM_TENSOR_FFN_EXPERT_WEIGHTS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (n_expert == 0) { throw std::runtime_error("n_expert must be > 0"); @@ -10285,22 +10285,22 @@ llm_expert_gating_func_type gating_op, case LLM_EXPERT_GATING_FUNC_SOFTMAX: { probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] - cb(probs, "ffn_moe_probs", il); } break; case LLM_EXPERT_GATING_FUNC_SIGMOID: { probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] - cb(probs, "ffn_moe_sigm", il); } break; default: GGML_ABORT("fatal error"); } + cb(probs, "ffn_moe_probs", il); // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights ggml_tensor * selection_probs = probs; if (expert_weights_b != nullptr) { selection_probs = ggml_add(ctx, probs, expert_weights_b); - cb(selection_probs, "ffn_moe_sigm_biased", il); + cb(selection_probs, "ffn_moe_probs_biased", il); } // select experts @@ -16241,7 +16241,7 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_expert_weights_bias, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, From 5b4673b3dd8e65f74b81538f992395a89180e1f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 3 Jan 2025 14:57:56 +0100 Subject: [PATCH 6/9] llama : rename expert_weights_b to exp_probs_b --- src/llama.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 1ab930e3e27c8..9e1094f8d5aef 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10264,7 +10264,7 @@ static struct ggml_tensor * llm_build_moe_ffn( struct ggml_tensor * up_exps, struct ggml_tensor * gate_exps, struct ggml_tensor * down_exps, - struct ggml_tensor * expert_weights_b, + struct ggml_tensor * exp_probs_b, int64_t n_expert, int64_t n_expert_used, llm_ffn_op_type type_op, @@ -10298,8 +10298,8 @@ llm_expert_gating_func_type gating_op, // add experts selection bias - introduced in DeepSeek V3 // leave probs unbiased as it's later used to get expert weights ggml_tensor * selection_probs = probs; - if (expert_weights_b != nullptr) { - selection_probs = ggml_add(ctx, probs, expert_weights_b); + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx, probs, exp_probs_b); cb(selection_probs, "ffn_moe_probs_biased", il); } From dfffe676118b3878d8465602ea5bbada7abd2d34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 3 Jan 2025 18:39:44 +0100 Subject: [PATCH 7/9] llama : add support for ACCENT_MARK (\\p{M}) and SYMBOL (\\p{S}) unicode categories in pre-tokenization regex --- src/unicode.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/unicode.cpp b/src/unicode.cpp index 8ed6b1a51c251..7aca6544bc73d 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -667,18 +667,24 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{N}", unicode_cpt_flags::NUMBER }, { "\\p{L}", unicode_cpt_flags::LETTER }, { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, + { "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, + { "\\p{S}", unicode_cpt_flags::SYMBOL }, }; static const std::map k_ucat_cpt = { { unicode_cpt_flags::NUMBER, 0xD1 }, { unicode_cpt_flags::LETTER, 0xD2 }, { unicode_cpt_flags::PUNCTUATION, 0xD3 }, + { unicode_cpt_flags::ACCENT_MARK, 0xD4 }, + { unicode_cpt_flags::SYMBOL, 0xD5 }, }; static const std::map k_ucat_map = { { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9 { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints + { unicode_cpt_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`| }; // compute collapsed codepoints only if needed by at least one regex From a48c3df3df2220fc0df3a7038cdbdd2b9ed4eb3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 4 Jan 2025 14:36:03 +0100 Subject: [PATCH 8/9] llama : add DeepSeek-V3 chat template --- src/llama-chat.cpp | 18 ++++++++++++++++++ src/llama-chat.h | 1 + 2 files changed, 19 insertions(+) diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index a07e9cf00b942..44670d3d83934 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -45,6 +45,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, @@ -148,6 +149,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MINICPM; } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct @@ -453,6 +456,21 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "Assistant:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) { + // DeepSeek-V3 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << LU8("<|User|>") << message->content; + } else if (role == "assistant") { + ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << LU8("<|Assistant|>"); + } } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct diff --git a/src/llama-chat.h b/src/llama-chat.h index 364318c2775db..b8e94d9ef2b3b 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -25,6 +25,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_VICUNA_ORCA, LLM_CHAT_TEMPLATE_DEEPSEEK, LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_DEEPSEEK_3, LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGML_3, From 4a58b99777d357c1457f6c97d9462bc6aa3e6646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 4 Jan 2025 17:28:17 +0100 Subject: [PATCH 9/9] llama : move llama_expert_gating_func_type to llama-hparams.h --- include/llama.h | 6 ------ src/llama-hparams.h | 6 ++++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/llama.h b/include/llama.h index 5d4afe9bf867c..a0d5ba5ddcfd2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -116,12 +116,6 @@ extern "C" { LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, }; - enum llama_expert_gating_func_type { - LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, - LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, - }; - enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file LLAMA_TOKEN_TYPE_UNDEFINED = 0, LLAMA_TOKEN_TYPE_NORMAL = 1, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 5a5cf025877af..a29f20ec49665 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -8,6 +8,12 @@ #define LLAMA_MAX_LAYERS 512 #define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 +enum llama_expert_gating_func_type { + LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, +}; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer;