@@ -340,13 +340,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
340
340
invalid_param = true ;
341
341
break ;
342
342
}
343
- sparams.samplers_sequence = parse_samplers_input (argv[i]);
343
+ const auto sampler_names = string_split (argv[i], ' ;' );
344
+ sparams.samplers_sequence = sampler_types_from_names (sampler_names);
344
345
} else if (arg == " --sampling-seq" ) {
345
346
if (++i >= argc) {
346
347
invalid_param = true ;
347
348
break ;
348
349
}
349
- sparams.samplers_sequence = argv[i];
350
+ sparams.samplers_sequence = sampler_types_from_chars ( argv[i]) ;
350
351
} else if (arg == " --top-p" ) {
351
352
if (++i >= argc) {
352
353
invalid_param = true ;
@@ -906,6 +907,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
906
907
void gpt_print_usage (int /* argc*/ , char ** argv, const gpt_params & params) {
907
908
const llama_sampling_params & sparams = params.sparams ;
908
909
910
+ std::string sampler_type_chars;
911
+ std::string sampler_type_names;
912
+ for (const auto sampler_type : sparams.samplers_sequence ) {
913
+ sampler_type_chars += static_cast <char >(sampler_type);
914
+ sampler_type_names += sampler_type_to_name_string (sampler_type) + " ;" ;
915
+ }
916
+ sampler_type_names.pop_back ();
917
+
909
918
printf (" \n " );
910
919
printf (" usage: %s [options]\n " , argv[0 ]);
911
920
printf (" \n " );
@@ -947,8 +956,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
947
956
printf (" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n " , params.n_predict );
948
957
printf (" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n " , params.n_ctx );
949
958
printf (" -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
950
- printf (" --samplers samplers that will be used for generation in the order, separated by \' ;\' , for example: \" top_k;tfs;typical;top_p;min_p;temp \"\n " );
951
- printf (" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n " , sparams. samplers_sequence .c_str ());
959
+ printf (" --samplers samplers that will be used for generation in the order, separated by \' ;\' (default: %s) \n " , sampler_type_names. c_str () );
960
+ printf (" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n " , sampler_type_chars .c_str ());
952
961
printf (" --top-k N top-k sampling (default: %d, 0 = disabled)\n " , sparams.top_k );
953
962
printf (" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n " , (double )sparams.top_p );
954
963
printf (" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n " , (double )sparams.min_p );
@@ -1097,45 +1106,85 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
1097
1106
}
1098
1107
1099
1108
//
1100
- // String parsing
1109
+ // String utils
1101
1110
//
1102
1111
1103
- std::string parse_samplers_input (std::string input) {
1104
- std::string output = " " ;
1112
+ std::vector<std::string> string_split (std::string input, char separator) {
1113
+ std::vector<std::string> parts;
1114
+ size_t separator_pos = input.find (separator);
1115
+ while (separator_pos != std::string::npos) {
1116
+ std::string part = input.substr (0 , separator_pos);
1117
+ parts.emplace_back (part);
1118
+ input = input.substr (separator_pos + 1 );
1119
+ separator_pos = input.find (separator);
1120
+ }
1121
+ parts.emplace_back (input);
1122
+ return parts;
1123
+ }
1124
+
1125
+ std::vector<llama_sampler_type> sampler_types_from_names (const std::vector<std::string> & names) {
1105
1126
// since samplers names are written multiple ways
1106
1127
// make it ready for both system names and input names
1107
- std::unordered_map<std::string, char > samplers_symbols {
1108
- {" top_k" , ' k' },
1109
- {" top-k" , ' k' },
1110
- {" top_p" , ' p' },
1111
- {" top-p" , ' p' },
1112
- {" nucleus" , ' p' },
1113
- {" typical_p" , ' y' },
1114
- {" typical-p" , ' y' },
1115
- {" typical" , ' y' },
1116
- {" min_p" , ' m' },
1117
- {" min-p" , ' m' },
1118
- {" tfs_z" , ' f' },
1119
- {" tfs-z" , ' f' },
1120
- {" tfs" , ' f' },
1121
- {" temp" , ' t' },
1122
- {" temperature" ,' t' }
1128
+ std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
1129
+ {" top_k" , llama_sampler_type::TOP_K},
1130
+ {" top-k" , llama_sampler_type::TOP_K},
1131
+ {" top_p" , llama_sampler_type::TOP_P},
1132
+ {" top-p" , llama_sampler_type::TOP_P},
1133
+ {" nucleus" , llama_sampler_type::TOP_P},
1134
+ {" typical_p" , llama_sampler_type::TYPICAL_P},
1135
+ {" typical-p" , llama_sampler_type::TYPICAL_P},
1136
+ {" typical" , llama_sampler_type::TYPICAL_P},
1137
+ {" min_p" , llama_sampler_type::MIN_P},
1138
+ {" min-p" , llama_sampler_type::MIN_P},
1139
+ {" tfs_z" , llama_sampler_type::TFS_Z},
1140
+ {" tfs-z" , llama_sampler_type::TFS_Z},
1141
+ {" tfs" , llama_sampler_type::TFS_Z},
1142
+ {" temp" , llama_sampler_type::TEMP},
1143
+ {" temperature" , llama_sampler_type::TEMP}
1144
+ };
1145
+
1146
+ std::vector<llama_sampler_type> sampler_types;
1147
+ sampler_types.reserve (names.size ());
1148
+ for (const auto & name : names) {
1149
+ const auto sampler_item = sampler_name_map.find (name);
1150
+ if (sampler_item != sampler_name_map.end ()) {
1151
+ sampler_types.push_back (sampler_item->second );
1152
+ }
1153
+ }
1154
+ return sampler_types;
1155
+ }
1156
+
1157
+ std::vector<llama_sampler_type> sampler_types_from_chars (const std::string & names_string) {
1158
+ std::unordered_map<char , llama_sampler_type> sampler_name_map {
1159
+ {' k' , llama_sampler_type::TOP_K},
1160
+ {' p' , llama_sampler_type::TOP_P},
1161
+ {' y' , llama_sampler_type::TYPICAL_P},
1162
+ {' m' , llama_sampler_type::MIN_P},
1163
+ {' f' , llama_sampler_type::TFS_Z},
1164
+ {' t' , llama_sampler_type::TEMP}
1123
1165
};
1124
- // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
1125
- size_t separator = input.find (' ;' );
1126
- while (separator != input.npos ) {
1127
- std::string name = input.substr (0 ,separator);
1128
- input = input.substr (separator+1 );
1129
- separator = input.find (' ;' );
1130
-
1131
- if (samplers_symbols.find (name) != samplers_symbols.end ()) {
1132
- output += samplers_symbols[name];
1166
+
1167
+ std::vector<llama_sampler_type> sampler_types;
1168
+ sampler_types.reserve (names_string.size ());
1169
+ for (const auto & c : names_string) {
1170
+ const auto sampler_item = sampler_name_map.find (c);
1171
+ if (sampler_item != sampler_name_map.end ()) {
1172
+ sampler_types.push_back (sampler_item->second );
1133
1173
}
1134
1174
}
1135
- if (samplers_symbols.find (input) != samplers_symbols.end ()) {
1136
- output += samplers_symbols[input];
1175
+ return sampler_types;
1176
+ }
1177
+
1178
+ std::string sampler_type_to_name_string (llama_sampler_type sampler_type) {
1179
+ switch (sampler_type) {
1180
+ case llama_sampler_type::TOP_K: return " top_k" ;
1181
+ case llama_sampler_type::TFS_Z: return " tfs_z" ;
1182
+ case llama_sampler_type::TYPICAL_P: return " typical_p" ;
1183
+ case llama_sampler_type::TOP_P: return " top_p" ;
1184
+ case llama_sampler_type::MIN_P: return " min_p" ;
1185
+ case llama_sampler_type::TEMP: return " temp" ;
1186
+ default : return " " ;
1137
1187
}
1138
- return output;
1139
1188
}
1140
1189
1141
1190
//
0 commit comments