Skip to content

Commit 7e6f733

Browse files
committed
Latest commits, manual biasing
* added `logit_bias_strings_manual` for real-time shifting of unwanted tokens * minor fixes
1 parent a247762 commit 7e6f733

19 files changed

+1066
-465
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ HEADERS_GGUF_BASE = \
424424
$(ggmlsrc_f_s)/ggml-threading.h \
425425
$(ggmlsrc_cpu_f)/hbm.h \
426426
$(ggmlsrc_cpu_f)/ggml-cpu-impl.h \
427-
$(ggmlsrc_cpu_f)/ggml-cpu-quants.h \
428-
$(ggmlsrc_cpu_f)/ggml-cpu-traits.h \
427+
$(ggmlsrc_cpu_f)/quants.h \
428+
$(ggmlsrc_cpu_f)/traits.h \
429429
$(ggmlsrc_cpu_f)/common.h \
430430
$(ggmlsrc_cpu_f)/binary-ops.h \
431431
$(ggmlsrc_cpu_f)/unary-ops.h \
@@ -751,7 +751,7 @@ ui_simple = $(uibackend_f)/UI_simple.h
751751
endif
752752

753753
# Final parts
754-
$(TMP)$(PREFIX)_class_chat.o:$(conapp) $(COMMON_H_DEPS) $(json_layer) $(chat_layer) $(settings_layer) $(OBJS_GGUF)
754+
$(TMP)$(PREFIX)_class_chat.o:$(conapp) $(HEADERS_GGUF_BASE) $(COMMON_H_DEPS) $(json_layer) $(chat_layer) $(settings_layer) $(OBJS_GGUF)
755755
@echo ------------------------------------------------------------------------
756756
$(CXX) $(I_GGUF) $(CXXFLAGS) $(LDFLAGS) -c $< -o $@
757757
@echo ---------------CHAT COMPILED with: $(PREFIX)

base_sampling2/chat_layer.h

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,12 @@ class chat
291291
std::string logit_bias_strings_display = "";
292292
std::string logit_bias_strings_ext_display = "";
293293
std::string logit_bias_strings_start_display = "";
294+
std::string logit_bias_strings_manual_display = "";
294295

295296
std::string last_candidates_logits_display = "";
296297

298+
std::string dry_sequence_breakers_display = "";
299+
297300
struct llama_perf_context_data ctx_performance_data;
298301

299302
//std::map<std::string,std::string> stats;
@@ -601,6 +604,57 @@ class chat
601604
// std::getline(std::cin, pause);
602605
}
603606

607+
void sparams_postfill2() {
608+
// std::string space = " ";
609+
if (params.sparams.logit_bias_strings_manual.size()) {
610+
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
611+
std::string token_str = common_token_to_piece(ctx, i);
612+
// cutting spaces since there are "duplicated" tokens with them
613+
if (token_str.front() == ' ') {
614+
token_str = token_str.substr(1);
615+
}
616+
617+
// almost never happens
618+
if (token_str.back() == ' ') {
619+
token_str.pop_back();
620+
}
621+
622+
bool restricted = false;
623+
float bias = -INFINITY;
624+
625+
if (token_str.length() > 2) {
626+
for (auto word : params.sparams.logit_bias_strings_manual) {
627+
auto token_str_pos = word.find(token_str);
628+
629+
if (token_str_pos == 0 || token_str_pos == (word.size() - 1)) {
630+
restricted = true;
631+
break;
632+
} else if (token_str.find(word) == 0 && (token_str.length() - word.length()) < 4) {
633+
restricted = true;
634+
break;
635+
}
636+
}
637+
} else if (token_str.length() > 0) {
638+
for (auto word : params.sparams.logit_bias_strings_manual) {
639+
if (token_str == word) {
640+
restricted = true;
641+
break;
642+
}
643+
}
644+
}
645+
646+
if (restricted == true) {
647+
params.sparams.logit_bias_tokens_manual.push_back(i);
648+
}
649+
}
650+
}
651+
652+
// std::string pause = "";
653+
// std::getline(std::cin, pause);
654+
}
655+
656+
657+
604658
bool logit_bias_check_exact(std::string_view token_str) {
605659
for (auto word : params.sparams.logit_bias_strings_exact) {
606660
if (token_str == word) return true;
@@ -757,6 +811,7 @@ class chat
757811
logit_bias_strings_display = "";
758812
logit_bias_strings_ext_display = "";
759813
logit_bias_strings_start_display = "";
814+
logit_bias_strings_manual_display = "";
760815

761816
for (auto l : params.sparams.logit_bias) {
762817
if (l.bias == -INFINITY) {
@@ -769,6 +824,10 @@ class chat
769824
for (auto l : logit_bias_tokens_start) {
770825
logit_bias_strings_start_display += std::format(" '{}';", common_token_to_piece(ctx, l));
771826
}
827+
828+
for (auto l : params.sparams.logit_bias_tokens_manual) {
829+
logit_bias_strings_manual_display += std::format(" '{}';", common_token_to_piece(ctx, l));
830+
}
772831
}
773832

774833
void get_last_candidates_logits_display() {
@@ -779,6 +838,14 @@ class chat
779838
}
780839
}
781840

841+
void get_dry_sequence_breakers_display() {
842+
dry_sequence_breakers_display.clear();
843+
844+
for (auto breaker : params.sparams.dry_sequence_breakers) {
845+
dry_sequence_breakers_display += std::format("{}; ", breaker);
846+
}
847+
}
848+
782849
void params_postfill() {
783850
if (params.kv_overrides_pair.size()) kv_override_prefill();
784851
common_process_override_tensors(params);
@@ -1296,11 +1363,12 @@ class chat
12961363
printf("%s: llama_n_ctx = %d\n", __func__, n_ctx);
12971364

12981365
// processing restricted words into logit_bias
1299-
// sparams_postfill();
1366+
sparams_postfill2();
13001367
//sparams_postfill_ext();
13011368
// get_safeguard_token("Title");
13021369
processByVocab("Title");
1303-
1370+
get_logit_bias_str();
1371+
get_dry_sequence_breakers_display();
13041372

13051373
smpl = common_sampler_init(model, sparams);
13061374
printf("%s: common_sampler_init\n", __func__);
@@ -1611,6 +1679,7 @@ class chat
16111679
void check_antiprompt_tkns() {
16121680
// check for reverse prompt using special tokens
16131681
llama_token last_token = common_sampler_last(smpl);
1682+
16141683
for (std::vector<llama_token> ids : antiprompt_ids) {
16151684
if (std::size(ids) == 1 && last_token == ids[0]) {
16161685
if (params.interactive) {
@@ -1623,6 +1692,24 @@ class chat
16231692
}
16241693
}
16251694

1695+
bool check_antiprompt_tkns_bool() {
1696+
// check for reverse prompt using special tokens
1697+
llama_token last_token = common_sampler_last(smpl);
1698+
1699+
for (std::vector<llama_token> ids : antiprompt_ids) {
1700+
if (std::size(ids) == 1 && last_token == ids[0]) {
1701+
if (params.interactive) {
1702+
is_interacting = true;
1703+
has_antiprompt = std::format("{}: already has antiprompt", __func__);
1704+
}
1705+
is_antiprompt = true;
1706+
return true;
1707+
}
1708+
}
1709+
1710+
return false;
1711+
}
1712+
16261713
//checking already existing contex
16271714
int checkEmbd(){
16281715
if (debug) printf("-ce");
@@ -1678,15 +1765,19 @@ class chat
16781765
id = common_sampler_shift(smpl, ctx, -1, id);
16791766
}
16801767

1681-
for (auto l_b : params.sparams.logit_bias) {
1682-
if (l_b.bias < -99 && id == l_b.token) {
1683-
std::string c_bias_tkn_string = common_token_to_piece(ctx, id);
1684-
writeTextFile("logit_biasing.txt", std::format("Restricted: '{}';", c_bias_tkn_string));
1768+
int checks = 0;
1769+
while (checks < params.sparams.logit_bias_tokens_manual.size()) {
1770+
for (auto tkn : params.sparams.logit_bias_tokens_manual) {
1771+
++checks;
1772+
if (id == tkn) {
1773+
std::string c_bias_tkn_string = common_token_to_piece(ctx, id);
1774+
writeTextFile("logit_biasing.txt", std::format("{}: Restricted: '{}';", params.sparams.seed, c_bias_tkn_string));
16851775

1686-
id = common_sampler_shift(smpl, ctx, -1, id);
1776+
id = common_sampler_shift(smpl, ctx, -1, id);
16871777

1688-
c_bias_tkn_string = common_token_to_piece(ctx, id);
1689-
writeTextFile("logit_biasing.txt", std::format(" replaced with: '{}'\n", c_bias_tkn_string));
1778+
c_bias_tkn_string = common_token_to_piece(ctx, id);
1779+
writeTextFile("logit_biasing.txt", std::format(" replaced with: '{}'\n", c_bias_tkn_string));
1780+
}
16901781
}
16911782
}
16921783

@@ -2009,8 +2100,6 @@ class chat
20092100

20102101
if (debug) printf("Starting initial prompt processing...\n");
20112102

2012-
get_logit_bias_str();
2013-
20142103

20152104
std::string result;
20162105
//std::cout << " * " << std::endl;
@@ -2075,9 +2164,9 @@ class chat
20752164
const std::string getTknFromEmbd(){
20762165
if (debug) printf("-gp");
20772166

2078-
for (auto id : embd) {
2079-
//return llama_token_to_string(ctx, id);
2080-
return common_token_to_piece(ctx, id);
2167+
for (auto id : embd) {
2168+
//return llama_token_to_string(ctx, id);
2169+
return common_token_to_piece(ctx, id);
20812170
}
20822171
}
20832172

@@ -2224,14 +2313,36 @@ class chat
22242313
return getTknFromEmbd();
22252314
}
22262315

2316+
std::string getMultiBit(int numTkns = 2, bool emptyMessage = false, bool shortMessage = false) { // 1 2 3 4
2317+
std::string result = "";
2318+
2319+
for (int i = 0; i < numTkns; i++) {
2320+
if (checkAndClearEmbd() == 0) {
2321+
finished = true;
2322+
return txt_vocab_eos;
2323+
}
2324+
2325+
if (!is_interacting) sampleTknIntoEmbd(emptyMessage, shortMessage); // 2
2326+
2327+
result += getTknFromEmbd();
2328+
2329+
if (llama_token_is_eog(vocab, common_sampler_last(smpl))) {
2330+
return result;
2331+
}
2332+
}
2333+
2334+
return result;
2335+
}
2336+
22272337
// token by token generation and pushing
22282338
std::string cycleStringsOnly(bool emptyMessage = false, bool shortMessage = false) {
22292339

22302340
dynamicParamsPrepare();
22312341
//process_prompt(false); // do not forget to include it elsewhere after loading the model
22322342
//inputOnly(input); // MOVED
22332343

2234-
std::string bit = getBit(emptyMessage, shortMessage);
2344+
// std::string bit = getBit(emptyMessage, shortMessage);
2345+
std::string bit = getMultiBit(2, emptyMessage, shortMessage);
22352346

22362347
if ((int) std::size(embd_inp) <= n_consumed) {
22372348
if (debug) printf("-cso");

base_sampling2/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ struct common_params_sampling {
200200
std::vector<std::string> logit_bias_strings_beginning; // words for logit biases, beginning of the word matches
201201
std::vector<std::string> logit_bias_strings_ending; // words for logit biases, ending of the word matches
202202

203+
std::vector<llama_token> logit_bias_tokens_manual; // tokens for manual restricting
204+
std::vector<std::string> logit_bias_strings_manual; // words for manual restricting
205+
203206

204207
std::map<std::string, float> logit_bias_strings_ext; // words for logit biases, but with extra configuration
205208
std::vector<std::string> logit_bias_strings_start; // restricted beginnings of messages

base_sampling2/include/jsonParams.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ static void getSamplingParamsFromJson(nlohmann::json& config, common_params& par
551551
if (checkJNum(config, "dry_base")) params.sparams.dry_base = config["dry_base"];
552552
if (checkJNum(config, "dry_allowed_length")) params.sparams.dry_allowed_length = config["dry_allowed_length"];
553553
if (checkJNum(config, "dry_penalty_last_n")) params.sparams.dry_penalty_last_n = config["dry_penalty_last_n"];
554+
if (checkJArr(config, "dry_sequence_breakers")) params.sparams.dry_sequence_breakers = config["dry_sequence_breakers"];
554555

555556
//mirostat
556557
if (checkJNum(config, "mirostat")) params.sparams.mirostat = config["mirostat"];
@@ -562,6 +563,7 @@ static void getSamplingParamsFromJson(nlohmann::json& config, common_params& par
562563
if (checkJArr(config, "logit_bias_strings_exact")) params.sparams.logit_bias_strings_exact = config["logit_bias_strings_exact"];
563564
if (checkJArr(config, "logit_bias_strings_beginning")) params.sparams.logit_bias_strings_beginning = config["logit_bias_strings_beginning"];
564565
if (checkJArr(config, "logit_bias_strings_ending")) params.sparams.logit_bias_strings_ending = config["logit_bias_strings_ending"];
566+
if (checkJArr(config, "logit_bias_strings_manual")) params.sparams.logit_bias_strings_manual = config["logit_bias_strings_manual"];
565567

566568

567569
if (checkJObj(config, "logit_bias_strings_ext")) params.sparams.logit_bias_strings_ext = config["logit_bias_strings_ext"];

base_sampling2/master/ggml/src/ggml-cpu/ggml-cpu-impl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,14 @@ void ggml_barrier(struct ggml_threadpool * tp);
518518
#elif defined(__GNUC__)
519519
// GCC/Clang on *nix
520520
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(weak name = alias) // NOLINT
521-
#elif defined(_MSC_VER) && defined (_WIN64)
521+
#elif defined(_MSC_VER) && defined(_WIN64)
522522
// MSVC
523523
// Note: C name mangling varies across different calling conventions
524524
// see https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170
525525
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:" #name "=" #alias))
526+
#elif defined(_MSC_VER) && defined(WIN32)
527+
// ref: https://github.com/ggml-org/whisper.cpp/pull/3239#issuecomment-2958224591
528+
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:_" #name "=_" #alias))
526529
#else
527530
# error "Unsupported compiler for GGML_WEAK_ALIAS"
528531
#endif

base_sampling2/master/ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ set(GGML_OPENCL_KERNELS
8080
mul_mv_q4_0_f32_1d_8x_flat
8181
mul_mv_q4_0_f32_1d_16x_flat
8282
mul_mv_q6_k
83+
mul_mv_id_q4_0_f32_8x_flat
8384
mul
8485
norm
8586
relu

0 commit comments

Comments
 (0)