Skip to content

Commit dae3cae

Browse files
committed
llama : suffix the internal APIs with "_impl"
ggml-ci
1 parent 39fbaf9 commit dae3cae

File tree

7 files changed

+181
-168
lines changed

7 files changed

+181
-168
lines changed

src/llama-grammar.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
464464
return result;
465465
}
466466

467-
void llama_grammar_sample(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
467+
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468468
GGML_ASSERT(grammar);
469469
GGML_ASSERT(vocab);
470470

@@ -488,7 +488,7 @@ void llama_grammar_sample(const struct llama_grammar * grammar, const struct lla
488488
const llama_token id = candidates->data[i].id;
489489
const std::string & piece = vocab->cache_token_to_piece.at(id);
490490

491-
if (llama_token_is_eog(*vocab, id)) {
491+
if (llama_token_is_eog_impl(*vocab, id)) {
492492
if (!allow_eog) {
493493
candidates->data[i].logit = -INFINITY;
494494
}
@@ -508,10 +508,10 @@ void llama_grammar_sample(const struct llama_grammar * grammar, const struct lla
508508
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
509509
}
510510

511-
void llama_grammar_accept_token(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
511+
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
512512
const int64_t t_start_sample_us = ggml_time_us();
513513

514-
if (llama_token_is_eog(*vocab, token)) {
514+
if (llama_token_is_eog_impl(*vocab, token)) {
515515
for (const auto & stack : grammar->stacks) {
516516
if (stack.empty()) {
517517
return;

src/llama-grammar.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ struct llama_grammar {
1515

1616
struct llama_grammar * llama_get_grammar(struct llama_context * ctx);
1717

18+
//
19+
// internal API
20+
//
21+
1822
struct llama_grammar * llama_grammar_init_impl(
1923
const llama_grammar_element ** rules,
2024
size_t n_rules,
@@ -24,13 +28,13 @@ void llama_grammar_free_impl(struct llama_grammar * grammar);
2428

2529
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
2630

27-
void llama_grammar_sample(
31+
void llama_grammar_sample_impl(
2832
const struct llama_grammar * grammar,
2933
const struct llama_vocab * vocab,
3034
const struct llama_sampling * smpl,
3135
llama_token_data_array * candidates);
3236

33-
void llama_grammar_accept_token(
37+
void llama_grammar_accept_token_impl(
3438
struct llama_grammar * grammar,
3539
const struct llama_vocab * vocab,
3640
const struct llama_sampling * smpl,

src/llama-sampling.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ static void llama_log_softmax(float * array, size_t size) {
2121
}
2222
}
2323

24-
void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
24+
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
2525
if (seed == LLAMA_DEFAULT_SEED) {
2626
seed = time(NULL);
2727
}
2828

2929
smpl->rng.seed(seed);
3030
}
3131

32-
void llama_sample_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
32+
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
3333
GGML_ASSERT(candidates->size > 0);
3434

3535
const int64_t t_start_sample_us = ggml_time_us();
@@ -58,7 +58,7 @@ void llama_sample_softmax(struct llama_sampling * smpl, llama_token_data_array *
5858
}
5959
}
6060

61-
void llama_sample_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
61+
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
6262
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
6363
// if (k >= (int32_t)candidates->size) {
6464
// return;
@@ -139,12 +139,12 @@ void llama_sample_top_k(struct llama_sampling * smpl, llama_token_data_array * c
139139
}
140140
}
141141

142-
void llama_sample_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
142+
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
143143
if (p >= 1.0f) {
144144
return;
145145
}
146146

147-
llama_sample_softmax(smpl, candidates);
147+
llama_sample_softmax_impl(smpl, candidates);
148148

149149
const int64_t t_start_sample_us = ggml_time_us();
150150

@@ -171,7 +171,7 @@ void llama_sample_top_p(struct llama_sampling * smpl, llama_token_data_array * c
171171
}
172172
}
173173

174-
void llama_sample_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
174+
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
175175
if (p <= 0.0f || !candidates->size) {
176176
return;
177177
}
@@ -232,12 +232,12 @@ void llama_sample_min_p(struct llama_sampling * smpl, llama_token_data_array * c
232232
}
233233
}
234234

235-
void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
235+
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
236236
if (z >= 1.0f || candidates->size <= 2) {
237237
return;
238238
}
239239

240-
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
240+
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
241241
const int64_t t_start_sample_us = ggml_time_us();
242242

243243
// Compute the first and second derivatives
@@ -291,15 +291,15 @@ void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array
291291
}
292292
}
293293

294-
void llama_sample_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
294+
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
295295
// Reference implementation:
296296
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
297297
if (p >= 1.0f) {
298298
return;
299299
}
300300

301301
// Compute the softmax of logits and calculate entropy
302-
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
302+
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
303303

304304
const int64_t t_start_sample_us = ggml_time_us();
305305

@@ -355,7 +355,7 @@ void llama_sample_typical(struct llama_sampling * smpl, llama_token_data_array *
355355
}
356356
}
357357

358-
void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
358+
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
359359
const int64_t t_start_sample_us = ggml_time_us();
360360

361361
// no need to do anything if there is only one (or zero) candidates
@@ -366,7 +366,7 @@ void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array *
366366
// Calculate maximum possible entropy
367367
float max_entropy = -logf(1.0f / candidates->size);
368368

369-
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
369+
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
370370

371371
// Calculate entropy of the softmax probabilities
372372
float entropy = 0.0f;
@@ -422,7 +422,7 @@ void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array *
422422
}
423423
}
424424

425-
void llama_sample_temp(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
425+
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
426426
const int64_t t_start_sample_us = ggml_time_us();
427427

428428
for (size_t i = 0; i < candidates->size; ++i) {
@@ -434,7 +434,7 @@ void llama_sample_temp(struct llama_sampling * smpl, llama_token_data_array * ca
434434
}
435435
}
436436

437-
void llama_sample_repetition_penalties(
437+
void llama_sample_repetition_penalties_impl(
438438
struct llama_sampling * smpl,
439439
llama_token_data_array * candidates,
440440
const llama_token * last_tokens,
@@ -481,7 +481,7 @@ void llama_sample_repetition_penalties(
481481
}
482482
}
483483

484-
void llama_sample_apply_guidance(
484+
void llama_sample_apply_guidance_impl(
485485
struct llama_sampling * smpl,
486486
float * logits,
487487
float * logits_guidance,
@@ -504,14 +504,14 @@ void llama_sample_apply_guidance(
504504
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
505505
}
506506

507-
llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
507+
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
508508
GGML_ASSERT(smpl);
509509

510510
const int32_t n_vocab = float(smpl->n_vocab);
511511

512512
int64_t t_start_sample_us = ggml_time_us();
513513

514-
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
514+
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
515515

516516
// Estimate s_hat using the most probable m tokens
517517
float s_hat = 0.0;
@@ -530,9 +530,9 @@ llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_toke
530530
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
531531

532532
// Sample the next word X using top-k sampling
533-
llama_sample_top_k((struct llama_sampling *) nullptr, candidates, int(k), 1);
533+
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
534534
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
535-
llama_token X = llama_sample_token(smpl, candidates);
535+
llama_token X = llama_sample_token_impl(smpl, candidates);
536536
t_start_sample_us = ggml_time_us();
537537

538538
// Compute error as the difference between observed surprise and target surprise value
@@ -549,11 +549,11 @@ llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_toke
549549
return X;
550550
}
551551

552-
llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
552+
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
553553
int64_t t_start_sample_us;
554554
t_start_sample_us = ggml_time_us();
555555

556-
llama_sample_softmax(smpl, candidates);
556+
llama_sample_softmax_impl(smpl, candidates);
557557

558558
// Truncate the words with surprise values greater than mu
559559
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
@@ -569,10 +569,10 @@ llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_t
569569
}
570570

571571
// Normalize the probabilities of the remaining words
572-
llama_sample_softmax(smpl, candidates);
572+
llama_sample_softmax_impl(smpl, candidates);
573573

574574
// Sample the next word X from the remaining words
575-
llama_token X = llama_sample_token(smpl, candidates);
575+
llama_token X = llama_sample_token_impl(smpl, candidates);
576576
t_start_sample_us = ggml_time_us();
577577

578578
// Compute error as the difference between observed surprise and target surprise value
@@ -591,7 +591,7 @@ llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_t
591591
return X;
592592
}
593593

594-
llama_token llama_sample_token_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
594+
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
595595
const int64_t t_start_sample_us = ggml_time_us();
596596

597597
// Find max element
@@ -607,11 +607,11 @@ llama_token llama_sample_token_greedy(struct llama_sampling * smpl, llama_token_
607607
return result;
608608
}
609609

610-
llama_token llama_sample_token_with_rng(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
610+
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
611611
GGML_ASSERT(smpl);
612612

613613
const int64_t t_start_sample_us = ggml_time_us();
614-
llama_sample_softmax((struct llama_sampling *) nullptr, candidates);
614+
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
615615

616616
std::vector<float> probs;
617617
probs.reserve(candidates->size);
@@ -630,6 +630,6 @@ llama_token llama_sample_token_with_rng(struct llama_sampling * smpl, llama_toke
630630
return result;
631631
}
632632

633-
llama_token llama_sample_token(struct llama_sampling * smpl, llama_token_data_array * candidates) {
634-
return llama_sample_token_with_rng(smpl, candidates, smpl->rng);
633+
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
634+
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
635635
}

src/llama-sampling.h

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ struct llama_sampling {
2020

2121
struct llama_sampling * llama_get_sampling(struct llama_context * ctx);
2222

23-
void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed);
24-
25-
void llama_sample_softmax (struct llama_sampling * smpl, llama_token_data_array * candidates);
26-
void llama_sample_top_k (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
27-
void llama_sample_top_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
28-
void llama_sample_min_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
29-
void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
30-
void llama_sample_typical (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
31-
void llama_sample_entropy (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
32-
void llama_sample_temp (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
33-
34-
void llama_sample_repetition_penalties(
23+
//
24+
// internal API
25+
//
26+
27+
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
28+
29+
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
30+
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
31+
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
32+
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
33+
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
34+
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
35+
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
36+
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
37+
38+
void llama_sample_repetition_penalties_impl(
3539
struct llama_sampling * smpl,
3640
llama_token_data_array * candidates,
3741
const llama_token * last_tokens,
@@ -40,15 +44,15 @@ void llama_sample_repetition_penalties(
4044
float penalty_freq,
4145
float penalty_present);
4246

43-
void llama_sample_apply_guidance(
47+
void llama_sample_apply_guidance_impl(
4448
struct llama_sampling * smpl,
4549
float * logits,
4650
float * logits_guidance,
4751
float scale);
4852

49-
llama_token llama_sample_token_mirostat (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
50-
llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
51-
llama_token llama_sample_token_greedy (struct llama_sampling * smpl, llama_token_data_array * candidates);
52-
llama_token llama_sample_token_with_rng (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
53-
llama_token llama_sample_token (struct llama_sampling * smpl, llama_token_data_array * candidates);
53+
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
54+
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
55+
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
56+
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
57+
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
5458

0 commit comments

Comments
 (0)