Skip to content

Commit b85edd7

Browse files
committed
sampling : remove top-k min_keep, fix mirostat init and state
1 parent 935a4d0 commit b85edd7

File tree

8 files changed

+76
-73
lines changed

8 files changed

+76
-73
lines changed

common/sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
6767
for (const auto & cnstr : params.constraints) {
6868
switch (cnstr) {
6969
case GPT_CONSTRAINT_TYPE_TOP_K:
70-
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
70+
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k));
7171
break;
7272
case GPT_CONSTRAINT_TYPE_TOP_P:
7373
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ defer {
6161
llama_sampler_free(smpl)
6262
}
6363

64-
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1));
64+
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40));
6565
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
6666
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));
6767

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ int main(int argc, char ** argv) {
7070

7171
llama_sampler * smpl = llama_sampler_init(model, sparams);
7272

73-
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
73+
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k));
7474
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
7575
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));
7676

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ extern "C" {
10431043
};
10441044

10451045
LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void);
1046-
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
1046+
LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k);
10471047
LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
10481048
LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
10491049
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep);

src/llama-sampling.cpp

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) {
4949
}
5050
}
5151

52-
static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k, size_t min_keep) {
52+
static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
5353
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
5454
// if (k >= (int32_t)cur_p->size) {
5555
// return;
@@ -59,7 +59,6 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t
5959
k = cur_p->size;
6060
}
6161

62-
k = std::max(k, (int) min_keep);
6362
k = std::min(k, (int) cur_p->size);
6463

6564
// Sort scores in descending order
@@ -449,32 +448,30 @@ struct llama_constraint * llama_constraint_init_softmax_impl() {
449448

450449
struct llama_constraint_context_top_k {
451450
const int32_t k;
452-
const size_t min_keep;
453451
};
454452

455453
static struct llama_constraint_i llama_constraint_top_k_i = {
456454
/* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-k"; },
457455
/* .accept = */ nullptr,
458456
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) {
459457
const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx;
460-
llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep);
458+
llama_constraint_top_k_impl(cur_p, ctx->k);
461459
},
462460
/* .reset = */ nullptr,
463461
/* .copy = */ [](const struct llama_constraint * cnstr) {
464462
const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx;
465-
return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep);
463+
return llama_constraint_init_top_k_impl(ctx->k);
466464
},
467465
/* .free = */ [](struct llama_constraint * cnstr) {
468466
delete (llama_constraint_context_top_k *) cnstr->ctx;
469467
},
470468
};
471469

472-
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) {
470+
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) {
473471
return new llama_constraint {
474472
/* .iface = */ &llama_constraint_top_k_i,
475473
/* .ctx = */ new llama_constraint_context_top_k {
476-
/*.k =*/ k,
477-
/*.min_keep =*/ min_keep,
474+
/* .k = */ k,
478475
},
479476
};
480477
}
@@ -507,8 +504,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k
507504
return new llama_constraint {
508505
/* .iface = */ &llama_constraint_top_p_i,
509506
/* .ctx = */ new llama_constraint_context_top_p {
510-
/*.p =*/ p,
511-
/*.min_keep =*/ min_keep,
507+
/* .p = */ p,
508+
/* .min_keep = */ min_keep,
512509
},
513510
};
514511
}
@@ -541,8 +538,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k
541538
return new llama_constraint {
542539
/* .iface = */ &llama_constraint_min_p_i,
543540
/* .ctx = */ new llama_constraint_context_min_p {
544-
/*.p =*/ p,
545-
/*.min_keep =*/ min_keep,
541+
/* .p = */ p,
542+
/* .min_keep = */ min_keep,
546543
},
547544
};
548545
}
@@ -575,8 +572,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m
575572
return new llama_constraint {
576573
/* .iface = */ &llama_constraint_tail_free_i,
577574
/* .ctx = */ new llama_constraint_context_tail_free {
578-
/*.z =*/ z,
579-
/*.min_keep =*/ min_keep,
575+
/* .z = */ z,
576+
/*. min_keep = */ min_keep,
580577
},
581578
};
582579
}
@@ -609,8 +606,8 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min
609606
return new llama_constraint {
610607
/* .iface = */ &llama_constraint_typical_i,
611608
/* .ctx = */ new llama_constraint_context_typical {
612-
/*.p =*/ p,
613-
/*.min_keep =*/ min_keep,
609+
/* .p = */ p,
610+
/* .min_keep = */ min_keep,
614611
},
615612
};
616613
}
@@ -642,7 +639,7 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
642639
return new llama_constraint {
643640
/* .iface = */ &llama_constraint_temp_i,
644641
/* .ctx = */ new llama_constraint_context_temp {
645-
/*.temp =*/ temp,
642+
/*.temp = */ temp,
646643
},
647644
};
648645
}
@@ -683,9 +680,9 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
683680
return new llama_constraint {
684681
/* .iface = */ &llama_constraint_temp_ext_i,
685682
/* .ctx = */ new llama_constraint_context_temp_ext {
686-
/*.temp =*/ temp,
687-
/*.delta =*/ delta,
688-
/*.exponent =*/ exponent,
683+
/* .temp = */ temp,
684+
/* .delta = */ delta,
685+
/* .exponent = */ exponent,
689686
},
690687
};
691688
}
@@ -745,7 +742,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
745742
float epsilon_hat = s_hat - 1;
746743
float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat);
747744

748-
llama_constraint_top_k_impl(cur_p, int(k), 1);
745+
llama_constraint_top_k_impl(cur_p, std::max(int(k), 1));
749746

750747
// remember the order to be able to compute the distance later when accepting the token
751748
ctx->cur.resize(cur_p->size);
@@ -755,7 +752,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
755752
},
756753
/* .reset = */ [](struct llama_constraint * cnstr) {
757754
auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx;
758-
ctx->mu = 0.0f;
755+
ctx->mu = 2.0f*ctx->tau;
759756
},
760757
/* .copy = */ [](const struct llama_constraint * cnstr) {
761758
const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx;
@@ -770,12 +767,12 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama
770767
return new llama_constraint {
771768
/* .iface = */ &llama_constraint_mirostat_i,
772769
/* .ctx = */ new llama_constraint_context_mirostat {
773-
/*.vocab =*/ &vocab,
774-
/*.tau =*/ tau,
775-
/*.eta =*/ eta,
776-
/*.m =*/ m,
777-
/*.mu =*/ 0.0f,
778-
/*.cur =*/ {},
770+
/* .vocab = */ &vocab,
771+
/* .tau = */ tau,
772+
/* .eta = */ eta,
773+
/* .m = */ m,
774+
/* .mu = */ 2.0f*tau,
775+
/* .cur = */ {},
779776
},
780777
};
781778
}
@@ -826,10 +823,16 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
826823

827824
// Normalize the probabilities of the remaining words
828825
llama_constraint_softmax_impl(cur_p);
826+
827+
// remember the order to be able to compute the distance later when accepting the token
828+
ctx->cur.resize(cur_p->size);
829+
for (size_t i = 0; i < cur_p->size; ++i) {
830+
ctx->cur[i] = cur_p->data[i];
831+
}
829832
},
830833
/* .reset = */ [](struct llama_constraint * cnstr) {
831834
auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx;
832-
ctx->mu = 0.0f;
835+
ctx->mu = 2.0f*ctx->tau;
833836
},
834837
/* .copy = */ [](const struct llama_constraint * cnstr) {
835838
const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx;
@@ -844,10 +847,10 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
844847
return new llama_constraint {
845848
/* .iface = */ &llama_constraint_mirostat_v2_i,
846849
/* .ctx = */ new llama_constraint_context_mirostat_v2 {
847-
/*.tau =*/ tau,
848-
/*.eta =*/ eta,
849-
/*.mu =*/ 0.0f,
850-
/*.cur =*/ {},
850+
/* .tau = */ tau,
851+
/* .eta = */ eta,
852+
/* .mu = */ 2.0f*tau,
853+
/* .cur = */ {},
851854
},
852855
};
853856
}
@@ -919,17 +922,17 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
919922

920923
if (grammar_str != nullptr && grammar_str[0] != '\0') {
921924
*ctx = {
922-
/*.vocab = */ &vocab,
923-
/*.grammar_str = */ grammar_str,
924-
/*.grammar_root = */ grammar_root,
925-
/*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
925+
/* .vocab = */ &vocab,
926+
/* .grammar_str = */ grammar_str,
927+
/* .grammar_root = */ grammar_root,
928+
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
926929
};
927930
} else {
928931
*ctx = {
929-
/*.vocab = */ &vocab,
930-
/*.grammar_str = */ {},
931-
/*.grammar_root = */ {},
932-
/*.grammar = */ nullptr,
932+
/* .vocab = */ &vocab,
933+
/* .grammar_str = */ {},
934+
/* .grammar_root = */ {},
935+
/* .grammar = */ nullptr,
933936
};
934937
}
935938

@@ -1023,14 +1026,14 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
10231026
return new llama_constraint {
10241027
/* .iface = */ &llama_constraint_penalties_i,
10251028
/* .ctx = */ new llama_constraint_context_penalties {
1026-
/*.vocab =*/ &vocab,
1027-
/*.penalty_last_n =*/ penalty_last_n,
1028-
/*.penalty_repeat =*/ penalty_repeat,
1029-
/*.penalty_freq =*/ penalty_freq,
1030-
/*.penalty_present =*/ penalty_present,
1031-
/*.penalize_nl =*/ penalize_nl,
1032-
/*.ignore_eos =*/ ignore_eos,
1033-
/*.prev =*/ ring_buffer<llama_token>(penalty_last_n),
1029+
/* .vocab = */ &vocab,
1030+
/* .penalty_last_n = */ penalty_last_n,
1031+
/* .penalty_repeat = */ penalty_repeat,
1032+
/* .penalty_freq = */ penalty_freq,
1033+
/* .penalty_present = */ penalty_present,
1034+
/* .penalize_nl = */ penalize_nl,
1035+
/* .ignore_eos = */ ignore_eos,
1036+
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
10341037
},
10351038
};
10361039
}
@@ -1072,8 +1075,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl(
10721075
return new llama_constraint {
10731076
/* .iface = */ &llama_constraint_logit_bias_i,
10741077
/* .ctx = */ new llama_constraint_context_logit_bias {
1075-
/*.vocab =*/ &vocab,
1076-
/*.logit_bias=*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
1078+
/* .vocab = */ &vocab,
1079+
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
10771080
},
10781081
};
10791082
}

src/llama-sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void llama_constraint_penalties_impl(
2121
// constraints
2222

2323
struct llama_constraint * llama_constraint_init_softmax_impl ();
24-
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
24+
struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k);
2525
struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep);
2626
struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep);
2727
struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep);

src/llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20581,8 +20581,8 @@ struct llama_constraint * llama_constraint_init_softmax(void) {
2058120581
return llama_constraint_init_softmax_impl();
2058220582
}
2058320583

20584-
struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) {
20585-
return llama_constraint_init_top_k_impl(k, min_keep);
20584+
struct llama_constraint * llama_constraint_init_top_k(int32_t k) {
20585+
return llama_constraint_init_top_k_impl(k);
2058620586
}
2058720587

2058820588
struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) {

0 commit comments

Comments
 (0)