@@ -49,7 +49,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) {
49
49
}
50
50
}
51
51
52
- static void llama_constraint_top_k_impl (llama_token_data_array * cur_p, int32_t k, size_t min_keep ) {
52
+ static void llama_constraint_top_k_impl (llama_token_data_array * cur_p, int32_t k) {
53
53
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
54
54
// if (k >= (int32_t)cur_p->size) {
55
55
// return;
@@ -59,7 +59,6 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t
59
59
k = cur_p->size ;
60
60
}
61
61
62
- k = std::max (k, (int ) min_keep);
63
62
k = std::min (k, (int ) cur_p->size );
64
63
65
64
// Sort scores in descending order
@@ -449,32 +448,30 @@ struct llama_constraint * llama_constraint_init_softmax_impl() {
449
448
450
449
struct llama_constraint_context_top_k {
451
450
const int32_t k;
452
- const size_t min_keep;
453
451
};
454
452
455
453
static struct llama_constraint_i llama_constraint_top_k_i = {
456
454
/* .name = */ [](const struct llama_constraint * /* cnstr*/ ) { return " top-k" ; },
457
455
/* .accept = */ nullptr ,
458
456
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
459
457
const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx ;
460
- llama_constraint_top_k_impl (cur_p, ctx->k , ctx-> min_keep );
458
+ llama_constraint_top_k_impl (cur_p, ctx->k );
461
459
},
462
460
/* .reset = */ nullptr ,
463
461
/* .copy = */ [](const struct llama_constraint * cnstr) {
464
462
const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx ;
465
- return llama_constraint_init_top_k_impl (ctx->k , ctx-> min_keep );
463
+ return llama_constraint_init_top_k_impl (ctx->k );
466
464
},
467
465
/* .free = */ [](struct llama_constraint * cnstr) {
468
466
delete (llama_constraint_context_top_k *) cnstr->ctx ;
469
467
},
470
468
};
471
469
472
- struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep ) {
470
+ struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k) {
473
471
return new llama_constraint {
474
472
/* .iface = */ &llama_constraint_top_k_i,
475
473
/* .ctx = */ new llama_constraint_context_top_k {
476
- /* .k =*/ k,
477
- /* .min_keep =*/ min_keep,
474
+ /* .k = */ k,
478
475
},
479
476
};
480
477
}
@@ -507,8 +504,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k
507
504
return new llama_constraint {
508
505
/* .iface = */ &llama_constraint_top_p_i,
509
506
/* .ctx = */ new llama_constraint_context_top_p {
510
- /* .p =*/ p,
511
- /* .min_keep =*/ min_keep,
507
+ /* .p = */ p,
508
+ /* .min_keep = */ min_keep,
512
509
},
513
510
};
514
511
}
@@ -541,8 +538,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k
541
538
return new llama_constraint {
542
539
/* .iface = */ &llama_constraint_min_p_i,
543
540
/* .ctx = */ new llama_constraint_context_min_p {
544
- /* .p =*/ p,
545
- /* .min_keep =*/ min_keep,
541
+ /* .p = */ p,
542
+ /* .min_keep = */ min_keep,
546
543
},
547
544
};
548
545
}
@@ -575,8 +572,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m
575
572
return new llama_constraint {
576
573
/* .iface = */ &llama_constraint_tail_free_i,
577
574
/* .ctx = */ new llama_constraint_context_tail_free {
578
- /* .z =*/ z,
579
- /* .min_keep =*/ min_keep,
575
+ /* .z = */ z,
576
+ /* . min_keep = */ min_keep,
580
577
},
581
578
};
582
579
}
@@ -609,8 +606,8 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min
609
606
return new llama_constraint {
610
607
/* .iface = */ &llama_constraint_typical_i,
611
608
/* .ctx = */ new llama_constraint_context_typical {
612
- /* .p =*/ p,
613
- /* .min_keep =*/ min_keep,
609
+ /* .p = */ p,
610
+ /* .min_keep = */ min_keep,
614
611
},
615
612
};
616
613
}
@@ -642,7 +639,7 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
642
639
return new llama_constraint {
643
640
/* .iface = */ &llama_constraint_temp_i,
644
641
/* .ctx = */ new llama_constraint_context_temp {
645
- /* .temp =*/ temp,
642
+ /* .temp = */ temp,
646
643
},
647
644
};
648
645
}
@@ -683,9 +680,9 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
683
680
return new llama_constraint {
684
681
/* .iface = */ &llama_constraint_temp_ext_i,
685
682
/* .ctx = */ new llama_constraint_context_temp_ext {
686
- /* .temp =*/ temp,
687
- /* .delta =*/ delta,
688
- /* .exponent =*/ exponent,
683
+ /* .temp = */ temp,
684
+ /* .delta = */ delta,
685
+ /* .exponent = */ exponent,
689
686
},
690
687
};
691
688
}
@@ -745,7 +742,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
745
742
float epsilon_hat = s_hat - 1 ;
746
743
float k = powf ((epsilon_hat * powf (2 , ctx->mu )) / (1 - powf (ctx->vocab ->n_vocab , -epsilon_hat)), 1 / s_hat);
747
744
748
- llama_constraint_top_k_impl (cur_p, int (k), 1 );
745
+ llama_constraint_top_k_impl (cur_p, std::max ( int (k), 1 ) );
749
746
750
747
// remember the order to be able to compute the distance later when accepting the token
751
748
ctx->cur .resize (cur_p->size );
@@ -755,7 +752,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
755
752
},
756
753
/* .reset = */ [](struct llama_constraint * cnstr) {
757
754
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx ;
758
- ctx->mu = 0 .0f ;
755
+ ctx->mu = 2 .0f *ctx-> tau ;
759
756
},
760
757
/* .copy = */ [](const struct llama_constraint * cnstr) {
761
758
const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx ;
@@ -770,12 +767,12 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama
770
767
return new llama_constraint {
771
768
/* .iface = */ &llama_constraint_mirostat_i,
772
769
/* .ctx = */ new llama_constraint_context_mirostat {
773
- /* .vocab =*/ &vocab,
774
- /* .tau =*/ tau,
775
- /* .eta =*/ eta,
776
- /* .m =*/ m,
777
- /* .mu =*/ 0 .0f ,
778
- /* .cur =*/ {},
770
+ /* .vocab = */ &vocab,
771
+ /* .tau = */ tau,
772
+ /* .eta = */ eta,
773
+ /* .m = */ m,
774
+ /* .mu = */ 2 .0f *tau ,
775
+ /* .cur = */ {},
779
776
},
780
777
};
781
778
}
@@ -826,10 +823,16 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
826
823
827
824
// Normalize the probabilities of the remaining words
828
825
llama_constraint_softmax_impl (cur_p);
826
+
827
+ // remember the order to be able to compute the distance later when accepting the token
828
+ ctx->cur .resize (cur_p->size );
829
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
830
+ ctx->cur [i] = cur_p->data [i];
831
+ }
829
832
},
830
833
/* .reset = */ [](struct llama_constraint * cnstr) {
831
834
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx ;
832
- ctx->mu = 0 .0f ;
835
+ ctx->mu = 2 .0f *ctx-> tau ;
833
836
},
834
837
/* .copy = */ [](const struct llama_constraint * cnstr) {
835
838
const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx ;
@@ -844,10 +847,10 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
844
847
return new llama_constraint {
845
848
/* .iface = */ &llama_constraint_mirostat_v2_i,
846
849
/* .ctx = */ new llama_constraint_context_mirostat_v2 {
847
- /* .tau =*/ tau,
848
- /* .eta =*/ eta,
849
- /* .mu =*/ 0 .0f ,
850
- /* .cur =*/ {},
850
+ /* .tau = */ tau,
851
+ /* .eta = */ eta,
852
+ /* .mu = */ 2 .0f *tau ,
853
+ /* .cur = */ {},
851
854
},
852
855
};
853
856
}
@@ -919,17 +922,17 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
919
922
920
923
if (grammar_str != nullptr && grammar_str[0 ] != ' \0 ' ) {
921
924
*ctx = {
922
- /* .vocab = */ &vocab,
923
- /* .grammar_str = */ grammar_str,
924
- /* .grammar_root = */ grammar_root,
925
- /* .grammar = */ llama_grammar_init_impl (&vocab, grammar_str, grammar_root),
925
+ /* .vocab = */ &vocab,
926
+ /* .grammar_str = */ grammar_str,
927
+ /* .grammar_root = */ grammar_root,
928
+ /* .grammar = */ llama_grammar_init_impl (&vocab, grammar_str, grammar_root),
926
929
};
927
930
} else {
928
931
*ctx = {
929
- /* .vocab = */ &vocab,
930
- /* .grammar_str = */ {},
931
- /* .grammar_root = */ {},
932
- /* .grammar = */ nullptr ,
932
+ /* .vocab = */ &vocab,
933
+ /* .grammar_str = */ {},
934
+ /* .grammar_root = */ {},
935
+ /* .grammar = */ nullptr ,
933
936
};
934
937
}
935
938
@@ -1023,14 +1026,14 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
1023
1026
return new llama_constraint {
1024
1027
/* .iface = */ &llama_constraint_penalties_i,
1025
1028
/* .ctx = */ new llama_constraint_context_penalties {
1026
- /* .vocab =*/ &vocab,
1027
- /* .penalty_last_n =*/ penalty_last_n,
1028
- /* .penalty_repeat =*/ penalty_repeat,
1029
- /* .penalty_freq =*/ penalty_freq,
1030
- /* .penalty_present =*/ penalty_present,
1031
- /* .penalize_nl =*/ penalize_nl,
1032
- /* .ignore_eos =*/ ignore_eos,
1033
- /* .prev =*/ ring_buffer<llama_token>(penalty_last_n),
1029
+ /* .vocab = */ &vocab,
1030
+ /* .penalty_last_n = */ penalty_last_n,
1031
+ /* .penalty_repeat = */ penalty_repeat,
1032
+ /* .penalty_freq = */ penalty_freq,
1033
+ /* .penalty_present = */ penalty_present,
1034
+ /* .penalize_nl = */ penalize_nl,
1035
+ /* .ignore_eos = */ ignore_eos,
1036
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
1034
1037
},
1035
1038
};
1036
1039
}
@@ -1072,8 +1075,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl(
1072
1075
return new llama_constraint {
1073
1076
/* .iface = */ &llama_constraint_logit_bias_i,
1074
1077
/* .ctx = */ new llama_constraint_context_logit_bias {
1075
- /* .vocab = */ &vocab,
1076
- /* .logit_bias= */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
1078
+ /* .vocab = */ &vocab,
1079
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
1077
1080
},
1078
1081
};
1079
1082
}
0 commit comments