Skip to content

Commit b84adbc

Browse files
committed
Updated Ollama
1 parent 813e808 commit b84adbc

File tree

11 files changed

+362
-232
lines changed

11 files changed

+362
-232
lines changed

pkg/mistral/message.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mistral
33
import (
44
"encoding/json"
55

6+
// Packages
67
"github.com/mutablelogic/go-llm"
78
"github.com/mutablelogic/go-llm/pkg/tool"
89
)
@@ -23,13 +24,11 @@ type Message struct {
2324

2425
type RoleContent struct {
2526
Role string `json:"role,omitempty"` // assistant, user, tool, system
27+
Content any `json:"content,omitempty"` // string or array of text, reference, image_url
2628
Id string `json:"tool_call_id,omitempty"` // tool call - when role is tool
2729
Name string `json:"name,omitempty"` // function name - when role is tool
28-
Content any `json:"content,omitempty"` // string or array of text, reference, image_url
2930
}
3031

31-
var _ llm.Completion = (*Message)(nil)
32-
3332
// Completion Variation
3433
type Completion struct {
3534
Index uint64 `json:"index"`
@@ -38,6 +37,8 @@ type Completion struct {
3837
Reason string `json:"finish_reason,omitempty"`
3938
}
4039

40+
var _ llm.Completion = (*Message)(nil)
41+
4142
type Content struct {
4243
Type string `json:"type,omitempty"` // text, reference, image_url
4344
*Text `json:"text,omitempty"` // text content

pkg/ollama/chat.go

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ import (
1313
///////////////////////////////////////////////////////////////////////////////
1414
// TYPES
1515

16-
// Chat Response
16+
// Chat Completion Response
1717
type Response struct {
18-
Model string `json:"model"`
19-
CreatedAt time.Time `json:"created_at"`
20-
Message MessageMeta `json:"message"`
21-
Done bool `json:"done"`
22-
Reason string `json:"done_reason,omitempty"`
18+
Model string `json:"model"`
19+
CreatedAt time.Time `json:"created_at"`
20+
Done bool `json:"done"`
21+
Reason string `json:"done_reason,omitempty"`
22+
Message `json:"message"`
2323
Metrics
2424
}
2525

@@ -33,6 +33,8 @@ type Metrics struct {
3333
EvalDuration time.Duration `json:"eval_duration,omitempty"`
3434
}
3535

36+
var _ llm.Completion = (*Response)(nil)
37+
3638
///////////////////////////////////////////////////////////////////////////////
3739
// STRINGIFY
3840

@@ -49,34 +51,36 @@ func (r Response) String() string {
4951

5052
type reqChat struct {
5153
Model string `json:"model"`
52-
Messages []*MessageMeta `json:"messages"`
53-
Tools []ToolFunction `json:"tools,omitempty"`
54+
Messages []*Message `json:"messages"`
55+
Tools []llm.Tool `json:"tools,omitempty"`
5456
Format string `json:"format,omitempty"`
5557
Options map[string]interface{} `json:"options,omitempty"`
5658
Stream bool `json:"stream"`
5759
KeepAlive *time.Duration `json:"keep_alive,omitempty"`
5860
}
5961

60-
func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm.Opt) (*Response, error) {
62+
func (ollama *Client) Chat(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) {
63+
// Apply options
6164
opt, err := llm.ApplyOpts(opts...)
6265
if err != nil {
6366
return nil, err
6467
}
6568

6669
// Append the system prompt at the beginning
67-
seq := make([]*MessageMeta, 0, len(prompt.(*session).seq)+1)
68-
if system := opt.SystemPrompt(); system != "" {
69-
seq = append(seq, &MessageMeta{
70-
Role: "system",
71-
Content: opt.SystemPrompt(),
72-
})
70+
messages := make([]*Message, 0, len(context.(*session).seq)+1)
71+
//if system := opt.SystemPrompt(); system != "" {
72+
// messages = append(messages, systemPrompt(system))
73+
//}
74+
75+
// Always append the first message of each completion
76+
for _, message := range context.(*session).seq {
77+
messages = append(messages, message)
7378
}
74-
seq = append(seq, prompt.(*session).seq...)
7579

7680
// Request
7781
req, err := client.NewJSONRequest(reqChat{
78-
Model: prompt.(*session).model.Name(),
79-
Messages: seq,
82+
Model: context.(*session).model.Name(),
83+
Messages: messages,
8084
Tools: optTools(ollama, opt),
8185
Format: optFormat(opt),
8286
Options: optOptions(opt),
@@ -88,52 +92,53 @@ func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm.
8892
}
8993

9094
// Response
91-
var response, delta Response
92-
if err := ollama.DoWithContext(ctx, req, &delta, client.OptPath("chat"), client.OptJsonStreamCallback(func(v any) error {
93-
if v, ok := v.(*Response); !ok || v == nil {
94-
return llm.ErrConflict.Withf("Invalid stream response: %v", v)
95-
} else {
96-
response.Model = v.Model
97-
response.CreatedAt = v.CreatedAt
98-
response.Message.Role = v.Message.Role
99-
response.Message.Content += v.Message.Content
100-
if v.Done {
101-
response.Done = v.Done
102-
response.Metrics = v.Metrics
103-
response.Reason = v.Reason
95+
var response Response
96+
reqopts := []client.RequestOpt{
97+
client.OptPath("chat"),
98+
}
99+
if optStream(ollama, opt) {
100+
reqopts = append(reqopts, client.OptJsonStreamCallback(func(v any) error {
101+
if v, ok := v.(*Response); !ok || v == nil {
102+
return llm.ErrConflict.Withf("Invalid stream response: %v", v)
103+
} else if err := streamEvent(&response, v); err != nil {
104+
return err
104105
}
105-
}
106-
107-
//Call the chat callback
108-
if optStream(ollama, opt) {
109106
if fn := opt.StreamFn(); fn != nil {
110107
fn(&response)
111108
}
112-
}
113-
return nil
114-
})); err != nil {
115-
return nil, err
109+
return nil
110+
}))
116111
}
117112

118-
// We return the delta or the response
119-
if optStream(ollama, opt) {
120-
return &response, nil
121-
} else {
122-
return &delta, nil
113+
// Response
114+
if err := ollama.DoWithContext(ctx, req, &response, reqopts...); err != nil {
115+
return nil, err
123116
}
124-
}
125117

126-
///////////////////////////////////////////////////////////////////////////////
127-
// INTERFACE - CONTEXT CONTENT
128-
129-
func (response Response) Role() string {
130-
return response.Message.Role
118+
// Return success
119+
return &response, nil
131120
}
132121

133-
func (response Response) Text() string {
134-
return response.Message.Content
135-
}
122+
///////////////////////////////////////////////////////////////////////////////
123+
// PRIVATE METHODS
136124

137-
func (response Response) ToolCalls() []llm.ToolCall {
125+
func streamEvent(response, delta *Response) error {
126+
if delta.Model != "" {
127+
response.Model = delta.Model
128+
}
129+
if !delta.CreatedAt.IsZero() {
130+
response.CreatedAt = delta.CreatedAt
131+
}
132+
if delta.Message.RoleContent.Role != "" {
133+
response.Message.RoleContent.Role = delta.Message.RoleContent.Role
134+
}
135+
if delta.Message.RoleContent.Content != "" {
136+
response.Message.RoleContent.Content += delta.Message.RoleContent.Content
137+
}
138+
if delta.Done {
139+
response.Done = delta.Done
140+
response.Metrics = delta.Metrics
141+
response.Reason = delta.Reason
142+
}
138143
return nil
139144
}
File renamed without changes.

pkg/ollama/client_test.go

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package ollama_test
22

33
import (
4+
"flag"
5+
"log"
46
"os"
7+
"strconv"
58
"testing"
69

710
// Packages
@@ -10,23 +13,46 @@ import (
1013
assert "github.com/stretchr/testify/assert"
1114
)
1215

13-
func Test_client_001(t *testing.T) {
14-
assert := assert.New(t)
15-
client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true))
16-
if assert.NoError(err) {
17-
assert.NotNil(client)
18-
t.Log(client)
16+
///////////////////////////////////////////////////////////////////////////////
17+
// TEST SET-UP
18+
19+
var (
20+
client *ollama.Client
21+
)
22+
23+
func TestMain(m *testing.M) {
24+
var verbose bool
25+
26+
// Verbose output
27+
flag.Parse()
28+
if f := flag.Lookup("test.v"); f != nil {
29+
if v, err := strconv.ParseBool(f.Value.String()); err == nil {
30+
verbose = v
31+
}
1932
}
33+
34+
// Endpoint
35+
endpoint_url := os.Getenv("OLLAMA_URL")
36+
if endpoint_url == "" {
37+
log.Print("OLLAMA_URL not set")
38+
os.Exit(0)
39+
}
40+
41+
// Create client
42+
var err error
43+
client, err = ollama.New(endpoint_url, opts.OptTrace(os.Stderr, verbose))
44+
if err != nil {
45+
log.Println(err)
46+
os.Exit(-1)
47+
}
48+
os.Exit(m.Run())
2049
}
2150

2251
///////////////////////////////////////////////////////////////////////////////
23-
// ENVIRONMENT
52+
// TESTS
2453

25-
func GetEndpoint(t *testing.T) string {
26-
key := os.Getenv("OLLAMA_URL")
27-
if key == "" {
28-
t.Skip("OLLAMA_URL not set, skipping tests")
29-
t.SkipNow()
30-
}
31-
return key
54+
func Test_client_001(t *testing.T) {
55+
assert := assert.New(t)
56+
assert.NotNil(client)
57+
t.Log(client)
3258
}

pkg/ollama/embedding_test.go

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,17 @@ package ollama_test
22

33
import (
44
"context"
5-
"os"
65
"testing"
76

87
// Packages
9-
opts "github.com/mutablelogic/go-client"
10-
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
8+
119
assert "github.com/stretchr/testify/assert"
1210
)
1311

1412
func Test_embed_001(t *testing.T) {
15-
client, err := ollama.New(GetEndpoint(t), opts.OptTrace(os.Stderr, true))
16-
if err != nil {
17-
t.FailNow()
18-
}
19-
2013
t.Run("Embedding", func(t *testing.T) {
2114
assert := assert.New(t)
22-
embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"world"})
15+
embedding, err := client.GenerateEmbedding(context.TODO(), "qwen:0.5b", []string{"hello, world"})
2316
if !assert.NoError(err) {
2417
t.FailNow()
2518
}

pkg/ollama/message.go

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,80 @@
11
package ollama
22

33
import (
4+
"fmt"
5+
6+
// Packages
47
llm "github.com/mutablelogic/go-llm"
8+
tool "github.com/mutablelogic/go-llm/pkg/tool"
59
)
610

711
///////////////////////////////////////////////////////////////////////////////
812
// TYPES
913

10-
// Chat Message
11-
type MessageMeta struct {
12-
Role string `json:"role"`
13-
Content string `json:"content,omitempty"`
14-
FunctionName string `json:"name,omitempty"` // Function name for a tool result
15-
Images []Data `json:"images,omitempty"` // Image attachments
16-
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Tool calls from the assistant
14+
// Message with text or object content
15+
type Message struct {
16+
RoleContent
17+
ToolCallArray `json:"tool_calls,omitempty"`
18+
}
19+
20+
type RoleContent struct {
21+
Role string `json:"role,omitempty"` // assistant, user, tool, system
22+
Content string `json:"content,omitempty"` // string or array of text, reference, image_url
23+
Images []Data `json:"images,omitempty"` // Image attachments
24+
ToolResult
1725
}
1826

27+
// A set of tool calls
28+
type ToolCallArray []ToolCall
29+
1930
type ToolCall struct {
31+
Type string `json:"type"` // function
2032
Function ToolCallFunction `json:"function"`
2133
}
2234

2335
type ToolCallFunction struct {
2436
Index int `json:"index,omitempty"`
2537
Name string `json:"name"`
26-
Arguments map[string]any `json:"arguments"`
38+
Arguments map[string]any `json:"arguments,omitempty"`
2739
}
2840

2941
// Data represents the raw binary data of an image file.
3042
type Data []byte
3143

32-
// ToolFunction
33-
type ToolFunction struct {
34-
Type string `json:"type"` // function
35-
Function llm.Tool `json:"function"`
44+
// ToolResult
45+
type ToolResult struct {
46+
Name string `json:"name,omitempty"` // function name - when role is tool
47+
}
48+
49+
///////////////////////////////////////////////////////////////////////////////
50+
// PUBLIC METHODS - MESSAGE
51+
52+
func (m Message) Num() int {
53+
return 1
54+
}
55+
56+
func (m Message) Role() string {
57+
return m.RoleContent.Role
58+
}
59+
60+
func (m Message) Text(index int) string {
61+
if index != 0 {
62+
return ""
63+
}
64+
return m.Content
65+
}
66+
67+
func (m Message) ToolCalls(index int) []llm.ToolCall {
68+
if index != 0 {
69+
return nil
70+
}
71+
72+
// Make the tool calls
73+
calls := make([]llm.ToolCall, 0, len(m.ToolCallArray))
74+
for _, call := range m.ToolCallArray {
75+
calls = append(calls, tool.NewCall(fmt.Sprint(call.Function.Index), call.Function.Name, call.Function.Arguments))
76+
}
77+
78+
// Return success
79+
return calls
3680
}

0 commit comments

Comments
 (0)