Skip to content

Commit eb5a34e

Browse files
committed
Added download code
1 parent e896bc7 commit eb5a34e

File tree

5 files changed

+97
-108
lines changed

5 files changed

+97
-108
lines changed

cmd/llm/chat.go

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99

1010
// Packages
1111
llm "github.com/mutablelogic/go-llm"
12-
agent "github.com/mutablelogic/go-llm/pkg/agent"
1312
)
1413

1514
////////////////////////////////////////////////////////////////////////////////
@@ -27,17 +26,7 @@ type ChatCmd struct {
2726
// PUBLIC METHODS
2827

2928
func (cmd *ChatCmd) Run(globals *Globals) error {
30-
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
31-
// Get the model
32-
a, ok := client.(*agent.Agent)
33-
if !ok {
34-
return fmt.Errorf("No agents found")
35-
}
36-
model, err := a.GetModel(ctx, cmd.Model)
37-
if err != nil {
38-
return err
39-
}
40-
29+
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
4130
// Current buffer
4231
var buf string
4332

@@ -67,6 +56,7 @@ func (cmd *ChatCmd) Run(globals *Globals) error {
6756
input = cmd.Prompt
6857
cmd.Prompt = ""
6958
} else {
59+
var err error
7060
input, err = globals.term.ReadLine(model.Name() + "> ")
7161
if errors.Is(err, io.EOF) {
7262
return nil

cmd/llm/complete.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99

1010
// Packages
1111
llm "github.com/mutablelogic/go-llm"
12-
agent "github.com/mutablelogic/go-llm/pkg/agent"
1312
)
1413

1514
////////////////////////////////////////////////////////////////////////////////
@@ -29,15 +28,9 @@ type CompleteCmd struct {
2928
// PUBLIC METHODS
3029

3130
func (cmd *CompleteCmd) Run(globals *Globals) error {
32-
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
31+
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
3332
var prompt []byte
3433

35-
// Load the model
36-
model, err := client.(*agent.Agent).GetModel(ctx, cmd.Model)
37-
if err != nil {
38-
return err
39-
}
40-
4134
// If we are pipeline content in via stdin
4235
fileInfo, err := os.Stdin.Stat()
4336
if err != nil {

cmd/llm/embedding.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
// Packages
8+
llm "github.com/mutablelogic/go-llm"
9+
)
10+
11+
////////////////////////////////////////////////////////////////////////////////
12+
// TYPES
13+
14+
type EmbeddingCmd struct {
15+
Model string `arg:"" help:"Model name"`
16+
Prompt string `arg:"" help:"Prompt"`
17+
}
18+
19+
////////////////////////////////////////////////////////////////////////////////
20+
// PUBLIC METHODS
21+
22+
func (cmd *EmbeddingCmd) Run(globals *Globals) error {
23+
return run(globals, cmd.Model, func(ctx context.Context, model llm.Model) error {
24+
fmt.Println(model)
25+
return nil
26+
})
27+
}

cmd/llm/main.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type Globals struct {
3838

3939
// Context
4040
ctx context.Context
41-
agent llm.Agent
41+
agent *agent.Agent
4242
toolkit *tool.ToolKit
4343
term *Term
4444
}
@@ -76,9 +76,10 @@ type CLI struct {
7676
Tools ListToolsCmd `cmd:"" help:"Return a list of tools"`
7777

7878
// Commands
79-
Download DownloadModelCmd `cmd:"" help:"Download a model"`
80-
Chat ChatCmd `cmd:"" help:"Start a chat session"`
81-
Complete CompleteCmd `cmd:"" help:"Complete a prompt"`
79+
Download DownloadModelCmd `cmd:"" help:"Download a model"`
80+
Chat ChatCmd `cmd:"" help:"Start a chat session"`
81+
Complete CompleteCmd `cmd:"" help:"Complete a prompt"`
82+
Embedding EmbeddingCmd `cmd:"" help:"Generate an embedding"`
8283
}
8384

8485
////////////////////////////////////////////////////////////////////////////////
@@ -186,3 +187,16 @@ func clientOpts(cli *CLI) []client.ClientOpt {
186187
}
187188
return result
188189
}
190+
191+
////////////////////////////////////////////////////////////////////////////////
192+
// PRIVATE METHODS
193+
194+
func run(globals *Globals, name string, fn func(ctx context.Context, model llm.Model) error) error {
195+
model, err := globals.agent.GetModel(globals.ctx, name)
196+
if err != nil {
197+
return err
198+
}
199+
200+
// Get the model
201+
return fn(globals.ctx, model)
202+
}

cmd/llm/models.go

Lines changed: 49 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package main
22

33
import (
4-
"context"
54
"encoding/json"
65
"fmt"
76
"os"
@@ -10,7 +9,6 @@ import (
109

1110
// Packages
1211
tablewriter "github.com/djthorpe/go-tablewriter"
13-
llm "github.com/mutablelogic/go-llm"
1412
agent "github.com/mutablelogic/go-llm/pkg/agent"
1513
"github.com/mutablelogic/go-llm/pkg/ollama"
1614
)
@@ -35,104 +33,71 @@ type DownloadModelCmd struct {
3533
// PUBLIC METHODS
3634

3735
func (cmd *ListToolsCmd) Run(globals *Globals) error {
38-
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
39-
tools := globals.toolkit.Tools(client)
40-
fmt.Println(tools)
41-
return nil
42-
})
36+
tools := globals.toolkit.Tools(globals.agent)
37+
fmt.Println(tools)
38+
return nil
4339
}
4440

4541
func (cmd *ListModelsCmd) Run(globals *Globals) error {
46-
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
47-
agent_, ok := client.(*agent.Agent)
48-
if !ok {
49-
return fmt.Errorf("No agents found")
50-
}
51-
models, err := agent_.ListModels(ctx, cmd.Agent...)
52-
if err != nil {
53-
return err
54-
}
55-
result := make(ModelList, 0, len(models))
56-
for _, model := range models {
57-
result = append(result, Model{
58-
Agent: model.(*agent.Model).Agent,
59-
Model: model.Name(),
60-
Description: model.Description(),
61-
Aliases: strings.Join(model.Aliases(), ", "),
62-
})
63-
}
64-
// Sort models by name
65-
sort.Sort(result)
42+
models, err := globals.agent.ListModels(globals.ctx, cmd.Agent...)
43+
if err != nil {
44+
return err
45+
}
46+
result := make(ModelList, 0, len(models))
47+
for _, model := range models {
48+
result = append(result, Model{
49+
Agent: model.(*agent.Model).Agent,
50+
Model: model.Name(),
51+
Description: model.Description(),
52+
Aliases: strings.Join(model.Aliases(), ", "),
53+
})
54+
}
55+
56+
// Sort models by name
57+
sort.Sort(result)
6658

67-
// Write out the models
68-
return tablewriter.New(os.Stdout).Write(result, tablewriter.OptOutputText(), tablewriter.OptHeader())
69-
})
59+
// Write out the models
60+
return tablewriter.New(os.Stdout).Write(result, tablewriter.OptOutputText(), tablewriter.OptHeader())
7061
}
7162

7263
func (*ListAgentsCmd) Run(globals *Globals) error {
73-
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
74-
agent, ok := client.(*agent.Agent)
75-
if !ok {
76-
return fmt.Errorf("No agents found")
77-
}
78-
79-
agents := make([]string, 0, len(agent.Agents()))
80-
for _, agent := range agent.Agents() {
81-
agents = append(agents, agent.Name())
82-
}
83-
84-
data, err := json.MarshalIndent(agents, "", " ")
85-
if err != nil {
86-
return err
87-
}
88-
fmt.Println(string(data))
89-
90-
return nil
91-
})
64+
agents := globals.agent.AgentNames()
65+
data, err := json.MarshalIndent(agents, "", " ")
66+
if err != nil {
67+
return err
68+
}
69+
fmt.Println(string(data))
70+
return nil
9271
}
9372

9473
func (cmd *DownloadModelCmd) Run(globals *Globals) error {
95-
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
96-
agent := getagent(client, cmd.Agent)
97-
if agent == nil {
98-
return fmt.Errorf("No agents found with name %q", cmd.Agent)
99-
}
100-
// Download the model
101-
switch agent.Name() {
102-
case "ollama":
103-
model, err := agent.(*ollama.Client).PullModel(ctx, cmd.Model)
104-
if err != nil {
105-
return err
106-
}
107-
fmt.Println(model)
108-
default:
109-
return fmt.Errorf("Agent %q does not support model download", agent.Name())
110-
}
111-
return nil
112-
})
113-
}
114-
115-
////////////////////////////////////////////////////////////////////////////////
116-
// PRIVATE METHODS
117-
118-
func runagent(globals *Globals, fn func(ctx context.Context, agent llm.Agent) error) error {
119-
return fn(globals.ctx, globals.agent)
120-
}
121-
122-
func getagent(client llm.Agent, name string) llm.Agent {
123-
agent, ok := client.(*agent.Agent)
124-
if !ok {
125-
return nil
74+
agents := globals.agent.AgentsWithName(cmd.Agent)
75+
if len(agents) == 0 {
76+
return fmt.Errorf("No agents found with name %q", cmd.Agent)
12677
}
127-
for _, agent := range agent.Agents() {
128-
if agent.Name() == name {
129-
return agent
78+
switch agents[0].Name() {
79+
case "ollama":
80+
model, err := agents[0].(*ollama.Client).PullModel(globals.ctx, cmd.Model, ollama.WithPullStatus(func(status *ollama.PullStatus) {
81+
var pct int64
82+
if status.TotalBytes > 0 {
83+
pct = status.CompletedBytes * 100 / status.TotalBytes
84+
}
85+
fmt.Print("\r", status.Status, " ", pct, "%")
86+
if status.Status == "success" {
87+
fmt.Println("")
88+
}
89+
}))
90+
if err != nil {
91+
return err
13092
}
93+
fmt.Println(model)
94+
default:
95+
return fmt.Errorf("Agent %q does not support model download", agents[0].Name())
13196
}
13297
return nil
13398
}
13499

135-
// //////////////////////////////////////////////////////////////////////////////
100+
////////////////////////////////////////////////////////////////////////////////
136101
// MODEL LIST
137102

138103
type Model struct {

0 commit comments

Comments
 (0)