Skip to content

Commit 1b70cde

Browse files
committed
fix(llama): Overhaul use of sampling module for llama.cpp changes
The changes here reflect the changes made in the big llama.cpp sampling PR ggml-org/llama.cpp#9294 The sampling functionality is now broken into the base interface (llama_sampler) and the generation implementation (gpt_sampler). The changes here reflect that. Since the sampling.h/sampling.cpp code uses c++ STL headers, the sampling_ext.[h|cpp] wrapper is maintained to allow go to access a pure-C interface. Branch: IBMGraniteArchitectureSupport Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent bcdae7c commit 1b70cde

File tree

4 files changed

+33
-37
lines changed

4 files changed

+33
-37
lines changed

llama/llama.go

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
445445
// sampling
446446
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
447447
type SamplingContext struct {
448-
c *C.struct_llama_sampling_context
448+
c *C.struct_llama_sampler
449449
}
450450

451451
type SamplingParams struct {
@@ -467,7 +467,8 @@ type SamplingParams struct {
467467
Grammar string
468468
}
469469

470-
func NewSamplingContext(params SamplingParams) *SamplingContext {
470+
func NewSamplingContext(model *Model, params SamplingParams) *SamplingContext {
471+
471472
var cparams C.struct_llama_sampling_cparams
472473
cparams.top_k = C.int32_t(params.TopK)
473474
cparams.top_p = C.float(params.TopP)
@@ -489,7 +490,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
489490
defer C.free(unsafe.Pointer(grammar))
490491

491492
cparams.grammar = grammar
492-
context := &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
493+
context := &SamplingContext{c: C.llama_sampling_cinit(model.c, &cparams)}
493494
runtime.SetFinalizer(context, func(s *SamplingContext) { C.llama_sampling_cfree(s.c) })
494495

495496
return context
@@ -499,15 +500,10 @@ func (s *SamplingContext) Reset() {
499500
C.llama_sampling_creset(s.c)
500501
}
501502

502-
func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
503-
// TODO (jmorganca): handle nil for all args
504-
if ctxConfig == nil {
505-
return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
506-
}
507-
508-
return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
503+
func (s *SamplingContext) Sample(ctxMain *Context, idx int) int {
504+
return int(C.llama_sampling_csample(s.c, ctxMain.c, C.int(idx)))
509505
}
510506

511-
func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
512-
C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
507+
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
508+
C.llama_sampling_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
513509
}

llama/runner/runner.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
126126

127127
var sc *llama.SamplingContext
128128
if params.samplingParams != nil {
129-
sc = llama.NewSamplingContext(*params.samplingParams)
129+
sc = llama.NewSamplingContext(s.model, *params.samplingParams)
130130
for _, input := range inputs {
131131
if input.embed == nil {
132-
sc.Accept(s.lc, input.token, false)
132+
sc.Accept(input.token, false)
133133
}
134134
}
135135
}
@@ -429,8 +429,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
429429
}
430430

431431
// sample a token
432-
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
433-
seq.samplingCtx.Accept(s.lc, token, true)
432+
token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
433+
seq.samplingCtx.Accept(token, true)
434434
piece := s.model.TokenToPiece(token)
435435

436436
seq.numPredicted++

llama/sampling_ext.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
#include "sampling.h"
33
#include "sampling_ext.h"
44

5-
struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparams *params)
5+
struct llama_sampler *llama_sampling_cinit(
6+
const struct llama_model *model, struct llama_sampling_cparams *params)
67
{
7-
llama_sampling_params sparams;
8+
gpt_sampler_params sparams;
89
sparams.top_k = params->top_k;
910
sparams.top_p = params->top_p;
1011
sparams.min_p = params->min_p;
1112
sparams.tfs_z = params->tfs_z;
12-
sparams.typical_p = params->typical_p;
13+
sparams.typ_p = params->typical_p;
1314
sparams.temp = params->temp;
1415
sparams.penalty_last_n = params->penalty_last_n;
1516
sparams.penalty_repeat = params->penalty_repeat;
@@ -21,33 +22,32 @@ struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparam
2122
sparams.penalize_nl = params->penalize_nl;
2223
sparams.seed = params->seed;
2324
sparams.grammar = params->grammar;
24-
return llama_sampling_init(sparams);
25+
return (llama_sampler*)gpt_sampler_init(model, sparams);
2526
}
2627

27-
void llama_sampling_cfree(struct llama_sampling_context *ctx)
28+
void llama_sampling_cfree(struct llama_sampler *sampler)
2829
{
29-
llama_sampling_free(ctx);
30+
gpt_sampler_free((gpt_sampler*)sampler);
3031
}
3132

32-
void llama_sampling_creset(struct llama_sampling_context *ctx)
33+
void llama_sampling_creset(struct llama_sampler *sampler)
3334
{
34-
llama_sampling_reset(ctx);
35+
gpt_sampler_reset((gpt_sampler*)sampler);
3536
}
3637

3738
llama_token llama_sampling_csample(
38-
struct llama_sampling_context *ctx_sampling,
39+
struct llama_sampler *sampler,
3940
struct llama_context *ctx_main,
40-
struct llama_context *ctx_cfg,
4141
int idx)
4242
{
43-
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx);
43+
// TODO (ggoodhart): Do we need to support grammar_first?
44+
return gpt_sampler_sample((gpt_sampler*)sampler, ctx_main, idx);
4445
}
4546

4647
void llama_sampling_caccept(
47-
struct llama_sampling_context *ctx_sampling,
48-
struct llama_context *ctx_main,
48+
struct llama_sampler *sampler,
4949
llama_token id,
5050
bool apply_grammar)
5151
{
52-
llama_sampling_accept(ctx_sampling, ctx_main, id, apply_grammar);
52+
gpt_sampler_accept((gpt_sampler*)sampler, id, apply_grammar);
5353
}

llama/sampling_ext.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ extern "C"
2929
char *grammar;
3030
};
3131

32-
struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparams *params);
33-
void llama_sampling_cfree(struct llama_sampling_context *ctx);
34-
void llama_sampling_creset(struct llama_sampling_context *ctx);
32+
struct llama_sampler *llama_sampling_cinit(
33+
const struct llama_model *model,
34+
struct llama_sampling_cparams *params);
35+
void llama_sampling_cfree(struct llama_sampler *sampler);
36+
void llama_sampling_creset(struct llama_sampler *sampler);
3537

3638
llama_token llama_sampling_csample(
37-
struct llama_sampling_context *ctx_sampling,
39+
struct llama_sampler *sampler,
3840
struct llama_context *ctx_main,
39-
struct llama_context *ctx_cfg,
4041
int idx);
4142

4243
void llama_sampling_caccept(
43-
struct llama_sampling_context *ctx_sampling,
44-
struct llama_context *ctx_main,
44+
struct llama_sampler *sampler,
4545
llama_token id,
4646
bool apply_grammar);
4747

0 commit comments

Comments
 (0)