Skip to content

Commit e896bc7

Browse files
committed
Updated
1 parent 0c84bb6 commit e896bc7

File tree

11 files changed

+139
-21
lines changed

11 files changed

+139
-21
lines changed

cmd/llm/complete.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type CompleteCmd struct {
2121
File []string `type:"file" short:"f" help:"Files to attach"`
2222
System string `flag:"system" help:"Set the system prompt"`
2323
NoStream bool `flag:"no-stream" help:"Do not stream output"`
24-
Format string `flag:"format" enum:"text,json" default:"text" help:"Output format. You may also need to specify the output in the system or user prompt."`
24+
Format string `flag:"format" enum:"text,markdown,json" default:"text" help:"Output format"`
2525
Temperature *float64 `flag:"temperature" short:"t" help:"Temperature for sampling"`
2626
}
2727

@@ -97,14 +97,30 @@ func (cmd *CompleteCmd) Run(globals *Globals) error {
9797

9898
func (cmd *CompleteCmd) opts() []llm.Opt {
9999
opts := []llm.Opt{}
100+
101+
// Set system prompt
102+
var system []string
103+
if cmd.Format == "markdown" {
104+
system = append(system, "Return the completion in markdown format.")
105+
} else if cmd.Format == "json" {
106+
system = append(system, "Return the completion in JSON format.")
107+
}
100108
if cmd.System != "" {
101-
opts = append(opts, llm.WithSystemPrompt(cmd.System))
109+
system = append(system, cmd.System)
110+
}
111+
if len(system) > 0 {
112+
opts = append(opts, llm.WithSystemPrompt(strings.Join(system, "\n")))
102113
}
114+
115+
// Set format
103116
if cmd.Format == "json" {
104117
opts = append(opts, llm.WithFormat("json"))
105118
}
119+
120+
// Set temperature
106121
if cmd.Temperature != nil {
107122
opts = append(opts, llm.WithTemperature(*cmd.Temperature))
108123
}
124+
109125
return opts
110126
}

cmd/llm/models.go

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"os"
8+
"sort"
9+
"strings"
710

811
// Packages
12+
tablewriter "github.com/djthorpe/go-tablewriter"
913
llm "github.com/mutablelogic/go-llm"
1014
agent "github.com/mutablelogic/go-llm/pkg/agent"
15+
"github.com/mutablelogic/go-llm/pkg/ollama"
1116
)
1217

1318
////////////////////////////////////////////////////////////////////////////////
@@ -39,16 +44,28 @@ func (cmd *ListToolsCmd) Run(globals *Globals) error {
3944

4045
func (cmd *ListModelsCmd) Run(globals *Globals) error {
4146
return runagent(globals, func(ctx context.Context, client llm.Agent) error {
42-
agent, ok := client.(*agent.Agent)
47+
agent_, ok := client.(*agent.Agent)
4348
if !ok {
4449
return fmt.Errorf("No agents found")
4550
}
46-
models, err := agent.ListModels(ctx, cmd.Agent...)
51+
models, err := agent_.ListModels(ctx, cmd.Agent...)
4752
if err != nil {
4853
return err
4954
}
50-
fmt.Println(models)
51-
return nil
55+
result := make(ModelList, 0, len(models))
56+
for _, model := range models {
57+
result = append(result, Model{
58+
Agent: model.(*agent.Model).Agent,
59+
Model: model.Name(),
60+
Description: model.Description(),
61+
Aliases: strings.Join(model.Aliases(), ", "),
62+
})
63+
}
64+
// Sort models by name
65+
sort.Sort(result)
66+
67+
// Write out the models
68+
return tablewriter.New(os.Stdout).Write(result, tablewriter.OptOutputText(), tablewriter.OptHeader())
5269
})
5370
}
5471

@@ -82,14 +99,12 @@ func (cmd *DownloadModelCmd) Run(globals *Globals) error {
8299
}
83100
// Download the model
84101
switch agent.Name() {
85-
/*
86-
case "ollama":
87-
model, err := agent.(*ollama.Client).PullModel(ctx, cmd.Model)
88-
if err != nil {
89-
return err
90-
}
91-
fmt.Println(model)
92-
*/
102+
case "ollama":
103+
model, err := agent.(*ollama.Client).PullModel(ctx, cmd.Model)
104+
if err != nil {
105+
return err
106+
}
107+
fmt.Println(model)
93108
default:
94109
return fmt.Errorf("Agent %q does not support model download", agent.Name())
95110
}
@@ -116,3 +131,27 @@ func getagent(client llm.Agent, name string) llm.Agent {
116131
}
117132
return nil
118133
}
134+
135+
// //////////////////////////////////////////////////////////////////////////////
136+
// MODEL LIST
137+
138+
type Model struct {
139+
Agent string `json:"agent" writer:"Agent,width:10"`
140+
Model string `json:"model" writer:"Model,wrap,width:40"`
141+
Description string `json:"description" writer:"Description,wrap,width:60"`
142+
Aliases string `json:"aliases" writer:"Aliases,wrap,width:30"`
143+
}
144+
145+
type ModelList []Model
146+
147+
func (models ModelList) Len() int {
148+
return len(models)
149+
}
150+
151+
func (models ModelList) Less(a, b int) bool {
152+
return strings.Compare(models[a].Model, models[b].Model) < 0
153+
}
154+
155+
func (models ModelList) Swap(a, b int) {
156+
models[a], models[b] = models[b], models[a]
157+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/MichaelMure/go-term-text v0.3.1
77
github.com/alecthomas/kong v1.7.0
88
github.com/djthorpe/go-errors v1.0.3
9+
github.com/djthorpe/go-tablewriter v0.0.7
910
github.com/fatih/color v1.9.0
1011
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
1112
github.com/mutablelogic/go-client v1.0.10

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
1111
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
1212
github.com/djthorpe/go-errors v1.0.3 h1:GZeMPkC1mx2vteXLI/gvxZS0Ee9zxzwD1mcYyKU5jD0=
1313
github.com/djthorpe/go-errors v1.0.3/go.mod h1:HtfrZnMd6HsX75Mtbv9Qcnn0BqOrrFArvCaj3RMnZhY=
14+
github.com/djthorpe/go-tablewriter v0.0.7 h1:jnNsJDjjLLCt0OAqB5DzGZN7V3beT1IpNMQ8GcOwZDU=
15+
github.com/djthorpe/go-tablewriter v0.0.7/go.mod h1:NVBvytpL+6fHfCKn0+3lSi15/G3A1HWf2cLNeHg6YBg=
1416
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
1517
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
1618
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=

model.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ type Model interface {
1111
// Return the name of the model
1212
Name() string
1313

14+
// Return the description of the model
15+
Description() string
16+
17+
// Return any model aliases
18+
Aliases() []string
19+
1420
// Return am empty session context object for the model, setting
1521
// session options
1622
Context(...Opt) Context

pkg/agent/agent.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ type Agent struct {
1818
*llm.Opts
1919
}
2020

21-
type model struct {
21+
type Model struct {
2222
Agent string `json:"agent"`
2323
llm.Model `json:"model"`
2424
}
@@ -44,7 +44,7 @@ func New(opts ...llm.Opt) (*Agent, error) {
4444
///////////////////////////////////////////////////////////////////////////////
4545
// STRINGIFY
4646

47-
func (m model) String() string {
47+
func (m Model) String() string {
4848
data, err := json.MarshalIndent(m, "", " ")
4949
if err != nil {
5050
return err.Error()
@@ -168,13 +168,27 @@ func modelsForAgent(ctx context.Context, agent llm.Agent, names ...string) ([]ll
168168
return nil, err
169169
}
170170

171+
match_model := func(model llm.Model, names ...string) bool {
172+
if len(names) == 0 {
173+
return true
174+
}
175+
if slices.Contains(names, model.Name()) {
176+
return true
177+
}
178+
for _, alias := range model.Aliases() {
179+
if slices.Contains(names, alias) {
180+
return true
181+
}
182+
}
183+
return false
184+
}
185+
171186
// Filter models
172187
result := make([]llm.Model, 0, len(models))
173188
for _, agentmodel := range models {
174-
if len(names) > 0 && !slices.Contains(names, agentmodel.Name()) {
175-
continue
189+
if match_model(agentmodel, names...) {
190+
result = append(result, &Model{Agent: agent.Name(), Model: agentmodel})
176191
}
177-
result = append(result, &model{Agent: agent.Name(), Model: agentmodel})
178192
}
179193

180194
// Return success

pkg/gemini/model.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ func (m model) Name() string {
5555
return m.meta.Name
5656
}
5757

58+
// Return model aliases
59+
func (model model) Aliases() []string {
60+
return nil
61+
}
62+
63+
// Return model description
64+
func (model model) Description() string {
65+
return model.meta.Description
66+
}
67+
5868
// Return the models
5969
func (gemini *Client) Models(ctx context.Context) ([]llm.Model, error) {
6070
// Cache models

pkg/mistral/completion_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@ import (
77
"testing"
88

99
// Packages
10-
1110
llm "github.com/mutablelogic/go-llm"
12-
"github.com/mutablelogic/go-llm/pkg/tool"
11+
tool "github.com/mutablelogic/go-llm/pkg/tool"
1312
assert "github.com/stretchr/testify/assert"
1413
)
1514

pkg/mistral/model.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ func (model model) Name() string {
9797
return model.meta.Name
9898
}
9999

100+
// Return model aliases
101+
func (model model) Aliases() []string {
102+
return model.meta.Aliases
103+
}
104+
105+
// Return model description
106+
func (model model) Description() string {
107+
return model.meta.Description
108+
}
109+
100110
// Return a new empty session
101111
func (model *model) Context(opts ...llm.Opt) llm.Context {
102112
return impl.NewSession(model, &messagefactory{}, opts...)

pkg/ollama/model.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"net/http"
7+
"strings"
78
"time"
89

910
// Packages
@@ -88,6 +89,16 @@ func (m model) Name() string {
8889
return m.ModelMeta.Name
8990
}
9091

92+
// Return model name
93+
func (model) Aliases() []string {
94+
return nil
95+
}
96+
97+
// Return model description
98+
func (model model) Description() string {
99+
return strings.Join(model.ModelMeta.Details.Families, ", ")
100+
}
101+
91102
// Agent interface
92103
func (ollama *Client) Models(ctx context.Context) ([]llm.Model, error) {
93104
// We don't explicitly cache models

0 commit comments

Comments
 (0)