@@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
229
229
params.logit_bias .data ()));
230
230
231
231
if (params.mirostat == 0 ) {
232
- if (params.top_n_sigma >= 0 ) {
233
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
234
- llama_sampler_chain_add (result->chain , llama_sampler_init_temp (params.temp ));
235
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
236
- } else {
237
- for (const auto & cnstr : params.samplers ) {
238
- switch (cnstr) {
239
- case COMMON_SAMPLER_TYPE_DRY:
240
- {
241
- std::vector<const char *> c_breakers;
242
- c_breakers.reserve (params.dry_sequence_breakers .size ());
243
- for (const auto & str : params.dry_sequence_breakers ) {
244
- c_breakers.push_back (str.c_str ());
245
- }
246
-
247
- llama_sampler_chain_add (result->chain , llama_sampler_init_dry (vocab, llama_model_n_ctx_train (model), params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
232
+ for (const auto & cnstr : params.samplers ) {
233
+ switch (cnstr) {
234
+ case COMMON_SAMPLER_TYPE_DRY:
235
+ {
236
+ std::vector<const char *> c_breakers;
237
+ c_breakers.reserve (params.dry_sequence_breakers .size ());
238
+ for (const auto & str : params.dry_sequence_breakers ) {
239
+ c_breakers.push_back (str.c_str ());
248
240
}
249
- break ;
250
- case COMMON_SAMPLER_TYPE_TOP_K:
251
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
252
- break ;
253
- case COMMON_SAMPLER_TYPE_TOP_P:
254
- llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
255
- break ;
256
- case COMMON_SAMPLER_TYPE_MIN_P:
257
- llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
258
- break ;
259
- case COMMON_SAMPLER_TYPE_XTC:
260
- llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
261
- break ;
262
- case COMMON_SAMPLER_TYPE_TYPICAL_P:
263
- llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
264
- break ;
265
- case COMMON_SAMPLER_TYPE_TEMPERATURE:
266
- llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
267
- break ;
268
- case COMMON_SAMPLER_TYPE_INFILL:
269
- llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
270
- break ;
271
- case COMMON_SAMPLER_TYPE_PENALTIES:
272
- llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
273
- break ;
274
- default :
275
- GGML_ASSERT (false && " unknown sampler type" );
276
- }
241
+
242
+ llama_sampler_chain_add (result->chain , llama_sampler_init_dry (vocab, llama_model_n_ctx_train (model), params.dry_multiplier , params.dry_base , params.dry_allowed_length , params.dry_penalty_last_n , c_breakers.data (), c_breakers.size ()));
243
+ }
244
+ break ;
245
+ case COMMON_SAMPLER_TYPE_TOP_K:
246
+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_k (params.top_k ));
247
+ break ;
248
+ case COMMON_SAMPLER_TYPE_TOP_P:
249
+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_p (params.top_p , params.min_keep ));
250
+ break ;
251
+ case COMMON_SAMPLER_TYPE_MIN_P:
252
+ llama_sampler_chain_add (result->chain , llama_sampler_init_min_p (params.min_p , params.min_keep ));
253
+ break ;
254
+ case COMMON_SAMPLER_TYPE_XTC:
255
+ llama_sampler_chain_add (result->chain , llama_sampler_init_xtc (params.xtc_probability , params.xtc_threshold , params.min_keep , params.seed ));
256
+ break ;
257
+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
258
+ llama_sampler_chain_add (result->chain , llama_sampler_init_typical (params.typ_p , params.min_keep ));
259
+ break ;
260
+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
261
+ llama_sampler_chain_add (result->chain , llama_sampler_init_temp_ext (params.temp , params.dynatemp_range , params.dynatemp_exponent ));
262
+ break ;
263
+ case COMMON_SAMPLER_TYPE_INFILL:
264
+ llama_sampler_chain_add (result->chain , llama_sampler_init_infill (vocab));
265
+ break ;
266
+ case COMMON_SAMPLER_TYPE_PENALTIES:
267
+ llama_sampler_chain_add (result->chain , llama_sampler_init_penalties (params.penalty_last_n , params.penalty_repeat , params.penalty_freq , params.penalty_present ));
268
+ break ;
269
+ case COMMON_SAMPLER_TYPE_TOP_NSIGMA:
270
+ llama_sampler_chain_add (result->chain , llama_sampler_init_top_n_sigma (params.top_n_sigma ));
271
+ break ;
272
+ default :
273
+ GGML_ASSERT (false && " unknown sampler type" );
277
274
}
278
275
}
279
276
llama_sampler_chain_add (result->chain , llama_sampler_init_dist (params.seed ));
@@ -480,6 +477,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
480
477
case COMMON_SAMPLER_TYPE_XTC: return ' x' ;
481
478
case COMMON_SAMPLER_TYPE_INFILL: return ' i' ;
482
479
case COMMON_SAMPLER_TYPE_PENALTIES: return ' e' ;
480
+ case COMMON_SAMPLER_TYPE_TOP_NSIGMA: return ' s' ;
483
481
default : return ' ?' ;
484
482
}
485
483
}
@@ -495,6 +493,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
495
493
case COMMON_SAMPLER_TYPE_XTC: return " xtc" ;
496
494
case COMMON_SAMPLER_TYPE_INFILL: return " infill" ;
497
495
case COMMON_SAMPLER_TYPE_PENALTIES: return " penalties" ;
496
+ case COMMON_SAMPLER_TYPE_TOP_NSIGMA: return " top_n_sigma" ;
498
497
default : return " " ;
499
498
}
500
499
}
@@ -510,6 +509,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
510
509
{ " xtc" , COMMON_SAMPLER_TYPE_XTC },
511
510
{ " infill" , COMMON_SAMPLER_TYPE_INFILL },
512
511
{ " penalties" , COMMON_SAMPLER_TYPE_PENALTIES },
512
+ { " top_n_sigma" , COMMON_SAMPLER_TYPE_TOP_NSIGMA},
513
513
};
514
514
515
515
// since samplers names are written multiple ways
@@ -524,6 +524,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
524
524
{ " typ" , COMMON_SAMPLER_TYPE_TYPICAL_P },
525
525
{ " min-p" , COMMON_SAMPLER_TYPE_MIN_P },
526
526
{ " temp" , COMMON_SAMPLER_TYPE_TEMPERATURE },
527
+ { " top-n-sigma" , COMMON_SAMPLER_TYPE_TOP_NSIGMA},
527
528
};
528
529
529
530
std::vector<common_sampler_type> samplers;
@@ -557,6 +558,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
557
558
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
558
559
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
559
560
{ common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
561
+ { common_sampler_type_to_chr (COMMON_SAMPLER_TYPE_TOP_NSIGMA), COMMON_SAMPLER_TYPE_TOP_NSIGMA},
560
562
};
561
563
562
564
std::vector<common_sampler_type> samplers;
0 commit comments