Skip to content

Commit b0b280e

Browse files
committed
Merge branch 'master' into compilade/refactor-kv-cache
2 parents f716358 + 6efcd65 commit b0b280e

32 files changed

+999
-219
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ jobs:
342342
cd build
343343
export GGML_VK_VISIBLE_DEVICES=0
344344
# This is using llvmpipe and runs slower than other backends
345-
ctest -L main --verbose --timeout 3600
345+
ctest -L main --verbose --timeout 4200
346346
347347
ubuntu-22-cmake-hip:
348348
runs-on: ubuntu-22.04

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,6 +2734,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27342734
params.public_path = value;
27352735
}
27362736
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH"));
2737+
add_opt(common_arg(
2738+
{"--api-prefix"}, "PREFIX",
2739+
string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()),
2740+
[](common_params & params, const std::string & value) {
2741+
params.api_prefix = value;
2742+
}
2743+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
27372744
add_opt(common_arg(
27382745
{"--no-webui"},
27392746
string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"),

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ struct common_params {
370370

371371
std::string hostname = "127.0.0.1";
372372
std::string public_path = ""; // NOLINT
373+
std::string api_prefix = ""; // NOLINT
373374
std::string chat_template = ""; // NOLINT
374375
bool use_jinja = false; // NOLINT
375376
bool enable_chat_template = true;

convert_hf_to_gguf.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
815815
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
816816
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
817817
res = "minerva-7b"
818+
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
819+
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820+
res = "hunyuan"
818821

819822
if res is None:
820823
logger.warning("\n")
@@ -6652,6 +6655,160 @@ def set_gguf_parameters(self):
66526655
super().set_gguf_parameters()
66536656
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
66546657

6658+
6659+
@ModelBase.register("HunYuanMoEV1ForCausalLM")
6660+
class HunYuanMoEModel(TextModel):
6661+
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
6662+
6663+
def __init__(self, *args, **kwargs):
6664+
super().__init__(*args, **kwargs)
6665+
# For handling tied embeddings
6666+
self._tok_embd = None
6667+
6668+
def set_vocab(self):
6669+
from transformers import AutoTokenizer
6670+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
6671+
6672+
# 1. Get the pre-tokenizer identifier hash
6673+
tokpre = self.get_vocab_base_pre(tokenizer)
6674+
6675+
# 2. Reverse-engineer the merges list from mergeable_ranks
6676+
merges = []
6677+
vocab = {}
6678+
mergeable_ranks = tokenizer.mergeable_ranks
6679+
for token, rank in mergeable_ranks.items():
6680+
vocab[QwenModel.token_bytes_to_string(token)] = rank
6681+
if len(token) == 1:
6682+
continue
6683+
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
6684+
if len(merged) == 2: # todo this is an assert in Qwen, why?
6685+
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
6686+
6687+
# 3. Generate the tokens and toktypes lists
6688+
vocab_size = self.hparams["vocab_size"]
6689+
assert tokenizer.vocab_size == vocab_size
6690+
special_tokens = tokenizer.special_tokens
6691+
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
6692+
tokens: list[str] = []
6693+
toktypes: list[int] = []
6694+
for i in range(vocab_size):
6695+
if i not in reverse_vocab:
6696+
tokens.append(f"[PAD{i}]")
6697+
toktypes.append(gguf.TokenType.UNUSED)
6698+
else:
6699+
token = reverse_vocab[i]
6700+
tokens.append(token)
6701+
if i in special_tokens.values():
6702+
toktypes.append(gguf.TokenType.CONTROL)
6703+
else:
6704+
toktypes.append(gguf.TokenType.NORMAL)
6705+
6706+
# 4. Write all vocab-related fields to the GGUF writer
6707+
self.gguf_writer.add_tokenizer_model("gpt2")
6708+
self.gguf_writer.add_tokenizer_pre(tokpre)
6709+
self.gguf_writer.add_token_list(tokens)
6710+
self.gguf_writer.add_token_types(toktypes)
6711+
self.gguf_writer.add_token_merges(merges)
6712+
6713+
# 5. Add special tokens and chat templates
6714+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
6715+
special_vocab.add_to_gguf(self.gguf_writer)
6716+
# FIX for BOS token: Overwrite incorrect id read from config.json
6717+
self.gguf_writer.add_bos_token_id(127959) # <|bos|>
6718+
6719+
def set_gguf_parameters(self):
6720+
super().set_gguf_parameters()
6721+
hparams = self.hparams
6722+
6723+
self.gguf_writer.add_expert_count(hparams["num_experts"])
6724+
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
6725+
6726+
moe_intermediate_size = hparams["moe_intermediate_size"]
6727+
assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
6728+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
6729+
6730+
moe_topk = hparams["moe_topk"]
6731+
assert all(topk == moe_topk[0] for topk in moe_topk)
6732+
self.gguf_writer.add_expert_used_count(moe_topk[0])
6733+
6734+
moe_shared_expert = hparams["num_shared_expert"]
6735+
assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
6736+
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
6737+
6738+
# Rope
6739+
rope_scaling = hparams.get("rope_scaling", {})
6740+
if rope_scaling.get("type") == "dynamic":
6741+
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
6742+
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
6743+
alpha = rope_scaling.get("alpha", 1000)
6744+
base = hparams.get("rope_theta", 10000.0)
6745+
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
6746+
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
6747+
self.gguf_writer.add_rope_freq_base(scaled_base)
6748+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
6749+
self.gguf_writer.add_rope_scaling_factor(1)
6750+
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
6751+
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
6752+
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
6753+
6754+
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
6755+
assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
6756+
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
6757+
6758+
_experts: list[dict[str, Tensor]] | None = None
6759+
6760+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6761+
if name == "model.embed_tokens.weight":
6762+
self._tok_embd = data_torch.clone()
6763+
6764+
if name == "lm_head.weight":
6765+
if self.hparams.get("tie_word_embeddings", False):
6766+
logger.info("Skipping tied output layer 'lm_head.weight'")
6767+
return []
6768+
6769+
if name.find("mlp.experts") != -1:
6770+
n_experts = self.hparams["num_experts"]
6771+
assert bid is not None
6772+
6773+
if self._experts is None:
6774+
self._experts = [{} for _ in range(self.block_count)]
6775+
6776+
self._experts[bid][name] = data_torch
6777+
6778+
if len(self._experts[bid]) >= n_experts * 3:
6779+
# merge the experts into a single 3d tensor
6780+
tensors: list[tuple[str, Tensor]] = []
6781+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
6782+
datas: list[Tensor] = []
6783+
6784+
for xid in range(n_experts):
6785+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
6786+
datas.append(self._experts[bid][ename])
6787+
del self._experts[bid][ename]
6788+
6789+
data_torch = torch.stack(datas, dim=0)
6790+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
6791+
new_name = self.map_tensor_name(merged_name)
6792+
tensors.append((new_name, data_torch))
6793+
6794+
return tensors
6795+
else:
6796+
return []
6797+
6798+
return [(self.map_tensor_name(name), data_torch)]
6799+
6800+
def prepare_tensors(self):
6801+
super().prepare_tensors()
6802+
if self._experts is not None:
6803+
experts = [k for d in self._experts for k in d.keys()]
6804+
if len(experts) > 0:
6805+
raise ValueError(f"Unprocessed experts: {experts}")
6806+
6807+
6808+
@ModelBase.register("SmolLM3ForCausalLM")
6809+
class SmolLM3Model(LlamaModel):
6810+
model_arch = gguf.MODEL_ARCH.SMOLLM3
6811+
66556812
###### CONVERSION LOGIC ######
66566813

66576814

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class TOKENIZER_TYPE(IntEnum):
137137
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
138138
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
139139
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
140+
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
140141
]
141142

142143

docs/development/HOWTO-add-model.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv
8383

8484
### 2. Define the model architecture in `llama.cpp`
8585

86-
The model params and tensors layout must be defined in `llama.cpp`:
87-
1. Define a new `llm_arch`
88-
2. Define the tensors layout in `LLM_TENSOR_NAMES`
89-
3. Add any non-standard metadata in `llm_load_hparams`
90-
4. Create the tensors for inference in `llm_load_tensors`
91-
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
86+
The model params and tensors layout must be defined in `llama.cpp` source files:
87+
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
88+
2. In `src/llama-arch.cpp`:
89+
- Add the architecture name to the `LLM_ARCH_NAMES` map.
90+
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
91+
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
92+
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
9293

9394
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
9495

9596
### 3. Build the GGML graph implementation
9697

97-
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
98-
99-
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
98+
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
99+
Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
100+
Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
101+
Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
100102

101103
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
102104

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ extern "C" {
495495
GGML_OP_POOL_1D,
496496
GGML_OP_POOL_2D,
497497
GGML_OP_POOL_2D_BACK,
498-
GGML_OP_UPSCALE, // nearest interpolate
498+
GGML_OP_UPSCALE,
499499
GGML_OP_PAD,
500500
GGML_OP_PAD_REFLECT_1D,
501501
GGML_OP_ROLL,

ggml/src/ggml-cuda/common.cuh

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,20 @@ static const char * cu_get_error_str(CUresult err) {
176176
#endif
177177

178178
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
179-
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
180-
do { \
181-
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
182-
const int id = ggml_cuda_get_device(); \
183-
if (!shared_memory_limit_raised[id]) { \
184-
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
185-
shared_memory_limit_raised[id] = true; \
186-
} \
187-
} while (0)
179+
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
180+
do { \
181+
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
182+
const int id = ggml_cuda_get_device(); \
183+
if (!shared_memory_limit_raised[id]) { \
184+
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
185+
shared_memory_limit_raised[id] = true; \
186+
} \
187+
} while (0)
188188
#else
189-
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
189+
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
190+
do { \
191+
GGML_UNUSED(nbytes); \
192+
} while (0)
190193
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
191194

192195
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32(
299299
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
300300
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
301301
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
302-
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
303-
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
304-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
305-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
306-
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
307-
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
308-
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
309-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
302+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
303+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
304+
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
305+
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
306+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
307+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
308+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
309+
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
310310
NO_DEVICE_CODE;
311311
#endif // FLASH_ATTN_AVAILABLE
312312
}

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32(
337337
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
338338
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
339339
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
340-
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
341-
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
342-
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
343-
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
344-
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
345-
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
346-
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
340+
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
341+
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
342+
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
343+
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
344+
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
345+
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
346+
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
347+
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
348+
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
347349
NO_DEVICE_CODE;
348350
#endif // FLASH_ATTN_AVAILABLE
349351
}

0 commit comments

Comments
 (0)