Skip to content

Commit acd5b84

Browse files
committed
sampling : add name API + option to disable timings
1 parent e3396f3 commit acd5b84

File tree

5 files changed

+33
-14
lines changed

5 files changed

+33
-14
lines changed

common/sampling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
3131

3232
for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) {
3333
const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i);
34-
result += " -> " + std::string(cnstr->iface->name(cnstr)) + " ";
34+
result += std::string(" -> ") + llama_constraint_name(cnstr) + " ";
3535
}
3636

3737
return result;

include/llama.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ extern "C" {
379379

380380
// TODO: will be used by the llama_decode_with_sampler() API in the future
381381
enum llama_sampler_type type;
382+
383+
bool no_timing; // whether to measure performance timings
382384
} llama_sampler_params;
383385

384386
// performance timing information
@@ -1095,9 +1097,10 @@ extern "C" {
10951097
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
10961098
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
10971099

1098-
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
1099-
LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * cur_p);
1100-
LLAMA_API void llama_constraint_reset (struct llama_constraint * cnstr);
1100+
LLAMA_API const char * llama_constraint_name (const struct llama_constraint * cnstr);
1101+
LLAMA_API void llama_constraint_accept( struct llama_constraint * cnstr, llama_token token);
1102+
LLAMA_API void llama_constraint_apply ( struct llama_constraint * cnstr, llama_token_data_array * cur_p);
1103+
LLAMA_API void llama_constraint_reset ( struct llama_constraint * cnstr);
11011104

11021105
// samplers
11031106

src/llama-sampling.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,6 +1190,14 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
11901190
// TODO: should we reset the timings?
11911191
}
11921192

1193+
const char * llama_constraint_name_impl(const struct llama_constraint & cnstr) {
1194+
if (!cnstr.iface) {
1195+
return "(null)";
1196+
}
1197+
1198+
return cnstr.iface->name(&cnstr);
1199+
}
1200+
11931201
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
11941202
smpl.prev.push_back(token);
11951203

src/llama-sampling.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ struct llama_constraint * llama_constraint_clone_impl(const struct llama_constra
6262

6363
void llama_constraint_free_impl(struct llama_constraint * cnstr);
6464

65-
void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token);
66-
void llama_constraint_apply_impl (struct llama_constraint & cnstr, struct llama_token_data_array * cur_p);
67-
void llama_constraint_reset_impl (struct llama_constraint & cnstr);
65+
const char * llama_constraint_name_impl (const struct llama_constraint & cnstr);
66+
void llama_constraint_accept_impl( struct llama_constraint & cnstr, llama_token token);
67+
void llama_constraint_apply_impl ( struct llama_constraint & cnstr, struct llama_token_data_array * cur_p);
68+
void llama_constraint_reset_impl ( struct llama_constraint & cnstr);
6869

6970
// samplers
7071

src/llama.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,12 @@ static void zeros(std::ofstream & file, size_t n) {
148148
}
149149

150150
struct time_meas {
151-
time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {}
151+
time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
152152

153153
~time_meas() {
154-
t_acc += ggml_time_us() - t_start_us;
154+
if (t_start_us >= 0) {
155+
t_acc += ggml_time_us() - t_start_us;
156+
}
155157
}
156158

157159
const int64_t t_start_us;
@@ -17908,9 +17910,10 @@ struct llama_context_params llama_context_default_params() {
1790817910

1790917911
struct llama_sampler_params llama_sampler_default_params() {
1791017912
struct llama_sampler_params result = {
17911-
/*.seed =*/ LLAMA_DEFAULT_SEED,
17912-
/*.n_prev =*/ 256,
17913-
/*.type =*/ LLAMA_SAMPLER_TYPE_DIST,
17913+
/*.seed =*/ LLAMA_DEFAULT_SEED,
17914+
/*.n_prev =*/ 256,
17915+
/*.type =*/ LLAMA_SAMPLER_TYPE_DIST,
17916+
/*.no_timing =*/ false, // TODO: change to true and set explicitly in examples
1791417917
};
1791517918

1791617919
return result;
@@ -20651,6 +20654,10 @@ void llama_constraint_free(struct llama_constraint * cnstr) {
2065120654
llama_constraint_free_impl(cnstr);
2065220655
}
2065320656

20657+
const char * llama_constraint_name(const struct llama_constraint * cnstr) {
20658+
return llama_constraint_name_impl(*cnstr);
20659+
}
20660+
2065420661
void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) {
2065520662
llama_constraint_accept_impl(*cnstr, token);
2065620663
}
@@ -20688,7 +20695,7 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
2068820695
}
2068920696

2069020697
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
20691-
time_meas tm(smpl->t_sample_us);
20698+
time_meas tm(smpl->t_sample_us, smpl->params.no_timing);
2069220699

2069320700
if (cur_p == nullptr) {
2069420701
cur_p = &smpl->cur_p;
@@ -20726,7 +20733,7 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample
2072620733
}
2072720734

2072820735
llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
20729-
time_meas tm(smpl->t_sample_us);
20736+
time_meas tm(smpl->t_sample_us, smpl->params.no_timing);
2073020737

2073120738
if (cur_p == nullptr) {
2073220739
cur_p = &smpl->cur_p;

0 commit comments

Comments
 (0)