Skip to content

Commit 8c701d7

Browse files
committed
Merge commit '72b090da2c50e540143fd312a2f9aa5f151e6136' into concedo_experimental
# Conflicts: # docs/backend/CANN.md # docs/function-calling.md # examples/embedding/embedding.cpp # examples/retrieval/retrieval.cpp # ggml/src/ggml-cann/CMakeLists.txt # ggml/src/ggml-cann/Doxyfile # ggml/src/ggml-cann/acl_tensor.cpp # ggml/src/ggml-cann/acl_tensor.h # ggml/src/ggml-cann/aclnn_ops.cpp # ggml/src/ggml-cann/aclnn_ops.h # ggml/src/ggml-cann/common.h # ggml/src/ggml-cann/ggml-cann.cpp # ggml/src/ggml-cpu/CMakeLists.txt # ggml/src/ggml-sycl/binbcast.cpp # ggml/src/ggml-sycl/common.hpp # ggml/src/ggml-sycl/concat.cpp # ggml/src/ggml-sycl/conv.cpp # ggml/src/ggml-sycl/cpy.cpp # ggml/src/ggml-sycl/dmmv.cpp # ggml/src/ggml-sycl/element_wise.cpp # ggml/src/ggml-sycl/getrows.cpp # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-sycl/gla.cpp # ggml/src/ggml-sycl/mmvq.cpp # ggml/src/ggml-sycl/norm.cpp # ggml/src/ggml-sycl/outprod.cpp # ggml/src/ggml-sycl/rope.cpp # ggml/src/ggml-sycl/softmax.cpp # ggml/src/ggml-sycl/tsembd.cpp # ggml/src/ggml-sycl/wkv.cpp # scripts/compare-commits.sh # tests/test-chat.cpp # tests/test-sampling.cpp
2 parents 868cb6a + 72b090d commit 8c701d7

File tree

20 files changed

+366
-207
lines changed

20 files changed

+366
-207
lines changed

common/arg.cpp

Lines changed: 113 additions & 93 deletions
Large diffs are not rendered by default.

common/chat-parser.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,23 @@ std::string common_chat_msg_parser::consume_rest() {
170170
}
171171

172172
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
173-
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
173+
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) {
174174
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
175175
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
176176
return std::nullopt;
177177
}
178+
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
179+
pos_ = m.groups[0].end;
180+
181+
if (add_prelude_to_content) {
182+
add_content(prelude);
183+
}
178184
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
179185
if (is_partial()) {
180186
throw common_chat_msg_partial_exception(regex.str());
181187
}
182188
return std::nullopt;
183189
}
184-
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
185-
pos_ = m.groups[0].end;
186-
187190
return find_regex_result{prelude, m.groups};
188191
}
189192

common/chat-parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class common_chat_msg_parser {
3030
const std::string & healing_marker() const { return healing_marker_; }
3131
const bool & is_partial() const { return is_partial_; }
3232
const common_chat_msg & result() const { return result_; }
33+
const common_chat_syntax & syntax() const { return syntax_; }
3334

3435
void move_to(size_t pos) {
3536
if (pos > input_.size()) {
@@ -77,7 +78,7 @@ class common_chat_msg_parser {
7778
std::vector<common_string_range> groups;
7879
};
7980

80-
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
81+
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
8182

8283
bool try_consume_literal(const std::string & literal);
8384

common/chat.cpp

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ static std::string string_diff(const std::string & last, const std::string & cur
3131
return current;
3232
}
3333
if (!string_starts_with(current, last)) {
34+
if (string_starts_with(last, current)) {
35+
// This happens if the last generation ended on a partial stop word (not erased),
36+
// and the current ended on a stop word (erased).
37+
return "";
38+
}
3439
throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
3540
}
3641
return current.substr(last.size());
@@ -101,9 +106,9 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
101106
if (!args_diff.empty() || pref.id != newf.id) {
102107
auto & diff = diffs.emplace_back();
103108
diff.tool_call_index = idx;
104-
diff.tool_call_delta.name = newf.name;
105109
if (pref.id != newf.id) {
106110
diff.tool_call_delta.id = newf.id;
111+
diff.tool_call_delta.name = newf.name;
107112
}
108113
diff.tool_call_delta.arguments = args_diff;
109114
}
@@ -387,22 +392,19 @@ template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_di
387392
delta["content"] = diff.content_delta;
388393
}
389394
if (diff.tool_call_index != std::string::npos) {
395+
json tool_call;
396+
tool_call["index"] = diff.tool_call_index;
397+
if (!diff.tool_call_delta.id.empty()) {
398+
tool_call["id"] = diff.tool_call_delta.id;
399+
tool_call["type"] = "function";
400+
}
390401
json function = json::object();
391402
if (!diff.tool_call_delta.name.empty()) {
392403
function["name"] = diff.tool_call_delta.name;
393404
}
394-
if (!diff.tool_call_delta.id.empty()) {
395-
function["id"] = diff.tool_call_delta.id;
396-
}
397-
if (!diff.tool_call_delta.arguments.empty()) {
398-
function["arguments"] = diff.tool_call_delta.arguments;
399-
}
400-
delta["tool_calls"] = json::array({
401-
json {
402-
{"index", diff.tool_call_index},
403-
{"function", function}
404-
}
405-
});
405+
function["arguments"] = diff.tool_call_delta.arguments;
406+
tool_call["function"] = function;
407+
delta["tool_calls"] = json::array({tool_call});
406408
}
407409
return delta;
408410
}
@@ -654,7 +656,6 @@ static void parse_json_tool_calls(
654656
}
655657
from = std::string::npos;
656658

657-
builder.add_content(res->prelude);
658659
auto maybe_raw_python = name == "python" && allow_raw_python;
659660
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
660661
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
@@ -684,7 +685,6 @@ static void parse_json_tool_calls(
684685
};
685686
if (block_open) {
686687
if (auto res = builder.try_find_regex(*block_open)) {
687-
builder.add_content(res->prelude);
688688
parse_tool_calls();
689689
} else {
690690
builder.add_content(builder.consume_rest());
@@ -697,7 +697,6 @@ static void parse_json_tool_calls(
697697
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) {
698698
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
699699
if (auto res = builder.try_find_regex(prefix)) {
700-
builder.add_content(res->prelude);
701700
builder.move_back(rstrip_prefix);
702701
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
703702
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
@@ -833,6 +832,10 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
833832
return data;
834833
}
835834
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
835+
if (!builder.syntax().parse_tool_calls) {
836+
builder.add_content(builder.consume_rest());
837+
return;
838+
}
836839
static const std::vector<std::vector<std::string>> content_paths = {
837840
{"response"},
838841
};
@@ -905,6 +908,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
905908
return data;
906909
}
907910
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
911+
if (!builder.syntax().parse_tool_calls) {
912+
builder.add_content(builder.consume_rest());
913+
return;
914+
}
915+
908916
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
909917
parse_prefixed_json_tool_call_array(builder, prefix);
910918
}
@@ -999,7 +1007,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
9991007

10001008
if (auto res = builder.try_find_regex(start_action_regex)) {
10011009
// If we didn't extract thoughts, prelude includes them.
1002-
builder.add_content(res->prelude);
10031010
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
10041011
for (const auto & tool_call : tool_calls.value) {
10051012
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
@@ -1014,11 +1021,7 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
10141021
}
10151022
builder.consume_regex(end_action_regex);
10161023
} else if (auto res = builder.try_find_regex(start_response_regex)) {
1017-
// If we didn't extract thoughts, prelude includes them.
1018-
builder.add_content(res->prelude);
1019-
if (auto res = builder.try_find_regex(end_response_regex)) {
1020-
builder.add_content(res->prelude);
1021-
} else {
1024+
if (!builder.try_find_regex(end_response_regex)) {
10221025
builder.add_content(builder.consume_rest());
10231026
throw common_chat_msg_partial_exception(end_response_regex.str());
10241027
}
@@ -1126,6 +1129,11 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
11261129
return data;
11271130
}
11281131
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
1132+
if (!builder.syntax().parse_tool_calls) {
1133+
builder.add_content(builder.consume_rest());
1134+
return;
1135+
}
1136+
11291137
static const common_regex function_regex(
11301138
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
11311139
static const common_regex close_regex("\\}\\s*");
@@ -1136,8 +1144,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w
11361144
if (with_builtin_tools) {
11371145
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
11381146
if (auto res = builder.try_find_regex(builtin_call_regex)) {
1139-
builder.add_content(res->prelude);
1140-
11411147
auto fun_res = builder.consume_regex(function_name_regex);
11421148
auto function_name = builder.str(fun_res.groups[1]);
11431149

@@ -1253,6 +1259,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
12531259
}
12541260
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
12551261
builder.try_parse_reasoning("<think>", "</think>");
1262+
if (!builder.syntax().parse_tool_calls) {
1263+
builder.add_content(builder.consume_rest());
1264+
return;
1265+
}
12561266

12571267
static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)");
12581268
static const common_regex tool_calls_end("<|tool▁calls▁end|>");
@@ -1314,6 +1324,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
13141324
return data;
13151325
}
13161326
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
1327+
if (!builder.syntax().parse_tool_calls) {
1328+
builder.add_content(builder.consume_rest());
1329+
return;
1330+
}
13171331
static const common_regex prefix(regex_escape(" functools["));
13181332
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
13191333
}
@@ -1455,15 +1469,12 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
14551469
return data;
14561470
}
14571471
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
1458-
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
1459-
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
1460-
1461-
if (auto res = builder.try_find_regex(python_tag_regex)) {
1462-
builder.add_content(res->prelude);
1463-
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
1464-
builder.add_tool_call("python", "", arguments);
1472+
if (!builder.syntax().parse_tool_calls) {
1473+
builder.add_content(builder.consume_rest());
14651474
return;
14661475
}
1476+
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
1477+
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
14671478

14681479
static const common_regex function_regex(R"(<function=(\w+)>)");
14691480
static const common_regex close_regex(R"(</function>)");
@@ -1475,6 +1486,12 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser
14751486
function_regex,
14761487
close_regex,
14771488
std::nullopt);
1489+
1490+
if (auto res = builder.try_find_regex(python_tag_regex)) {
1491+
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
1492+
builder.add_tool_call("python", "", arguments);
1493+
return;
1494+
}
14781495
}
14791496

14801497
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1593,6 +1610,10 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
15931610
}
15941611
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
15951612
builder.try_parse_reasoning("<think>", "</think>");
1613+
if (!builder.syntax().parse_tool_calls) {
1614+
builder.add_content(builder.consume_rest());
1615+
return;
1616+
}
15961617

15971618
static const common_regex open_regex(
15981619
"(?:"
@@ -1614,8 +1635,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16141635
);
16151636

16161637
if (auto res = builder.try_find_regex(open_regex)) {
1617-
builder.add_content(res->prelude);
1618-
16191638
const auto & block_start = res->groups[1];
16201639
std::string block_end = block_start.empty() ? "" : "```";
16211640

@@ -1851,10 +1870,10 @@ static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
18511870
builder.add_content(builder.consume_rest());
18521871
}
18531872

1854-
static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) {
1855-
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format), builder.input().c_str());
1873+
static void common_chat_parse(common_chat_msg_parser & builder) {
1874+
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
18561875

1857-
switch (format) {
1876+
switch (builder.syntax().format) {
18581877
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
18591878
common_chat_parse_content_only(builder);
18601879
break;
@@ -1889,15 +1908,15 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form
18891908
common_chat_parse_command_r7b(builder);
18901909
break;
18911910
default:
1892-
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(format));
1911+
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
18931912
}
18941913
builder.finish();
18951914
}
18961915

18971916
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
18981917
common_chat_msg_parser builder(input, is_partial, syntax);
18991918
try {
1900-
common_chat_parse(builder, syntax.format);
1919+
common_chat_parse(builder);
19011920
} catch (const common_chat_msg_partial_exception & ex) {
19021921
LOG_DBG("Partial parse: %s\n", ex.what());
19031922
if (!is_partial) {

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ struct common_chat_syntax {
144144
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
145145
bool reasoning_in_content = false;
146146
bool thinking_forced_open = false;
147+
bool parse_tool_calls = true;
147148
};
148149

149150
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ struct common_params {
287287
int32_t verbosity = 0;
288288
int32_t control_vector_layer_start = -1; // layer range for control vector
289289
int32_t control_vector_layer_end = -1; // layer range for control vector
290+
bool offline = false;
290291

291292
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
292293
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line

examples/training/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ Proof of concept:
1010

1111
``` sh
1212
export model_name=llama_3.2-1b && export quantization=f32
13-
./build/bin/finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
14-
./build/bin/perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
13+
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
14+
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
1515
```
1616

1717
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

ggml/src/ggml-backend.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,9 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
16041604
for (int i = 0; i < sched->n_backends; i++) {
16051605
ggml_backend_synchronize(sched->backends[i]);
16061606
}
1607+
// reset the current copy to 0 so that the graphs will be similar during generation
1608+
// necessary for CUDA graphs
1609+
sched->cur_copy = 0;
16071610
}
16081611

16091612
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
168168

169169
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
170170

171-
#if !defined(GGML_USE_HIP)
171+
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
172172
static const char * cu_get_error_str(CUresult err) {
173173
const char * err_str;
174174
cuGetErrorString(err, &err_str);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6476,6 +6476,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
64766476
case GGML_OP_ROPE:
64776477
case GGML_OP_RMS_NORM:
64786478
case GGML_OP_CONV_2D_DW:
6479+
case GGML_OP_IM2COL:
64796480
return true;
64806481
default:
64816482
return false;

0 commit comments

Comments
 (0)