Skip to content

Commit ed2b92e

Browse files
committed
Updated
1 parent b84adbc commit ed2b92e

File tree

8 files changed

+140
-36
lines changed

8 files changed

+140
-36
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ The options are as follows:
202202
| `llm.WithToolKit(llm.ToolKit)` | Cannot be combined with streaming | Yes | Yes | - | The set of tools to use. |
203203
| `llm.WithStopSequence(string, string, ...)` | Yes | Yes | Yes | - | Stop generation if one of these tokens is detected. |
204204
| `llm.WithSystemPrompt(string)` | No | Yes | Yes | - | Set the system prompt for the model. |
205-
| `llm.WithSeed(uint64)` | No | Yes | Yes | - | The seed to use for random sampling. If set, different calls will generate deterministic results. |
206-
| `llm.WithFormat(string)` | No | Yes | Use `json_format` or `text` | - | The format of the response. For Mistral, you must also instruct the model to produce JSON yourself with a system or a user message. |
207-
| `mistral.WithPresencePenalty(float64)` | No | No | Yes | - | Determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. |
208-
| `mistral.WithFequencyPenalty(float64)` | No | No | Yes | - | Penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition. |
205+
| `llm.WithSeed(uint64)` | Yes | Yes | Yes | - | The seed to use for random sampling. If set, different calls will generate deterministic results. |
206+
| `llm.WithFormat(string)` | Use `json` | Yes | Use `json_format` or `text` | - | The format of the response. For Mistral, you must also instruct the model to produce JSON yourself with a system or a user message. |
207+
| `llm.WithPresencePenalty(float64)` | Yes | No | Yes | - | Determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. |
208+
| `llm.WithFequencyPenalty(float64)` | Yes | No | Yes | - | Penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition. |
209209
| `mistral.WithPrediction(string)` | No | No | Yes | - | Enable users to specify expected results, optimizing response times by leveraging known or predictable content. This approach is especially effective for updating text documents or code files with minimal changes, reducing latency while maintaining high-quality results. |
210210
| `llm.WithSafePrompt()` | No | No | Yes | - | Whether to inject a safety prompt before all conversations. |
211211
| `llm.WithNumCompletions(uint64)` | No | No | Yes | - | Number of completions to return for each request. |

opt.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,26 @@ func WithTopK(v uint64) Opt {
271271
}
272272
}
273273

274+
func WithPresencePenalty(v float64) Opt {
275+
return func(o *Opts) error {
276+
if v < -2 || v > 2 {
277+
return ErrBadParameter.With("presence_penalty")
278+
}
279+
o.Set("presence_penalty", v)
280+
return nil
281+
}
282+
}
283+
284+
func WithFrequencyPenalty(v float64) Opt {
285+
return func(o *Opts) error {
286+
if v < -2 || v > 2 {
287+
return ErrBadParameter.With("frequency_penalty")
288+
}
289+
o.Set("frequency_penalty", v)
290+
return nil
291+
}
292+
}
293+
274294
// The maximum number of tokens to generate in the completion.
275295
func WithMaxTokens(v uint64) Opt {
276296
return func(o *Opts) error {

pkg/mistral/opt.go

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,6 @@ import (
99
///////////////////////////////////////////////////////////////////////////////
1010
// PUBLIC METHODS
1111

12-
func WithPresencePenalty(v float64) llm.Opt {
13-
return func(o *llm.Opts) error {
14-
if v < -2 || v > 2 {
15-
return llm.ErrBadParameter.With("presence_penalty")
16-
}
17-
o.Set("presence_penalty", v)
18-
return nil
19-
}
20-
}
21-
22-
func WithFrequencyPenalty(v float64) llm.Opt {
23-
return func(o *llm.Opts) error {
24-
if v < -2 || v > 2 {
25-
return llm.ErrBadParameter.With("frequency_penalty")
26-
}
27-
o.Set("frequency_penalty", v)
28-
return nil
29-
}
30-
}
31-
3212
func WithPrediction(v string) llm.Opt {
3313
return func(o *llm.Opts) error {
3414
o.Set("prediction", v)

pkg/ollama/chat.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ func (ollama *Client) Chat(ctx context.Context, context llm.Context, opts ...llm
6868

6969
// Append the system prompt at the beginning
7070
messages := make([]*Message, 0, len(context.(*session).seq)+1)
71-
//if system := opt.SystemPrompt(); system != "" {
72-
// messages = append(messages, systemPrompt(system))
73-
//}
71+
if system := opt.SystemPrompt(); system != "" {
72+
messages = append(messages, systemPrompt(system))
73+
}
7474

7575
// Always append the first message of each completion
7676
for _, message := range context.(*session).seq {
@@ -92,7 +92,7 @@ func (ollama *Client) Chat(ctx context.Context, context llm.Context, opts ...llm
9292
}
9393

9494
// Response
95-
var response Response
95+
var response, delta Response
9696
reqopts := []client.RequestOpt{
9797
client.OptPath("chat"),
9898
}
@@ -111,12 +111,16 @@ func (ollama *Client) Chat(ctx context.Context, context llm.Context, opts ...llm
111111
}
112112

113113
// Response
114-
if err := ollama.DoWithContext(ctx, req, &response, reqopts...); err != nil {
114+
if err := ollama.DoWithContext(ctx, req, &delta, reqopts...); err != nil {
115115
return nil, err
116116
}
117117

118118
// Return success
119-
return &response, nil
119+
if optStream(ollama, opt) {
120+
return &response, nil
121+
} else {
122+
return &delta, nil
123+
}
120124
}
121125

122126
///////////////////////////////////////////////////////////////////////////////

pkg/ollama/chat_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package ollama_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
// Packages
8+
9+
llm "github.com/mutablelogic/go-llm"
10+
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
11+
assert "github.com/stretchr/testify/assert"
12+
)
13+
14+
func Test_chat_001(t *testing.T) {
15+
// Pull the model
16+
model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) {
17+
t.Log(status)
18+
}))
19+
if err != nil {
20+
t.FailNow()
21+
}
22+
23+
t.Run("Temperature", func(t *testing.T) {
24+
assert := assert.New(t)
25+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTemperature(0.5))
26+
if !assert.NoError(err) {
27+
t.FailNow()
28+
}
29+
t.Log(response)
30+
})
31+
32+
t.Run("TopP", func(t *testing.T) {
33+
assert := assert.New(t)
34+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTopP(0.5))
35+
if !assert.NoError(err) {
36+
t.FailNow()
37+
}
38+
t.Log(response)
39+
})
40+
t.Run("TopK", func(t *testing.T) {
41+
assert := assert.New(t)
42+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithTopK(50))
43+
if !assert.NoError(err) {
44+
t.FailNow()
45+
}
46+
t.Log(response)
47+
})
48+
49+
t.Run("Stream", func(t *testing.T) {
50+
assert := assert.New(t)
51+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStream(func(stream llm.Completion) {
52+
t.Log(stream)
53+
}))
54+
if !assert.NoError(err) {
55+
t.FailNow()
56+
}
57+
t.Log(response)
58+
})
59+
60+
t.Run("Stop", func(t *testing.T) {
61+
assert := assert.New(t)
62+
response, err := client.Chat(context.TODO(), model.UserPrompt("why is the sky blue?"), llm.WithStopSequence("sky"))
63+
if !assert.NoError(err) {
64+
t.FailNow()
65+
}
66+
t.Log(response)
67+
})
68+
}

pkg/ollama/embedding_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,22 @@ import (
99
assert "github.com/stretchr/testify/assert"
1010
)
1111

12-
func Test_embed_001(t *testing.T) {
13-
t.Run("Embedding", func(t *testing.T) {
12+
func Test_embeddings_001(t *testing.T) {
13+
t.Run("Embedding1", func(t *testing.T) {
1414
assert := assert.New(t)
1515
embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"hello, world"})
1616
if !assert.NoError(err) {
1717
t.FailNow()
1818
}
19-
t.Log(embedding)
19+
assert.Equal(1, len(embedding.Embeddings))
20+
})
21+
22+
t.Run("Embedding2", func(t *testing.T) {
23+
assert := assert.New(t)
24+
embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"hello, world", "goodbye cruel world"})
25+
if !assert.NoError(err) {
26+
t.FailNow()
27+
}
28+
assert.Equal(2, len(embedding.Embeddings))
2029
})
2130
}

pkg/ollama/model.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ func (ollama *Client) GetModel(ctx context.Context, name string) (llm.Model, err
162162
var response ModelMeta
163163
if err := ollama.DoWithContext(ctx, req, &response, client.OptPath("show")); err != nil {
164164
return nil, err
165+
} else {
166+
response.Name = name
165167
}
166168

167169
// Return success

pkg/ollama/opt.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,15 @@ func optFormat(opts *llm.Opts) string {
9191
return opts.GetString("format")
9292
}
9393

94+
func optStopSequence(opts *llm.Opts) []string {
95+
if opts.Has("stop") {
96+
if stop, ok := opts.Get("stop").([]string); ok {
97+
return stop
98+
}
99+
}
100+
return nil
101+
}
102+
94103
func optOptions(opts *llm.Opts) map[string]any {
95104
result := make(map[string]any)
96105
if o, ok := opts.Get("options").(map[string]any); ok {
@@ -101,13 +110,25 @@ func optOptions(opts *llm.Opts) map[string]any {
101110

102111
// copy across temperature, top_p and top_k
103112
if opts.Has("temperature") {
104-
result["temperature"] = opts.Get("temperature")
113+
result["temperature"] = opts.Get("temperature").(float64)
105114
}
106115
if opts.Has("top_p") {
107-
result["top_p"] = opts.Get("top_p")
116+
result["top_p"] = opts.GetFloat64("top_p")
108117
}
109118
if opts.Has("top_k") {
110-
result["top_k"] = opts.Get("top_k")
119+
result["top_k"] = opts.GetUint64("top_k")
120+
}
121+
if opts.Has("stop") {
122+
result["stop"] = opts.Get("stop").([]string)
123+
}
124+
if opts.Has("seed") {
125+
result["seed"] = opts.GetUint64("seed")
126+
}
127+
if opts.Has("presence_penalty") {
128+
result["presence_penalty"] = opts.GetFloat64("presence_penalty")
129+
}
130+
if opts.Has("frequency_penalty") {
131+
result["frequency_penalty"] = opts.GetFloat64("frequency_penalty")
111132
}
112133

113134
// Return result

0 commit comments

Comments
 (0)