Skip to content

Commit 3334010

Browse files
committed
Updated llm
1 parent 31a65ca commit 3334010

File tree

17 files changed

+442
-229
lines changed

17 files changed

+442
-229
lines changed

cmd/agent/main.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ type CLI struct {
4444
Globals
4545

4646
// Agents, Models and Tools
47-
Agents ListAgentsCmd `cmd:"" help:"Return a list of agents"`
48-
Models ListModelsCmd `cmd:"" help:"Return a list of models"`
49-
Generate GenerateCmd `cmd:"" help:"Generate a response"`
47+
Agents ListAgentsCmd `cmd:"" help:"Return a list of agents"`
48+
Models ListModelsCmd `cmd:"" help:"Return a list of models"`
49+
Download DownloadModelCmd `cmd:"" help:"Download a model"`
50+
Generate GenerateCmd `cmd:"" help:"Generate a response"`
5051
}
5152

5253
////////////////////////////////////////////////////////////////////////////////

cmd/agent/models.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
// Packages
88
llm "github.com/mutablelogic/go-llm"
99
agent "github.com/mutablelogic/go-llm/pkg/agent"
10+
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
1011
)
1112

1213
////////////////////////////////////////////////////////////////////////////////
@@ -18,6 +19,11 @@ type ListModelsCmd struct {
1819

1920
type ListAgentsCmd struct{}
2021

22+
type DownloadModelCmd struct {
23+
Agent string `arg:"" help:"Agent name"`
24+
Model string `arg:"" help:"Model name"`
25+
}
26+
2127
////////////////////////////////////////////////////////////////////////////////
2228
// PUBLIC METHODS
2329

@@ -49,9 +55,43 @@ func (*ListAgentsCmd) Run(globals *Globals) error {
4955
})
5056
}
5157

58+
func (cmd *DownloadModelCmd) Run(globals *Globals) error {
59+
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
60+
agent := getagent(client, cmd.Agent)
61+
if agent == nil {
62+
return fmt.Errorf("No agents found with name %q", cmd.Agent)
63+
}
64+
// Download the model
65+
switch agent.Name() {
66+
case "ollama":
67+
model, err := agent.(*ollama.Client).PullModel(ctx, cmd.Model)
68+
if err != nil {
69+
return err
70+
}
71+
fmt.Println(model)
72+
default:
73+
return fmt.Errorf("Agent %q does not support model download", agent.Name())
74+
}
75+
return nil
76+
})
77+
}
78+
5279
////////////////////////////////////////////////////////////////////////////////
5380
// PRIVATE METHODS
5481

5582
func runagent(globals *Globals, fn func(ctx context.Context, agent llm.Agent) error) error {
5683
return fn(globals.ctx, globals.agent)
5784
}
85+
86+
func getagent(client llm.Agent, name string) llm.Agent {
87+
agent, ok := client.(*agent.Agent)
88+
if !ok {
89+
return nil
90+
}
91+
for _, agent := range agent.Agents() {
92+
if agent.Name() == name {
93+
return agent
94+
}
95+
}
96+
return nil
97+
}

context.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ type Context interface {
1616
// Generate a response from a user prompt (with attachments)
1717
FromUser(context.Context, string, ...Opt) (Context, error)
1818

19-
// Generate a response from a tool result
20-
FromTool(context.Context, ...any) (Context, error)
19+
// Generate a response from a tool, passing the call identifier or funtion name, and the result
20+
FromTool(context.Context, string, any) (Context, error)
2121
}

go.mod

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ go 1.23.5
44

55
require (
66
github.com/alecthomas/kong v1.7.0
7-
github.com/mutablelogic/go-client v1.0.9
8-
github.com/stretchr/testify v1.9.0
7+
github.com/mutablelogic/go-client v1.0.10
8+
github.com/stretchr/testify v1.10.0
9+
golang.org/x/term v0.28.0
910
)
1011

1112
require (
1213
github.com/davecgh/go-spew v1.1.1 // indirect
1314
github.com/djthorpe/go-errors v1.0.3 // indirect
1415
github.com/pmezard/go-difflib v1.0.0 // indirect
16+
golang.org/x/sys v0.29.0 // indirect
1517
gopkg.in/yaml.v3 v3.0.1 // indirect
1618
)

go.sum

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@ github.com/djthorpe/go-errors v1.0.3 h1:GZeMPkC1mx2vteXLI/gvxZS0Ee9zxzwD1mcYyKU5
1010
github.com/djthorpe/go-errors v1.0.3/go.mod h1:HtfrZnMd6HsX75Mtbv9Qcnn0BqOrrFArvCaj3RMnZhY=
1111
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
1212
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
13-
github.com/mutablelogic/go-client v1.0.9 h1:Eh4sjQOFDldP/L3IizqkcOD3WigZR+u1VaHTUM4ujYw=
14-
github.com/mutablelogic/go-client v1.0.9/go.mod h1:VLyB8j8IBJSK/FXvvqhmq93PRWDKkyLu8R7V2Vudb6A=
13+
github.com/mutablelogic/go-client v1.0.10 h1:d4t8irXlGNQrQS/+FoUht+1RnjL9lBaf1e2UasN3ifE=
14+
github.com/mutablelogic/go-client v1.0.10/go.mod h1:XbG8KGo2Efi7PGxXs7rhYxYhLeXL6aCSo6sz0mVchiw=
1515
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
1616
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
17-
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
18-
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
17+
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
18+
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
19+
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
20+
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
21+
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
22+
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
1923
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
2024
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
2125
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

model.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,8 @@ type Model interface {
1010
// Return am empty session context object for the model,
1111
// setting session options
1212
Context(...Opt) (Context, error)
13+
14+
// Convenience method to create a session context object
15+
// with a user prompt, which panics on error
16+
MustUserPrompt(string, ...Opt) Context
1317
}

pkg/agent/agent.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ func (m model) String() string {
5656
// PUBLIC METHODS
5757

5858
// Return a list of agent names
59-
func (a *Agent) Agents() []string {
60-
var keys []string
61-
for k := range a.agents {
62-
keys = append(keys, k)
59+
func (a *Agent) Agents() []llm.Agent {
60+
var result []llm.Agent
61+
for _, v := range a.agents {
62+
result = append(result, v)
6363
}
64-
return keys
64+
return result
6565
}
6666

6767
// Return a list of tool names
@@ -75,7 +75,11 @@ func (a *Agent) Tools() []string {
7575

7676
// Return a comma-separated list of agent names
7777
func (a *Agent) Name() string {
78-
return strings.Join(a.Agents(), ",")
78+
var keys []string
79+
for key := range a.agents {
80+
keys = append(keys, key)
81+
}
82+
return strings.Join(keys, ",")
7983
}
8084

8185
// Return the models from all agents

pkg/ollama/chat.go

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ import (
1515

1616
// Chat Response
1717
type Response struct {
18-
Model string `json:"model"`
19-
CreatedAt time.Time `json:"created_at"`
20-
Message MessageMeta `json:"message"`
21-
Done bool `json:"done"`
22-
Reason string `json:"done_reason,omitempty"`
23-
Context []*MessageMeta `json:"-"`
18+
Model string `json:"model"`
19+
CreatedAt time.Time `json:"created_at"`
20+
Message MessageMeta `json:"message"`
21+
Done bool `json:"done"`
22+
Reason string `json:"done_reason,omitempty"`
2423
Metrics
2524
}
2625

@@ -58,20 +57,16 @@ type reqChat struct {
5857
KeepAlive *time.Duration `json:"keep_alive,omitempty"`
5958
}
6059

61-
func (ollama *Client) Chat(ctx context.Context, model string, prompt llm.Context, opts ...llm.Opt) (*Response, error) {
60+
func (ollama *Client) Chat(ctx context.Context, prompt llm.Context, opts ...llm.Opt) (*Response, error) {
6261
// Apply options
6362
opt, err := apply(opts...)
6463
if err != nil {
6564
return nil, err
6665
}
6766

68-
// Make a new sequence of messages
69-
seq := make([]*MessageMeta, len(prompt.(*session).seq))
70-
copy(seq, prompt.(*session).seq)
71-
7267
// Request
7368
req, err := client.NewJSONRequest(reqChat{
74-
Model: model,
69+
Model: prompt.(*session).model.Name(),
7570
Messages: prompt.(*session).seq,
7671
Tools: opt.tools,
7772
Format: opt.format,
@@ -84,20 +79,34 @@ func (ollama *Client) Chat(ctx context.Context, model string, prompt llm.Context
8479
}
8580

8681
// Response
87-
var response Response
88-
if err := ollama.DoWithContext(ctx, req, &response, client.OptPath("chat"), client.OptJsonStreamCallback(func(v any) error {
89-
if v, ok := v.(*Response); ok && opt.chatcallback != nil {
90-
opt.chatcallback(v)
82+
var response, delta Response
83+
if err := ollama.DoWithContext(ctx, req, &delta, client.OptPath("chat"), client.OptJsonStreamCallback(func(v any) error {
84+
if v, ok := v.(*Response); !ok || v == nil {
85+
return llm.ErrConflict.Withf("Invalid stream response: %v", v)
86+
} else {
87+
response.Model = v.Model
88+
response.CreatedAt = v.CreatedAt
89+
response.Message.Role = v.Message.Role
90+
response.Message.Content += v.Message.Content
91+
if v.Done {
92+
response.Done = v.Done
93+
response.Metrics = v.Metrics
94+
response.Reason = v.Reason
95+
}
96+
}
97+
98+
if opt.chatcallback != nil {
99+
opt.chatcallback(&response)
91100
}
92101
return nil
93102
})); err != nil {
94103
return nil, err
95104
}
96105

97-
// Append the response message to the context
98-
prompt.(*session).seq = append(prompt.(*session).seq, &response.Message)
99-
response.Context = prompt.(*session).seq
100-
101-
// Return success
102-
return &response, nil
106+
// We return the delta or the response
107+
if opt.stream {
108+
return &response, nil
109+
} else {
110+
return &delta, nil
111+
}
103112
}

pkg/ollama/chat_test.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ func Test_chat_001(t *testing.T) {
1818
}
1919

2020
// Pull the model
21-
if err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) {
21+
model, err := client.PullModel(context.TODO(), "qwen:0.5b", ollama.WithPullStatus(func(status *ollama.PullStatus) {
2222
t.Log(status)
23-
})); err != nil {
23+
}))
24+
if err != nil {
2425
t.FailNow()
2526
}
2627

2728
t.Run("ChatStream", func(t *testing.T) {
2829
assert := assert.New(t)
29-
response, err := client.Chat(context.TODO(), "qwen:0.5b", client.UserPrompt("why is the sky blue?"), ollama.WithChatStream(func(stream *ollama.Response) {
30+
response, err := client.Chat(context.TODO(), model.MustUserPrompt("why is the sky blue?"), ollama.WithStream(func(stream *ollama.Response) {
3031
t.Log(stream)
3132
}))
3233
if !assert.NoError(err) {
@@ -37,7 +38,7 @@ func Test_chat_001(t *testing.T) {
3738

3839
t.Run("ChatNoStream", func(t *testing.T) {
3940
assert := assert.New(t)
40-
response, err := client.Chat(context.TODO(), "qwen:0.5b", client.UserPrompt("why is the sky green?"))
41+
response, err := client.Chat(context.TODO(), model.MustUserPrompt("why is the sky green?"))
4142
if !assert.NoError(err) {
4243
t.FailNow()
4344
}
@@ -52,16 +53,17 @@ func Test_chat_002(t *testing.T) {
5253
}
5354

5455
// Pull the model
55-
if err := client.PullModel(context.TODO(), "llama3.2:1b", ollama.WithPullStatus(func(status *ollama.PullStatus) {
56+
model, err := client.PullModel(context.TODO(), "llama3.2:1b", ollama.WithPullStatus(func(status *ollama.PullStatus) {
5657
t.Log(status)
57-
})); err != nil {
58+
}))
59+
if err != nil {
5860
t.FailNow()
5961
}
6062

6163
t.Run("Tools", func(t *testing.T) {
6264
assert := assert.New(t)
63-
response, err := client.Chat(context.TODO(), "llama3.2:1b",
64-
client.UserPrompt("what is the weather in berlin?"),
65+
response, err := client.Chat(context.TODO(),
66+
model.MustUserPrompt("what is the weather in berlin?"),
6567
ollama.WithTool(ollama.MustTool("get_weather", "Return weather conditions in a location", struct {
6668
Location string `help:"Location to get weather for" required:""`
6769
}{})),
@@ -79,16 +81,15 @@ func Test_chat_003(t *testing.T) {
7981
t.FailNow()
8082
}
8183

82-
// Delete model
83-
client.DeleteModel(context.TODO(), "llava")
84-
8584
// Pull the model
86-
if err := client.PullModel(context.TODO(), "llava", ollama.WithPullStatus(func(status *ollama.PullStatus) {
85+
model, err := client.PullModel(context.TODO(), "llava", ollama.WithPullStatus(func(status *ollama.PullStatus) {
8786
t.Log(status)
88-
})); err != nil {
87+
}))
88+
if err != nil {
8989
t.FailNow()
9090
}
9191

92+
// Explain the content of an image
9293
t.Run("Image", func(t *testing.T) {
9394
assert := assert.New(t)
9495

@@ -98,8 +99,8 @@ func Test_chat_003(t *testing.T) {
9899
}
99100
defer f.Close()
100101

101-
response, err := client.Chat(context.TODO(), "llava",
102-
client.UserPrompt("where was this photo taken?", ollama.WithData(f)),
102+
response, err := client.Chat(context.TODO(),
103+
model.MustUserPrompt("describe this photo to me", ollama.WithData(f)),
103104
)
104105
if !assert.NoError(err) {
105106
t.FailNow()

0 commit comments

Comments
 (0)