Skip to content

Commit 443e7e7

Browse files
committed
Merge branch 'mamba2-sync' into GraniteFour
* mamba2-sync: (24 commits) sync : ggml Add `ggml_roll` (ggml/1274) docs : fix the link to llama.h (ggml-org#14293) CUDA: add conv_2d_transpose (ggml-org#14287) lint : remove trailing whitepace (ggml-org#14304) vocab : prevent tokenizer overflow (ggml-org#14301) sycl: add usage of enqueue_functions extension (ggml-org#14244) Implement GGML_CPU_ALL_VARIANTS for PowerPC (ggml-org#14286) llama : improve sep token handling (ggml-org#14272) cuda : synchronize graph capture and cublas handle destruction (ggml-org#14288) ggml : fix repack work size for mul_mat_id (ggml-org#14292) ggml: Update KleidiAI to v1.9.0 (ggml-org#14277) model : more uniform output id handling (ggml-org#14275) ubatch : new splitting logic (ggml-org#14217) CUDA: add conv_2d_dw (ggml-org#14265) ggml-cpu : remove unnecesary arm feature detection (ggml-org#14281) gguf-py : make sentencepiece optional (ggml-org#14200) server : add server parameters for draft model cache type (ggml-org#13782) build : suppress gcc15 compile warnings (ggml-org#14261) sycl: Cleanup codepaths in Get Rows in sycl backend (ggml-org#14215) ...
2 parents 8f3af99 + b605bb9 commit 443e7e7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+4212
-3693
lines changed

ci/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ function gg_run_rerank_tiny {
779779
model_f16="${path_models}/ggml-model-f16.gguf"
780780

781781
# for this model, the SEP token is "</s>"
782-
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s></s>hi\nwhat is panda?</s></s>it's a bear\nwhat is panda?</s></s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
782+
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
783783

784784
# sample output
785785
# rerank score 0: 0.029

common/arg.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,6 +2706,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27062706
params.embd_sep = value;
27072707
}
27082708
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
2709+
add_opt(common_arg(
2710+
{"--cls-separator"}, "STRING",
2711+
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
2712+
[](common_params & params, const std::string & value) {
2713+
params.cls_sep = value;
2714+
}
2715+
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
27092716
add_opt(common_arg(
27102717
{"--host"}, "HOST",
27112718
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),
@@ -3210,6 +3217,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32103217
params.speculative.model.path = value;
32113218
}
32123219
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
3220+
add_opt(common_arg(
3221+
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
3222+
string_format(
3223+
"KV cache data type for K for the draft model\n"
3224+
"allowed values: %s\n"
3225+
"(default: %s)",
3226+
get_all_kv_cache_types().c_str(),
3227+
ggml_type_name(params.speculative.cache_type_k)
3228+
),
3229+
[](common_params & params, const std::string & value) {
3230+
params.speculative.cache_type_k = kv_cache_type_from_str(value);
3231+
}
3232+
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
3233+
add_opt(common_arg(
3234+
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
3235+
string_format(
3236+
"KV cache data type for V for the draft model\n"
3237+
"allowed values: %s\n"
3238+
"(default: %s)",
3239+
get_all_kv_cache_types().c_str(),
3240+
ggml_type_name(params.speculative.cache_type_v)
3241+
),
3242+
[](common_params & params, const std::string & value) {
3243+
params.speculative.cache_type_v = kv_cache_type_from_str(value);
3244+
}
3245+
).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT"));
32133246

32143247
add_opt(common_arg(
32153248
{"-mv", "--model-vocoder"}, "FNAME",

common/common.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,17 @@ bool fs_validate_filename(const std::string & filename) {
706706
// disable C++17 deprecation warning for std::codecvt_utf8
707707
# pragma clang diagnostic push
708708
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
709+
#elif defined(__GNUC__)
710+
# pragma GCC diagnostic push
711+
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
709712
#endif
713+
710714
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
711715

712716
#if defined(__clang__)
713717
# pragma clang diagnostic pop
718+
#elif defined(__GNUC__)
719+
# pragma GCC diagnostic pop
714720
#endif
715721

716722
filename_utf32 = converter.from_bytes(filename);
@@ -1284,6 +1290,9 @@ std::vector<llama_token> common_tokenize(
12841290
int n_tokens = text.length() + 2 * add_special;
12851291
std::vector<llama_token> result(n_tokens);
12861292
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1293+
if (n_tokens == std::numeric_limits<int32_t>::min()) {
1294+
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
1295+
}
12871296
if (n_tokens < 0) {
12881297
result.resize(-n_tokens);
12891298
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ struct common_params_speculative {
199199
float p_split = 0.1f; // speculative decoding split probability
200200
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
201201

202+
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
203+
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
204+
202205
struct cpu_params cpuparams;
203206
struct cpu_params cpuparams_batch;
204207

@@ -355,6 +358,7 @@ struct common_params {
355358
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
356359
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
357360
std::string embd_sep = "\n"; // separator of embeddings
361+
std::string cls_sep = "\t"; // separator of classification sequences
358362

359363
// server params
360364
int32_t port = 8080; // server listens on this network port

convert_hf_to_gguf.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,6 @@ def __init__(self, *args, **kwargs):
21452145

21462146
def set_vocab(self):
21472147
self._set_vocab_gpt2()
2148-
self.gguf_writer.add_add_bos_token(True)
21492148

21502149
def set_gguf_parameters(self):
21512150
super().set_gguf_parameters()
@@ -3918,9 +3917,6 @@ def _xlmroberta_set_vocab(self) -> None:
39183917
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
39193918
special_vocab.add_to_gguf(self.gguf_writer)
39203919

3921-
self.gguf_writer.add_add_bos_token(True)
3922-
self.gguf_writer.add_add_eos_token(True)
3923-
39243920

39253921
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
39263922
class DistilBertModel(BertModel):
@@ -3962,8 +3958,6 @@ def set_vocab(self):
39623958
bpe_tok_path = self.dir_model / "tokenizer.json"
39633959
if bpe_tok_path.exists():
39643960
self._set_vocab_gpt2()
3965-
self.gguf_writer.add_add_bos_token(True)
3966-
self.gguf_writer.add_add_eos_token(True)
39673961

39683962
# we need this to validate the size of the token_type embeddings
39693963
# though currently we are passing all zeros to the token_type embeddings
@@ -5056,8 +5050,6 @@ def set_vocab(self):
50565050
self.gguf_writer.add_token_type_count(2)
50575051
else:
50585052
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
5059-
self.gguf_writer.add_add_bos_token(True)
5060-
self.gguf_writer.add_add_eos_token(True)
50615053

50625054

50635055
@ModelBase.register("OpenELMForCausalLM")
@@ -5659,9 +5651,6 @@ def set_vocab(self):
56595651
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
56605652
special_vocab.add_to_gguf(self.gguf_writer)
56615653

5662-
self.gguf_writer.add_add_bos_token(False)
5663-
self.gguf_writer.add_add_eos_token(True)
5664-
56655654
def set_gguf_parameters(self):
56665655
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
56675656
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -5799,9 +5788,6 @@ def set_vocab(self):
57995788
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
58005789
special_vocab.add_to_gguf(self.gguf_writer)
58015790

5802-
self.gguf_writer.add_add_bos_token(False)
5803-
self.gguf_writer.add_add_eos_token(True)
5804-
58055791
def set_gguf_parameters(self):
58065792
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
58075793
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -6630,8 +6616,8 @@ def parse_args() -> argparse.Namespace:
66306616
help="model is executed on big endian machine",
66316617
)
66326618
parser.add_argument(
6633-
"model", type=Path,
6634-
help="directory containing model file",
6619+
"model", type=str,
6620+
help="directory containing model file or huggingface repository ID (if --remote)",
66356621
nargs="?",
66366622
)
66376623
parser.add_argument(
@@ -6742,18 +6728,20 @@ def main() -> None:
67426728
else:
67436729
logging.basicConfig(level=logging.INFO)
67446730

6745-
dir_model = args.model
6746-
67476731
if args.remote:
6732+
hf_repo_id = args.model
67486733
from huggingface_hub import snapshot_download
67496734
local_dir = snapshot_download(
6750-
repo_id=str(dir_model),
6735+
repo_id=hf_repo_id,
67516736
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
67526737
dir_model = Path(local_dir)
67536738
logger.info(f"Downloaded config and tokenizer to {local_dir}")
6739+
else:
6740+
hf_repo_id = None
6741+
dir_model = Path(args.model)
67546742

67556743
if not dir_model.is_dir():
6756-
logger.error(f'Error: {args.model} is not a directory')
6744+
logger.error(f'Error: {dir_model} is not a directory')
67576745
sys.exit(1)
67586746

67596747
ftype_map: dict[str, gguf.LlamaFileType] = {
@@ -6773,9 +6761,9 @@ def main() -> None:
67736761

67746762
if args.outfile is not None:
67756763
fname_out = args.outfile
6776-
elif args.remote:
6764+
elif hf_repo_id:
67776765
# if remote, use the model ID as the output file name
6778-
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
6766+
fname_out = Path("./" + hf_repo_id.replace("/", "-") + "-{ftype}.gguf")
67796767
else:
67806768
fname_out = dir_model
67816769

@@ -6804,7 +6792,7 @@ def main() -> None:
68046792
split_max_tensors=args.split_max_tensors,
68056793
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
68066794
small_first_shard=args.no_tensor_first_split,
6807-
remote_hf_model_id=str(args.model) if args.remote else None)
6795+
remote_hf_model_id=hf_repo_id)
68086796

68096797
if args.vocab_only:
68106798
logger.info("Exporting model vocab...")

docs/build.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Build llama.cpp locally
22

3-
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h).
3+
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
44

55
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
66

examples/embedding/embedding.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
133133
// max batch size
134134
const uint64_t n_batch = params.n_batch;
135135

136+
// get added sep and eos token, if any
137+
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
138+
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
139+
136140
// tokenize the prompts and trim
137141
std::vector<std::vector<int32_t>> inputs;
138142
for (const auto & prompt : prompts) {
139-
auto inp = common_tokenize(ctx, prompt, true, true);
143+
std::vector<llama_token> inp;
144+
145+
// split classification pairs and insert expected separator tokens
146+
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
147+
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
148+
std::string final_prompt;
149+
150+
for (size_t i = 0; i < pairs.size(); i++) {
151+
final_prompt += pairs[i];
152+
if (i != pairs.size() - 1) {
153+
if (!added_eos_token.empty()) {
154+
final_prompt += added_eos_token;
155+
}
156+
if (!added_sep_token.empty()) {
157+
final_prompt += added_sep_token;
158+
}
159+
}
160+
}
161+
162+
inp = common_tokenize(ctx, final_prompt, true, true);
163+
} else {
164+
inp = common_tokenize(ctx, prompt, true, true);
165+
}
140166
if (inp.size() > n_batch) {
141167
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
142168
__func__, (long long int) inp.size(), (long long int) n_batch);
@@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
145171
inputs.push_back(inp);
146172
}
147173

148-
// check if the last token is SEP
174+
// check if the last token is SEP/EOS
149175
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
150176
for (auto & inp : inputs) {
151-
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
152-
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
177+
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
178+
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
153179
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
154180
}
155181
}

ggml/include/ggml.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ extern "C" {
489489
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492+
GGML_OP_ROLL,
492493
GGML_OP_ARANGE,
493494
GGML_OP_TIMESTEP_EMBEDDING,
494495
GGML_OP_ARGSORT,
@@ -1801,6 +1802,17 @@ extern "C" {
18011802
int p0,
18021803
int p1);
18031804

1805+
// Move tensor elements by an offset given for each dimension. Elements that
1806+
// are shifted beyond the last position are wrapped around to the beginning.
1807+
GGML_API struct ggml_tensor * ggml_roll(
1808+
struct ggml_context * ctx,
1809+
struct ggml_tensor * a,
1810+
int shift0,
1811+
int shift1,
1812+
int shift2,
1813+
int shift3);
1814+
1815+
18041816
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
18051817
// timesteps: [N,]
18061818
// return: [N, dim]

ggml/src/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ function(ggml_add_cpu_backend_variant tag_name)
286286
foreach (feat ${ARGN})
287287
set(GGML_INTERNAL_${feat} ON)
288288
endforeach()
289+
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
290+
foreach (feat ${ARGN})
291+
set(GGML_INTERNAL_${feat} ON)
292+
endforeach()
289293
endif()
290294

291295
ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -337,6 +341,19 @@ if (GGML_CPU_ALL_VARIANTS)
337341
else()
338342
message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
339343
endif()
344+
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
345+
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
346+
ggml_add_cpu_backend_variant(power0)
347+
ggml_add_cpu_backend_variant(power7_1 POWER7)
348+
ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
349+
ggml_add_cpu_backend_variant(power8_1 POWER8)
350+
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
351+
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
352+
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
353+
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
354+
else()
355+
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
356+
endif()
340357
else()
341358
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
342359
endif()

ggml/src/ggml-backend-reg.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@
6969
#if defined(__clang__)
7070
# pragma clang diagnostic push
7171
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
72+
#elif defined(__GNUC__)
73+
# pragma GCC diagnostic push
74+
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
7275
#endif
7376

7477
namespace fs = std::filesystem;
@@ -91,6 +94,8 @@ static std::string path_str(const fs::path & path) {
9194

9295
#if defined(__clang__)
9396
# pragma clang diagnostic pop
97+
#elif defined(__GNUC__)
98+
# pragma GCC diagnostic pop
9499
#endif
95100

96101
#ifdef _WIN32

0 commit comments

Comments
 (0)