Skip to content

Commit a05b1f8

Browse files
committed
Updated tool
1 parent 8b1151e commit a05b1f8

File tree

9 files changed

+148
-28
lines changed

9 files changed

+148
-28
lines changed

pkg/anthropic/message.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"strings"
1010

11+
// Packages
1112
llm "github.com/mutablelogic/go-llm"
1213
)
1314

pkg/anthropic/messages_test.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77

88
// Packages
99
opts "github.com/mutablelogic/go-client"
10+
"github.com/mutablelogic/go-llm"
1011
anthropic "github.com/mutablelogic/go-llm/pkg/anthropic"
12+
"github.com/mutablelogic/go-llm/pkg/tool"
1113
assert "github.com/stretchr/testify/assert"
1214
)
1315

@@ -107,14 +109,12 @@ func Test_messages_004(t *testing.T) {
107109
t.FailNow()
108110
}
109111

110-
weather, err := anthropic.NewTool("weather_in_location", "Get the weather in a location", struct {
111-
Location string `name:"location" help:"The location to get the weather for" required:"true"`
112-
}{})
113-
if !assert.NoError(err) {
112+
toolkit := tool.NewToolKit()
113+
if err := toolkit.Register(new(weather)); !assert.NoError(err) {
114114
t.FailNow()
115115
}
116116

117-
response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), anthropic.WithTool(weather))
117+
response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), anthropic.WithToolKit(toolkit))
118118
if assert.NoError(err) {
119119
t.Log(response)
120120
}
@@ -136,17 +136,34 @@ func Test_messages_005(t *testing.T) {
136136
t.FailNow()
137137
}
138138

139-
weather, err := anthropic.NewTool("weather_in_location", "Get the weather in a location", struct {
140-
Location string `name:"location" help:"The location to get the weather for" required:"true"`
141-
}{})
142-
if !assert.NoError(err) {
139+
toolkit := tool.NewToolKit()
140+
if err := toolkit.Register(new(weather)); !assert.NoError(err) {
143141
t.FailNow()
144142
}
145143

146144
response, err := client.Messages(context.TODO(), model.UserPrompt("why is the sky blue"), anthropic.WithStream(func(r *anthropic.Response) {
147145
t.Log(r)
148-
}), anthropic.WithTool(weather))
146+
}), anthropic.WithToolKit(toolkit))
149147
if assert.NoError(err) {
150148
t.Log(response)
151149
}
152150
}
151+
152+
////////////////////////////////////////////////////////////////////////////////
153+
// TOOLS
154+
155+
type weather struct {
156+
Location string `name:"location" help:"The location to get the weather for" required:"true"`
157+
}
158+
159+
func (*weather) Name() string {
160+
return "weather_in_location"
161+
}
162+
163+
func (*weather) Description() string {
164+
return "Get the weather in a location"
165+
}
166+
167+
func (*weather) Run(ctx context.Context) (any, error) {
168+
return nil, llm.ErrNotImplemented
169+
}

pkg/anthropic/session.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
// Packages
88
llm "github.com/mutablelogic/go-llm"
9+
tool "github.com/mutablelogic/go-llm/pkg/tool"
910
)
1011

1112
//////////////////////////////////////////////////////////////////
@@ -103,7 +104,7 @@ func (session *session) ToolCalls() []llm.ToolCall {
103104
var result []llm.ToolCall
104105
for _, content := range meta.Content {
105106
if content.Type == "tool_use" {
106-
result = append(result, NewToolCall(content))
107+
result = append(result, tool.NewCall(content.ContentTool.Id, content.ContentTool.Name, content.ContentTool.Input))
107108
}
108109
}
109110
return result

pkg/anthropic/session_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
// Packages
99
opts "github.com/mutablelogic/go-client"
1010
anthropic "github.com/mutablelogic/go-llm/pkg/anthropic"
11+
"github.com/mutablelogic/go-llm/pkg/tool"
1112
assert "github.com/stretchr/testify/assert"
1213
)
1314

@@ -68,12 +69,12 @@ func Test_session_002(t *testing.T) {
6869
t.Run("toolcall", func(t *testing.T) {
6970
assert := assert.New(t)
7071

71-
tool, err := anthropic.NewTool("get_weather", "Return the current weather", nil)
72-
if !assert.NoError(err) {
72+
toolkit := tool.NewToolKit()
73+
if err := toolkit.Register(new(weather)); !assert.NoError(err) {
7374
t.FailNow()
7475
}
7576

76-
session := model.Context(anthropic.WithTool(tool))
77+
session := model.Context(anthropic.WithToolKit(toolkit))
7778
assert.NotNil(session)
7879

7980
err = session.FromUser(context.TODO(), "What is today's weather?")
@@ -83,6 +84,6 @@ func Test_session_002(t *testing.T) {
8384

8485
toolcalls := session.ToolCalls()
8586
assert.NotEmpty(toolcalls)
86-
t.Log(toolcalls)
87+
t.Log("TOOLCALLS", toolcalls)
8788
})
8889
}

pkg/tool/call.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package tool
2+
3+
import (
4+
// Packages
5+
"bytes"
6+
"encoding/json"
7+
)
8+
9+
///////////////////////////////////////////////////////////////////////////////
10+
// TYPES
11+
12+
type CallMeta struct {
13+
Name string `json:"name"`
14+
Id string `json:"id,omitempty"`
15+
Input map[string]any `json:"input,omitempty"`
16+
}
17+
18+
type call struct {
19+
meta CallMeta
20+
}
21+
22+
///////////////////////////////////////////////////////////////////////////////
23+
// LIFECYCLE
24+
25+
func NewCall(name, id string, input map[string]any) *call {
26+
return &call{
27+
meta: CallMeta{
28+
Name: name,
29+
Id: id,
30+
Input: input,
31+
},
32+
}
33+
}
34+
35+
///////////////////////////////////////////////////////////////////////////////
36+
// STRINGIFY
37+
38+
func (t *call) String() string {
39+
data, err := json.MarshalIndent(t.meta, "", " ")
40+
if err != nil {
41+
return err.Error()
42+
}
43+
return string(data)
44+
}
45+
46+
///////////////////////////////////////////////////////////////////////////////
47+
// PUBLIC METHODS
48+
49+
func (t *call) Name() string {
50+
return t.meta.Name
51+
}
52+
53+
func (t *call) Id() string {
54+
return t.meta.Id
55+
}
56+
57+
func (t *call) Decode(v any) error {
58+
var buf bytes.Buffer
59+
if data, err := json.Marshal(t.meta.Input); err != nil {
60+
return err
61+
} else if err := json.Unmarshal(data, &buf); err != nil {
62+
return err
63+
}
64+
// Return success
65+
return nil
66+
}

pkg/tool/tollcall.go

Lines changed: 0 additions & 9 deletions
This file was deleted.

pkg/tool/tool.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type ToolParameter struct {
2323

2424
type tool struct {
2525
ToolMeta
26-
proto reflect.Type // Prototype for parameter return
26+
proto reflect.Type
2727
}
2828

2929
type ToolMeta struct {

pkg/tool/toolkit.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package tool
22

33
import (
4+
// Packages
5+
"context"
6+
"errors"
7+
"fmt"
48
"reflect"
9+
"sync"
510

6-
// Packages
711
llm "github.com/mutablelogic/go-llm"
812
)
913

@@ -84,3 +88,41 @@ func (kit *ToolKit) Register(v llm.Tool) error {
8488
// Return success
8589
return nil
8690
}
91+
92+
// Run calls a tool in the toolkit
93+
func (kit *ToolKit) Run(ctx context.Context, calls []llm.ToolCall) error {
94+
var wg sync.WaitGroup
95+
var result error
96+
97+
for _, call := range calls {
98+
wg.Add(1)
99+
go func(call llm.ToolCall) {
100+
defer wg.Done()
101+
102+
// Get the tool
103+
name := call.Name()
104+
t, exists := kit.functions[name]
105+
if !exists {
106+
result = errors.Join(result, llm.ErrNotFound.Withf("tool %q not found", name))
107+
}
108+
109+
// Make a new object to decode into
110+
v := reflect.New(t.proto).Interface()
111+
112+
// Decode the input and run the tool
113+
if err := call.Decode(&v); err != nil {
114+
result = errors.Join(result, err)
115+
} else if out, err := t.Run(ctx, v); err != nil {
116+
result = errors.Join(result, err)
117+
} else {
118+
fmt.Println("result of calling", call, "is", out)
119+
}
120+
}(call)
121+
}
122+
123+
// Wait for all calls to complete
124+
wg.Wait()
125+
126+
// Return any errors
127+
return result
128+
}

tool.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
////////////////////////////////////////////////////////////////////////////////
88
// TYPES
99

10+
// Definition of a tool
1011
type Tool interface {
1112
// The name of the tool
1213
Name() string
@@ -26,6 +27,6 @@ type ToolCall interface {
2627
// The tool identifier
2728
Id() string
2829

29-
// The calling parameters
30-
Params() any
30+
// Decode the calling parameters
31+
Decode(v any) error
3132
}

0 commit comments

Comments
 (0)