Skip to content

Commit 0f50960

Browse files
z80maniacggerganov
authored andcommitted
common : use enums for sampler types (ggml-org#5418)
* common: use enums for sampler types * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * minor : spaces --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 1d4fbd8 commit 0f50960

File tree

4 files changed

+122
-57
lines changed

4 files changed

+122
-57
lines changed

common/common.cpp

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
340340
invalid_param = true;
341341
break;
342342
}
343-
sparams.samplers_sequence = parse_samplers_input(argv[i]);
343+
const auto sampler_names = string_split(argv[i], ';');
344+
sparams.samplers_sequence = sampler_types_from_names(sampler_names);
344345
} else if (arg == "--sampling-seq") {
345346
if (++i >= argc) {
346347
invalid_param = true;
347348
break;
348349
}
349-
sparams.samplers_sequence = argv[i];
350+
sparams.samplers_sequence = sampler_types_from_chars(argv[i]);
350351
} else if (arg == "--top-p") {
351352
if (++i >= argc) {
352353
invalid_param = true;
@@ -906,6 +907,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
906907
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
907908
const llama_sampling_params & sparams = params.sparams;
908909

910+
std::string sampler_type_chars;
911+
std::string sampler_type_names;
912+
for (const auto sampler_type : sparams.samplers_sequence) {
913+
sampler_type_chars += static_cast<char>(sampler_type);
914+
sampler_type_names += sampler_type_to_name_string(sampler_type) + ";";
915+
}
916+
sampler_type_names.pop_back();
917+
909918
printf("\n");
910919
printf("usage: %s [options]\n", argv[0]);
911920
printf("\n");
@@ -947,8 +956,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
947956
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
948957
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
949958
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
950-
printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
951-
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
959+
printf(" --samplers samplers that will be used for generation in the order, separated by \';\' (default: %s)\n", sampler_type_names.c_str());
960+
printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str());
952961
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
953962
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
954963
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
@@ -1097,45 +1106,85 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
10971106
}
10981107

10991108
//
1100-
// String parsing
1109+
// String utils
11011110
//
11021111

1103-
std::string parse_samplers_input(std::string input) {
1104-
std::string output = "";
1112+
std::vector<std::string> string_split(std::string input, char separator) {
1113+
std::vector<std::string> parts;
1114+
size_t separator_pos = input.find(separator);
1115+
while (separator_pos != std::string::npos) {
1116+
std::string part = input.substr(0, separator_pos);
1117+
parts.emplace_back(part);
1118+
input = input.substr(separator_pos + 1);
1119+
separator_pos = input.find(separator);
1120+
}
1121+
parts.emplace_back(input);
1122+
return parts;
1123+
}
1124+
1125+
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names) {
11051126
// since samplers names are written multiple ways
11061127
// make it ready for both system names and input names
1107-
std::unordered_map<std::string, char> samplers_symbols {
1108-
{"top_k", 'k'},
1109-
{"top-k", 'k'},
1110-
{"top_p", 'p'},
1111-
{"top-p", 'p'},
1112-
{"nucleus", 'p'},
1113-
{"typical_p", 'y'},
1114-
{"typical-p", 'y'},
1115-
{"typical", 'y'},
1116-
{"min_p", 'm'},
1117-
{"min-p", 'm'},
1118-
{"tfs_z", 'f'},
1119-
{"tfs-z", 'f'},
1120-
{"tfs", 'f'},
1121-
{"temp", 't'},
1122-
{"temperature",'t'}
1128+
std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
1129+
{"top_k", llama_sampler_type::TOP_K},
1130+
{"top-k", llama_sampler_type::TOP_K},
1131+
{"top_p", llama_sampler_type::TOP_P},
1132+
{"top-p", llama_sampler_type::TOP_P},
1133+
{"nucleus", llama_sampler_type::TOP_P},
1134+
{"typical_p", llama_sampler_type::TYPICAL_P},
1135+
{"typical-p", llama_sampler_type::TYPICAL_P},
1136+
{"typical", llama_sampler_type::TYPICAL_P},
1137+
{"min_p", llama_sampler_type::MIN_P},
1138+
{"min-p", llama_sampler_type::MIN_P},
1139+
{"tfs_z", llama_sampler_type::TFS_Z},
1140+
{"tfs-z", llama_sampler_type::TFS_Z},
1141+
{"tfs", llama_sampler_type::TFS_Z},
1142+
{"temp", llama_sampler_type::TEMP},
1143+
{"temperature", llama_sampler_type::TEMP}
1144+
};
1145+
1146+
std::vector<llama_sampler_type> sampler_types;
1147+
sampler_types.reserve(names.size());
1148+
for (const auto& name : names) {
1149+
const auto sampler_item = sampler_name_map.find(name);
1150+
if (sampler_item != sampler_name_map.end()) {
1151+
sampler_types.push_back(sampler_item->second);
1152+
}
1153+
}
1154+
return sampler_types;
1155+
}
1156+
1157+
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string) {
1158+
std::unordered_map<char, llama_sampler_type> sampler_name_map {
1159+
{'k', llama_sampler_type::TOP_K},
1160+
{'p', llama_sampler_type::TOP_P},
1161+
{'y', llama_sampler_type::TYPICAL_P},
1162+
{'m', llama_sampler_type::MIN_P},
1163+
{'f', llama_sampler_type::TFS_Z},
1164+
{'t', llama_sampler_type::TEMP}
11231165
};
1124-
// expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
1125-
size_t separator = input.find(';');
1126-
while (separator != input.npos) {
1127-
std::string name = input.substr(0,separator);
1128-
input = input.substr(separator+1);
1129-
separator = input.find(';');
1130-
1131-
if (samplers_symbols.find(name) != samplers_symbols.end()) {
1132-
output += samplers_symbols[name];
1166+
1167+
std::vector<llama_sampler_type> sampler_types;
1168+
sampler_types.reserve(names_string.size());
1169+
for (const auto & c : names_string) {
1170+
const auto sampler_item = sampler_name_map.find(c);
1171+
if (sampler_item != sampler_name_map.end()) {
1172+
sampler_types.push_back(sampler_item->second);
11331173
}
11341174
}
1135-
if (samplers_symbols.find(input) != samplers_symbols.end()) {
1136-
output += samplers_symbols[input];
1175+
return sampler_types;
1176+
}
1177+
1178+
std::string sampler_type_to_name_string(llama_sampler_type sampler_type) {
1179+
switch (sampler_type) {
1180+
case llama_sampler_type::TOP_K: return "top_k";
1181+
case llama_sampler_type::TFS_Z: return "tfs_z";
1182+
case llama_sampler_type::TYPICAL_P: return "typical_p";
1183+
case llama_sampler_type::TOP_P: return "top_p";
1184+
case llama_sampler_type::MIN_P: return "min_p";
1185+
case llama_sampler_type::TEMP: return "temp";
1186+
default : return "";
11371187
}
1138-
return output;
11391188
}
11401189

11411190
//

common/common.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,13 @@ std::string gpt_random_prompt(std::mt19937 & rng);
162162
void process_escapes(std::string& input);
163163

164164
//
165-
// String parsing
165+
// String utils
166166
//
167167

168-
std::string parse_samplers_input(std::string input);
168+
std::vector<llama_sampler_type> sampler_types_from_names(const std::vector<std::string> & names);
169+
std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & names_string);
170+
std::vector<std::string> string_split(std::string input, char separator);
171+
std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
169172

170173
//
171174
// Model utils

common/sampling.cpp

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
103103
std::string llama_sampling_order_print(const llama_sampling_params & params) {
104104
std::string result = "CFG -> Penalties ";
105105
if (params.mirostat == 0) {
106-
for (auto s : params.samplers_sequence) {
107-
switch (s) {
108-
case 'k': result += "-> top_k "; break;
109-
case 'f': result += "-> tfs_z "; break;
110-
case 'y': result += "-> typical_p "; break;
111-
case 'p': result += "-> top_p "; break;
112-
case 'm': result += "-> min_p "; break;
113-
case 't': result += "-> temp "; break;
114-
default : break;
106+
for (auto sampler_type : params.samplers_sequence) {
107+
const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
108+
if (!sampler_type_name.empty()) {
109+
result += "-> " + sampler_type_name + " ";
115110
}
116111
}
117112
} else {
@@ -135,16 +130,16 @@ static void sampler_queue(
135130
const float min_p = params.min_p;
136131
const float tfs_z = params.tfs_z;
137132
const float typical_p = params.typical_p;
138-
const std::string & samplers_sequence = params.samplers_sequence;
139-
140-
for (auto s : samplers_sequence) {
141-
switch (s){
142-
case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
143-
case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
144-
case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
145-
case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
146-
case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
147-
case 't':
133+
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
134+
135+
for (auto sampler_type : samplers_sequence) {
136+
switch (sampler_type) {
137+
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
138+
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
139+
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
140+
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
141+
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
142+
case llama_sampler_type::TEMP:
148143
if (dynatemp_range > 0) {
149144
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
150145
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);

common/sampling.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
#include <vector>
99
#include <unordered_map>
1010

11+
// sampler types
12+
enum class llama_sampler_type : char {
13+
TOP_K = 'k',
14+
TOP_P = 'p',
15+
MIN_P = 'm',
16+
TFS_Z = 'f',
17+
TYPICAL_P = 'y',
18+
TEMP = 't'
19+
};
20+
1121
// sampling parameters
1222
typedef struct llama_sampling_params {
1323
int32_t n_prev = 64; // number of previous tokens to remember
@@ -28,7 +38,15 @@ typedef struct llama_sampling_params {
2838
float mirostat_tau = 5.00f; // target entropy
2939
float mirostat_eta = 0.10f; // learning rate
3040
bool penalize_nl = true; // consider newlines as a repeatable token
31-
std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
41+
42+
std::vector<llama_sampler_type> samplers_sequence = {
43+
llama_sampler_type::TOP_K,
44+
llama_sampler_type::TFS_Z,
45+
llama_sampler_type::TYPICAL_P,
46+
llama_sampler_type::TOP_P,
47+
llama_sampler_type::MIN_P,
48+
llama_sampler_type::TEMP
49+
};
3250

3351
std::string grammar; // optional BNF-like grammar to constrain sampling
3452

0 commit comments

Comments
 (0)