Skip to content

Commit a709077

Browse files
committed
More tool calling work
1 parent a05b1f8 commit a709077

File tree

15 files changed

+171
-99
lines changed

15 files changed

+171
-99
lines changed

pkg/anthropic/messages.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func (anthropic *Client) Messages(ctx context.Context, context llm.Context, opts
7878
req, err := client.NewJSONRequest(reqMessages{
7979
Model: context.(*session).model.Name(),
8080
Messages: context.(*session).seq,
81-
Tools: opt.Tools(),
81+
Tools: opt.tools(anthropic),
8282
opt: *opt,
8383
})
8484
if err != nil {

pkg/anthropic/messages_test.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ package anthropic_test
22

33
import (
44
"context"
5+
"encoding/json"
6+
"log"
57
"os"
68
"testing"
79

810
// Packages
911
opts "github.com/mutablelogic/go-client"
10-
"github.com/mutablelogic/go-llm"
1112
anthropic "github.com/mutablelogic/go-llm/pkg/anthropic"
1213
"github.com/mutablelogic/go-llm/pkg/tool"
1314
assert "github.com/stretchr/testify/assert"
@@ -153,7 +154,7 @@ func Test_messages_005(t *testing.T) {
153154
// TOOLS
154155

155156
type weather struct {
156-
Location string `name:"location" help:"The location to get the weather for" required:"true"`
157+
Location string `json:"location" name:"location" help:"The location to get the weather for" required:"true"`
157158
}
158159

159160
func (*weather) Name() string {
@@ -164,6 +165,15 @@ func (*weather) Description() string {
164165
return "Get the weather in a location"
165166
}
166167

167-
func (*weather) Run(ctx context.Context) (any, error) {
168-
return nil, llm.ErrNotImplemented
168+
func (weather *weather) String() string {
169+
data, err := json.MarshalIndent(weather, "", " ")
170+
if err != nil {
171+
return err.Error()
172+
}
173+
return string(data)
174+
}
175+
176+
func (weather *weather) Run(ctx context.Context) (any, error) {
177+
log.Println("weather_in_location", "=>", weather)
178+
return "very sunny today", nil
169179
}

pkg/anthropic/opt.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ func apply(opts ...llm.Opt) (*opt, error) {
4444
}
4545

4646
////////////////////////////////////////////////////////////////////////////////
47-
// PUBLIC METHODS
47+
// PRIVATE METHODS
4848

49-
func (o *opt) Tools() []llm.Tool {
49+
func (o *opt) tools(agent llm.Agent) []llm.Tool {
5050
if o.toolkit == nil {
5151
return nil
5252
} else {
53-
return o.toolkit.Tools()
53+
return o.toolkit.Tools(agent)
5454
}
5555
}
5656

pkg/anthropic/session_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ func Test_session_002(t *testing.T) {
7777
session := model.Context(anthropic.WithToolKit(toolkit))
7878
assert.NotNil(session)
7979

80-
err = session.FromUser(context.TODO(), "What is today's weather?")
80+
err = session.FromUser(context.TODO(), "What is today's weather, in Berlin?")
8181
if !assert.NoError(err) {
8282
t.FailNow()
8383
}
8484

85-
toolcalls := session.ToolCalls()
86-
assert.NotEmpty(toolcalls)
87-
t.Log("TOOLCALLS", toolcalls)
85+
err := toolkit.Run(context.TODO(), session.ToolCalls())
86+
if !assert.NoError(err) {
87+
t.FailNow()
88+
}
8889
})
8990
}

pkg/ollama/chat.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func (r Response) String() string {
5050
type reqChat struct {
5151
Model string `json:"model"`
5252
Messages []*MessageMeta `json:"messages"`
53-
Tools []*Tool `json:"tools,omitempty"`
53+
Tools []ToolFunction `json:"tools,omitempty"`
5454
Format string `json:"format,omitempty"`
5555
Options map[string]interface{} `json:"options,omitempty"`
5656
Stream bool `json:"stream"`
@@ -68,7 +68,7 @@ func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm.
6868
req, err := client.NewJSONRequest(reqChat{
6969
Model: prompt.(*session).model.Name(),
7070
Messages: prompt.(*session).seq,
71-
Tools: opt.tools,
71+
Tools: opt.tools(ollama),
7272
Format: opt.format,
7373
Options: opt.options,
7474
Stream: opt.stream,

pkg/ollama/chat_test.go

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ package ollama_test
22

33
import (
44
"context"
5+
"encoding/json"
6+
"log"
57
"os"
68
"testing"
79

810
// Packages
911
opts "github.com/mutablelogic/go-client"
1012
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
13+
"github.com/mutablelogic/go-llm/pkg/tool"
1114
assert "github.com/stretchr/testify/assert"
1215
)
1316

@@ -60,13 +63,17 @@ func Test_chat_002(t *testing.T) {
6063
t.FailNow()
6164
}
6265

66+
// Make a toolkit
67+
toolkit := tool.NewToolKit()
68+
if err := toolkit.Register(new(weather)); err != nil {
69+
t.FailNow()
70+
}
71+
6372
t.Run("Tools", func(t *testing.T) {
6473
assert := assert.New(t)
6574
response, err := client.Chat(context.TODO(),
6675
model.UserPrompt("what is the weather in berlin?"),
67-
ollama.WithTool(ollama.MustTool("get_weather", "Return weather conditions in a location", struct {
68-
Location string `help:"Location to get weather for" required:""`
69-
}{})),
76+
ollama.WithToolKit(toolkit),
7077
)
7178
if !assert.NoError(err) {
7279
t.FailNow()
@@ -108,3 +115,31 @@ func Test_chat_003(t *testing.T) {
108115
t.Log(response)
109116
})
110117
}
118+
119+
////////////////////////////////////////////////////////////////////////////////
120+
// TOOLS
121+
122+
type weather struct {
123+
Location string `json:"location" name:"location" help:"The location to get the weather for" required:"true"`
124+
}
125+
126+
func (*weather) Name() string {
127+
return "weather_in_location"
128+
}
129+
130+
func (*weather) Description() string {
131+
return "Get the weather in a location"
132+
}
133+
134+
func (weather *weather) String() string {
135+
data, err := json.MarshalIndent(weather, "", " ")
136+
if err != nil {
137+
return err.Error()
138+
}
139+
return string(data)
140+
}
141+
142+
func (weather *weather) Run(ctx context.Context) (any, error) {
143+
log.Println("weather_in_location", "=>", weather)
144+
return "very sunny today", nil
145+
}

pkg/ollama/message.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ollama
22

3+
import llm "github.com/mutablelogic/go-llm"
4+
35
///////////////////////////////////////////////////////////////////////////////
46
// TYPES
57

@@ -12,5 +14,21 @@ type MessageMeta struct {
1214
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // Tool calls from the assistant
1315
}
1416

17+
type ToolCall struct {
18+
Function ToolCallFunction `json:"function"`
19+
}
20+
21+
type ToolCallFunction struct {
22+
Index int `json:"index,omitempty"`
23+
Name string `json:"name"`
24+
Arguments map[string]any `json:"arguments"`
25+
}
26+
1527
// Data represents the raw binary data of an image file.
1628
type Data []byte
29+
30+
// ToolFunction
31+
type ToolFunction struct {
32+
Type string `json:"type"` // function
33+
Function llm.Tool `json:"function"`
34+
}

pkg/ollama/opt.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
// Packages
88
llm "github.com/mutablelogic/go-llm"
9+
"github.com/mutablelogic/go-llm/pkg/tool"
910
)
1011

1112
////////////////////////////////////////////////////////////////////////////////
@@ -20,8 +21,8 @@ type opt struct {
2021
truncate *bool
2122
keepalive *time.Duration
2223
options map[string]any
23-
tools []*Tool
2424
data []Data
25+
toolkit *tool.ToolKit // Toolkit for tools
2526
}
2627

2728
////////////////////////////////////////////////////////////////////////////////
@@ -38,6 +39,20 @@ func apply(opts ...llm.Opt) (*opt, error) {
3839
return o, nil
3940
}
4041

42+
////////////////////////////////////////////////////////////////////////////////
43+
// PRIVATE METHODS
44+
45+
func (o *opt) tools(agent llm.Agent) []ToolFunction {
46+
if o.toolkit == nil {
47+
return nil
48+
}
49+
var result []ToolFunction
50+
for _, t := range o.toolkit.Tools(agent) {
51+
result = append(result, ToolFunction{Type: "function", Function: t})
52+
}
53+
return result
54+
}
55+
4156
////////////////////////////////////////////////////////////////////////////////
4257
// OPTIONS
4358

@@ -85,24 +100,17 @@ func WithStream(fn func(*Response)) llm.Opt {
85100
if fn == nil {
86101
return llm.ErrBadParameter.With("callback required")
87102
}
88-
if len(o.(*opt).tools) > 0 {
89-
return llm.ErrBadParameter.With("streaming not supported with tools")
90-
}
91103
o.(*opt).stream = true
92104
o.(*opt).chatcallback = fn
93105
return nil
94106
}
95107
}
96108

97-
// Chat: Append a tool to the request.
98-
func WithTool(v *Tool) llm.Opt {
109+
// Chat: Append a toolkit to the request
110+
func WithToolKit(v *tool.ToolKit) llm.Opt {
99111
return func(o any) error {
100-
// We can't use streaming when tools are included
101-
if o.(*opt).stream {
102-
return llm.ErrBadParameter.With("tools not supported with streaming")
103-
}
104112
if v != nil {
105-
o.(*opt).tools = append(o.(*opt).tools, v)
113+
o.(*opt).toolkit = v
106114
}
107115
return nil
108116
}

pkg/ollama/session.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package ollama
33
import (
44
"context"
55
"encoding/json"
6+
"fmt"
67

78
// Packages
89
llm "github.com/mutablelogic/go-llm"
10+
"github.com/mutablelogic/go-llm/pkg/tool"
911
)
1012

1113
///////////////////////////////////////////////////////////////////////////////
@@ -144,7 +146,7 @@ func (session *session) ToolCalls() []llm.ToolCall {
144146
// Gather tool calls
145147
var result []llm.ToolCall
146148
for _, call := range meta.ToolCalls {
147-
result = append(result, NewToolCall(call))
149+
result = append(result, tool.NewCall(fmt.Sprint(call.Function.Index), call.Function.Name, call.Function.Arguments))
148150
}
149151
return result
150152
}

pkg/ollama/session_test.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
// Packages
99
opts "github.com/mutablelogic/go-client"
1010
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
11+
"github.com/mutablelogic/go-llm/pkg/tool"
1112
assert "github.com/stretchr/testify/assert"
1213
)
1314

@@ -66,16 +67,17 @@ func Test_session_002(t *testing.T) {
6667
t.FailNow()
6768
}
6869

70+
// Make a toolkit
71+
toolkit := tool.NewToolKit()
72+
if err := toolkit.Register(new(weather)); err != nil {
73+
t.FailNow()
74+
}
75+
6976
// Session with a tool call
7077
t.Run("toolcall", func(t *testing.T) {
7178
assert := assert.New(t)
7279

73-
tool, err := ollama.NewTool("get_weather", "Return the current weather", nil)
74-
if !assert.NoError(err) {
75-
t.FailNow()
76-
}
77-
78-
session := model.Context(ollama.WithTool(tool))
80+
session := model.Context(ollama.WithToolKit(toolkit))
7981
assert.NotNil(session)
8082

8183
err = session.FromUser(context.TODO(), "What is today's weather?")

0 commit comments

Comments
 (0)