@@ -291,9 +291,12 @@ class chat
291
291
std::string logit_bias_strings_display = " " ;
292
292
std::string logit_bias_strings_ext_display = " " ;
293
293
std::string logit_bias_strings_start_display = " " ;
294
+ std::string logit_bias_strings_manual_display = " " ;
294
295
295
296
std::string last_candidates_logits_display = " " ;
296
297
298
+ std::string dry_sequence_breakers_display = " " ;
299
+
297
300
struct llama_perf_context_data ctx_performance_data;
298
301
299
302
// std::map<std::string,std::string> stats;
@@ -601,6 +604,57 @@ class chat
601
604
// std::getline(std::cin, pause);
602
605
}
603
606
607
+ void sparams_postfill2 () {
608
+ // std::string space = " ";
609
+ if (params.sparams .logit_bias_strings_manual .size ()) {
610
+ for (llama_token i = 0 ; i < llama_vocab_n_tokens (vocab); i++) {
611
+ std::string token_str = common_token_to_piece (ctx, i);
612
+ // cutting spaces since there are "duplicated" tokens with them
613
+ if (token_str.front () == ' ' ) {
614
+ token_str = token_str.substr (1 );
615
+ }
616
+
617
+ // almost never happens
618
+ if (token_str.back () == ' ' ) {
619
+ token_str.pop_back ();
620
+ }
621
+
622
+ bool restricted = false ;
623
+ float bias = -INFINITY;
624
+
625
+ if (token_str.length () > 2 ) {
626
+ for (auto word : params.sparams .logit_bias_strings_manual ) {
627
+ auto token_str_pos = word.find (token_str);
628
+
629
+ if (token_str_pos == 0 || token_str_pos == (word.size () - 1 )) {
630
+ restricted = true ;
631
+ break ;
632
+ } else if (token_str.find (word) == 0 && (token_str.length () - word.length ()) < 4 ) {
633
+ restricted = true ;
634
+ break ;
635
+ }
636
+ }
637
+ } else if (token_str.length () > 0 ) {
638
+ for (auto word : params.sparams .logit_bias_strings_manual ) {
639
+ if (token_str == word) {
640
+ restricted = true ;
641
+ break ;
642
+ }
643
+ }
644
+ }
645
+
646
+ if (restricted == true ) {
647
+ params.sparams .logit_bias_tokens_manual .push_back (i);
648
+ }
649
+ }
650
+ }
651
+
652
+ // std::string pause = "";
653
+ // std::getline(std::cin, pause);
654
+ }
655
+
656
+
657
+
604
658
bool logit_bias_check_exact (std::string_view token_str) {
605
659
for (auto word : params.sparams .logit_bias_strings_exact ) {
606
660
if (token_str == word) return true ;
@@ -757,6 +811,7 @@ class chat
757
811
logit_bias_strings_display = " " ;
758
812
logit_bias_strings_ext_display = " " ;
759
813
logit_bias_strings_start_display = " " ;
814
+ logit_bias_strings_manual_display = " " ;
760
815
761
816
for (auto l : params.sparams .logit_bias ) {
762
817
if (l.bias == -INFINITY) {
@@ -769,6 +824,10 @@ class chat
769
824
for (auto l : logit_bias_tokens_start) {
770
825
logit_bias_strings_start_display += std::format (" '{}';" , common_token_to_piece (ctx, l));
771
826
}
827
+
828
+ for (auto l : params.sparams .logit_bias_tokens_manual ) {
829
+ logit_bias_strings_manual_display += std::format (" '{}';" , common_token_to_piece (ctx, l));
830
+ }
772
831
}
773
832
774
833
void get_last_candidates_logits_display () {
@@ -779,6 +838,14 @@ class chat
779
838
}
780
839
}
781
840
841
+ void get_dry_sequence_breakers_display () {
842
+ dry_sequence_breakers_display.clear ();
843
+
844
+ for (auto breaker : params.sparams .dry_sequence_breakers ) {
845
+ dry_sequence_breakers_display += std::format (" {}; " , breaker);
846
+ }
847
+ }
848
+
782
849
void params_postfill () {
783
850
if (params.kv_overrides_pair .size ()) kv_override_prefill ();
784
851
common_process_override_tensors (params);
@@ -1296,11 +1363,12 @@ class chat
1296
1363
printf (" %s: llama_n_ctx = %d\n " , __func__, n_ctx);
1297
1364
1298
1365
// processing restricted words into logit_bias
1299
- // sparams_postfill ();
1366
+ sparams_postfill2 ();
1300
1367
// sparams_postfill_ext();
1301
1368
// get_safeguard_token("Title");
1302
1369
processByVocab (" Title" );
1303
-
1370
+ get_logit_bias_str ();
1371
+ get_dry_sequence_breakers_display ();
1304
1372
1305
1373
smpl = common_sampler_init (model, sparams);
1306
1374
printf (" %s: common_sampler_init\n " , __func__);
@@ -1611,6 +1679,7 @@ class chat
1611
1679
void check_antiprompt_tkns () {
1612
1680
// check for reverse prompt using special tokens
1613
1681
llama_token last_token = common_sampler_last (smpl);
1682
+
1614
1683
for (std::vector<llama_token> ids : antiprompt_ids) {
1615
1684
if (std::size (ids) == 1 && last_token == ids[0 ]) {
1616
1685
if (params.interactive ) {
@@ -1623,6 +1692,24 @@ class chat
1623
1692
}
1624
1693
}
1625
1694
1695
+ bool check_antiprompt_tkns_bool () {
1696
+ // check for reverse prompt using special tokens
1697
+ llama_token last_token = common_sampler_last (smpl);
1698
+
1699
+ for (std::vector<llama_token> ids : antiprompt_ids) {
1700
+ if (std::size (ids) == 1 && last_token == ids[0 ]) {
1701
+ if (params.interactive ) {
1702
+ is_interacting = true ;
1703
+ has_antiprompt = std::format (" {}: already has antiprompt" , __func__);
1704
+ }
1705
+ is_antiprompt = true ;
1706
+ return true ;
1707
+ }
1708
+ }
1709
+
1710
+ return false ;
1711
+ }
1712
+
1626
1713
// checking already existing contex
1627
1714
int checkEmbd (){
1628
1715
if (debug) printf (" -ce" );
@@ -1678,15 +1765,19 @@ class chat
1678
1765
id = common_sampler_shift (smpl, ctx, -1 , id);
1679
1766
}
1680
1767
1681
- for (auto l_b : params.sparams .logit_bias ) {
1682
- if (l_b.bias < -99 && id == l_b.token ) {
1683
- std::string c_bias_tkn_string = common_token_to_piece (ctx, id);
1684
- writeTextFile (" logit_biasing.txt" , std::format (" Restricted: '{}';" , c_bias_tkn_string));
1768
+ int checks = 0 ;
1769
+ while (checks < params.sparams .logit_bias_tokens_manual .size ()) {
1770
+ for (auto tkn : params.sparams .logit_bias_tokens_manual ) {
1771
+ ++checks;
1772
+ if (id == tkn) {
1773
+ std::string c_bias_tkn_string = common_token_to_piece (ctx, id);
1774
+ writeTextFile (" logit_biasing.txt" , std::format (" {}: Restricted: '{}';" , params.sparams .seed , c_bias_tkn_string));
1685
1775
1686
- id = common_sampler_shift (smpl, ctx, -1 , id);
1776
+ id = common_sampler_shift (smpl, ctx, -1 , id);
1687
1777
1688
- c_bias_tkn_string = common_token_to_piece (ctx, id);
1689
- writeTextFile (" logit_biasing.txt" , std::format (" replaced with: '{}'\n " , c_bias_tkn_string));
1778
+ c_bias_tkn_string = common_token_to_piece (ctx, id);
1779
+ writeTextFile (" logit_biasing.txt" , std::format (" replaced with: '{}'\n " , c_bias_tkn_string));
1780
+ }
1690
1781
}
1691
1782
}
1692
1783
@@ -2009,8 +2100,6 @@ class chat
2009
2100
2010
2101
if (debug) printf (" Starting initial prompt processing...\n " );
2011
2102
2012
- get_logit_bias_str ();
2013
-
2014
2103
2015
2104
std::string result;
2016
2105
// std::cout << " * " << std::endl;
@@ -2075,9 +2164,9 @@ class chat
2075
2164
const std::string getTknFromEmbd (){
2076
2165
if (debug) printf (" -gp" );
2077
2166
2078
- for (auto id : embd) {
2079
- // return llama_token_to_string(ctx, id);
2080
- return common_token_to_piece (ctx, id);
2167
+ for (auto id : embd) {
2168
+ // return llama_token_to_string(ctx, id);
2169
+ return common_token_to_piece (ctx, id);
2081
2170
}
2082
2171
}
2083
2172
@@ -2224,14 +2313,36 @@ class chat
2224
2313
return getTknFromEmbd ();
2225
2314
}
2226
2315
2316
+ std::string getMultiBit (int numTkns = 2 , bool emptyMessage = false , bool shortMessage = false ) { // 1 2 3 4
2317
+ std::string result = " " ;
2318
+
2319
+ for (int i = 0 ; i < numTkns; i++) {
2320
+ if (checkAndClearEmbd () == 0 ) {
2321
+ finished = true ;
2322
+ return txt_vocab_eos;
2323
+ }
2324
+
2325
+ if (!is_interacting) sampleTknIntoEmbd (emptyMessage, shortMessage); // 2
2326
+
2327
+ result += getTknFromEmbd ();
2328
+
2329
+ if (llama_token_is_eog (vocab, common_sampler_last (smpl))) {
2330
+ return result;
2331
+ }
2332
+ }
2333
+
2334
+ return result;
2335
+ }
2336
+
2227
2337
// token by token generation and pushing
2228
2338
std::string cycleStringsOnly (bool emptyMessage = false , bool shortMessage = false ) {
2229
2339
2230
2340
dynamicParamsPrepare ();
2231
2341
// process_prompt(false); // do not forget to include it elsewhere after loading the model
2232
2342
// inputOnly(input); // MOVED
2233
2343
2234
- std::string bit = getBit (emptyMessage, shortMessage);
2344
+ // std::string bit = getBit(emptyMessage, shortMessage);
2345
+ std::string bit = getMultiBit (2 , emptyMessage, shortMessage);
2235
2346
2236
2347
if ((int ) std::size (embd_inp) <= n_consumed) {
2237
2348
if (debug) printf (" -cso" );
0 commit comments