Skip to content

Commit 5096045

Browse files
committed
Updated tool calling
1 parent 25a30d0 commit 5096045

22 files changed

+371
-209
lines changed

cmd/agent/chat.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@ import (
1515
////////////////////////////////////////////////////////////////////////////////
1616
// TYPES
1717

18-
type GenerateCmd struct {
18+
type ChatCmd struct {
1919
Model string `arg:"" help:"Model name"`
2020
NoStream bool `flag:"nostream" help:"Disable streaming"`
21+
System string `flag:"system" help:"Set the system prompt"`
2122
}
2223

2324
////////////////////////////////////////////////////////////////////////////////
2425
// PUBLIC METHODS
2526

26-
func (cmd *GenerateCmd) Run(globals *Globals) error {
27+
func (cmd *ChatCmd) Run(globals *Globals) error {
2728
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
2829
// Get the model
2930
a, ok := client.(*agent.Agent)
@@ -35,8 +36,22 @@ func (cmd *GenerateCmd) Run(globals *Globals) error {
3536
return err
3637
}
3738

39+
// Set the options
40+
opts := []llm.Opt{}
41+
if !cmd.NoStream {
42+
opts = append(opts, llm.WithStream(func(cc llm.ContextContent) {
43+
fmt.Println(cc)
44+
}))
45+
}
46+
if cmd.System != "" {
47+
opts = append(opts, llm.WithSystemPrompt(cmd.System))
48+
}
49+
if globals.toolkit != nil {
50+
opts = append(opts, llm.WithToolKit(globals.toolkit))
51+
}
52+
3853
// Create a session
39-
session := model.Context()
54+
session := model.Context(opts...)
4055

4156
// Continue looping until end of input
4257
for {
@@ -57,7 +72,16 @@ func (cmd *GenerateCmd) Run(globals *Globals) error {
5772
if err := session.FromUser(ctx, input); err != nil {
5873
return err
5974
}
75+
6076
fmt.Println(session.Text())
77+
78+
// If there are tool calls, then do these
79+
calls := session.ToolCalls()
80+
if results, err := globals.toolkit.Run(ctx, calls...); err != nil {
81+
return err
82+
} else {
83+
fmt.Println(results)
84+
}
6185
}
6286
})
6387
}

cmd/agent/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ type CLI struct {
6060

6161
// Commands
6262
Download DownloadModelCmd `cmd:"" help:"Download a model"`
63-
Generate GenerateCmd `cmd:"chat" help:"Start a chat session"`
63+
Chat ChatCmd `cmd:"" help:"Start a chat session"`
6464
}
6565

6666
////////////////////////////////////////////////////////////////////////////////

context.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import "context"
55
//////////////////////////////////////////////////////////////////
66
// TYPES
77

8-
// Context is fed to the agent to generate a response
9-
type Context interface {
8+
// ContextContent is the content of the last context message
9+
type ContextContent interface {
1010
// Return the current session role, which can be system, assistant, user, tool, tool_result, ...
1111
Role() string
1212

@@ -15,6 +15,11 @@ type Context interface {
1515

1616
// Return the current session tool calls, or empty if no tool calls were made
1717
ToolCalls() []ToolCall
18+
}
19+
20+
// Context is fed to the agent to generate a response
21+
type Context interface {
22+
ContextContent
1823

1924
// Generate a response from a user prompt (with attachments and
2025
// other empheral options

opt.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ type Opt func(*Opts) error
1313

1414
// set of options
1515
type Opts struct {
16-
agents map[string]Agent // Set of agents
17-
toolkit ToolKit // Toolkit for tools
18-
callback func(Context) // Streaming callback
19-
attachments []*Attachment // Attachments
20-
options map[string]any // Additional options
16+
agents map[string]Agent // Set of agents
17+
toolkit ToolKit // Toolkit for tools
18+
callback func(ContextContent) // Streaming callback
19+
attachments []*Attachment // Attachments
20+
system string // System prompt
21+
options map[string]any // Additional options
2122
}
2223

2324
////////////////////////////////////////////////////////////////////////////////
@@ -45,10 +46,15 @@ func (o *Opts) ToolKit() ToolKit {
4546
}
4647

4748
// Return the stream function
48-
func (o *Opts) StreamFn() func(Context) {
49+
func (o *Opts) StreamFn() func(ContextContent) {
4950
return o.callback
5051
}
5152

53+
// Return the system prompt
54+
func (o *Opts) SystemPrompt() string {
55+
return o.system
56+
}
57+
5258
// Return the array of registered agents
5359
func (o *Opts) Agents() []Agent {
5460
result := make([]Agent, 0, len(o.agents))
@@ -102,6 +108,26 @@ func (o *Opts) GetBool(key string) bool {
102108
return false
103109
}
104110

111+
// Get an option value as an unsigned integer
112+
func (o *Opts) GetUint64(key string) uint64 {
113+
if value, exists := o.options[key]; exists {
114+
if v, ok := value.(uint64); ok {
115+
return v
116+
}
117+
}
118+
return 0
119+
}
120+
121+
// Get an option value as a float64
122+
func (o *Opts) GetFloat64(key string) float64 {
123+
if value, exists := o.options[key]; exists {
124+
if v, ok := value.(float64); ok {
125+
return v
126+
}
127+
}
128+
return 0
129+
}
130+
105131
// Get an option value as a duration
106132
func (o *Opts) GetDuration(key string) time.Duration {
107133
if value, exists := o.options[key]; exists {
@@ -124,7 +150,7 @@ func WithToolKit(toolkit ToolKit) Opt {
124150
}
125151

126152
// Set chat streaming function
127-
func WithStream(fn func(Context)) Opt {
153+
func WithStream(fn func(ContextContent)) Opt {
128154
return func(o *Opts) error {
129155
o.callback = fn
130156
return nil
@@ -196,3 +222,11 @@ func WithTopK(v uint) Opt {
196222
return nil
197223
}
198224
}
225+
226+
// Set system prompt
227+
func WithSystemPrompt(v string) Opt {
228+
return func(o *Opts) error {
229+
o.system = v
230+
return nil
231+
}
232+
}

pkg/agent/agent.go

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func (m model) String() string {
5656
// PUBLIC METHODS
5757

5858
// Return a list of tool names
59-
func (a *Agent) Tools() []string {
59+
func (a *Agent) ToolNames() []string {
6060
if a.ToolKit() == nil {
6161
return nil
6262
}
@@ -67,11 +67,35 @@ func (a *Agent) Tools() []string {
6767
return result
6868
}
6969

70+
// Return a list of agent names
71+
func (a *Agent) AgentNames() []string {
72+
var result []string
73+
for _, a := range a.Agents() {
74+
result = append(result, a.Name())
75+
}
76+
return result
77+
}
78+
79+
// Return a list of agents
80+
func (a *Agent) AgentsWithName(name ...string) []llm.Agent {
81+
all := a.Agents()
82+
if len(name) == 0 {
83+
return all
84+
}
85+
result := make([]llm.Agent, 0, len(name))
86+
for _, a := range all {
87+
if slices.Contains(name, a.Name()) {
88+
result = append(result, a)
89+
}
90+
}
91+
return result
92+
}
93+
7094
// Return a comma-separated list of agent names
7195
func (a *Agent) Name() string {
7296
var keys []string
73-
for key := range a.Agents() {
74-
keys = append(keys, key)
97+
for _, agent := range a.Agents() {
98+
keys = append(keys, agent.Name())
7599
}
76100
return strings.Join(keys, ",")
77101
}
@@ -82,22 +106,13 @@ func (a *Agent) Models(ctx context.Context) ([]llm.Model, error) {
82106
}
83107

84108
// Return the models from list of agents
85-
func (a *Agent) ListModels(ctx context.Context, agents ...string) ([]llm.Model, error) {
109+
func (a *Agent) ListModels(ctx context.Context, names ...string) ([]llm.Model, error) {
86110
var result error
87111

88-
// Ensure all agents are valid
112+
// Gather models from agents
113+
agents := a.AgentsWithName(names...)
114+
models := make([]llm.Model, 0, len(agents)*10)
89115
for _, agent := range agents {
90-
if _, exists := a.agents[agent]; !exists {
91-
result = errors.Join(result, llm.ErrNotFound.Withf("agent %q", agent))
92-
}
93-
}
94-
95-
// Gather models from all agents
96-
models := make([]llm.Model, 0, 100)
97-
for _, agent := range a.agents {
98-
if len(agents) > 0 && !slices.Contains(agents, agent.Name()) {
99-
continue
100-
}
101116
agentmodels, err := modelsForAgent(ctx, agent)
102117
if err != nil {
103118
result = errors.Join(result, err)
@@ -113,24 +128,12 @@ func (a *Agent) ListModels(ctx context.Context, agents ...string) ([]llm.Model,
113128

114129
// Return a model by name. If no agents are specified, then all agents are considered.
115130
// If multiple agents are specified, then the first model found is returned.
116-
func (a *Agent) GetModel(ctx context.Context, name string, agents ...string) (llm.Model, error) {
117-
if len(agents) == 0 {
118-
for _, agent := range a.agents {
119-
agents = append(agents, agent.Name())
120-
}
121-
}
122-
123-
// Ensure all agents are valid
131+
func (a *Agent) GetModel(ctx context.Context, name string, agentnames ...string) (llm.Model, error) {
124132
var result error
125-
for _, agent := range agents {
126-
if _, exists := a.agents[agent]; !exists {
127-
result = errors.Join(result, llm.ErrNotFound.Withf("agent %q", agent))
128-
}
129-
}
130133

131-
// Gather models from agents
134+
agents := a.AgentsWithName(agentnames...)
132135
for _, agent := range agents {
133-
models, err := modelsForAgent(ctx, a.agents[agent], name)
136+
models, err := modelsForAgent(ctx, agent, name)
134137
if err != nil {
135138
result = errors.Join(result, err)
136139
continue

pkg/agent/opt.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package agent
22

33
import (
4-
54
// Packages
65
client "github.com/mutablelogic/go-client"
76
llm "github.com/mutablelogic/go-llm"
@@ -13,23 +12,23 @@ import (
1312
// PUBLIC METHODS
1413

1514
func WithOllama(endpoint string, opts ...client.ClientOpt) llm.Opt {
16-
return func(o any) error {
15+
return func(o *llm.Opts) error {
1716
client, err := ollama.New(endpoint, opts...)
1817
if err != nil {
1918
return err
2019
} else {
21-
return llm.WithAgent(client)
20+
return llm.WithAgent(client)(o)
2221
}
2322
}
2423
}
2524

2625
func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt {
27-
return func(o any) error {
26+
return func(o *llm.Opts) error {
2827
client, err := anthropic.New(key, opts...)
2928
if err != nil {
3029
return err
3130
} else {
32-
return llm.WithAgent(client)
31+
return llm.WithAgent(client)(o)
3332
}
3433
}
3534
}

pkg/anthropic/client.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ anthropic implements an API client for anthropic (https://docs.anthropic.com/en/
44
package anthropic
55

66
import (
7-
87
// Packages
98
client "github.com/mutablelogic/go-client"
109
llm "github.com/mutablelogic/go-llm"

0 commit comments

Comments
 (0)