Skip to content

Commit 9247d07

Browse files
committed
Updated
1 parent b56655f commit 9247d07

File tree

15 files changed

+472
-262
lines changed

15 files changed

+472
-262
lines changed

pkg/internal/impl/cache.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package impl
2+
3+
import (
4+
// Packages
5+
"sync"
6+
7+
llm "github.com/mutablelogic/go-llm"
8+
)
9+
10+
///////////////////////////////////////////////////////////////////////////////
11+
// TYPES
12+
13+
type ModelCache struct {
14+
sync.RWMutex
15+
cache map[string]llm.Model
16+
}
17+
18+
type ModelLoadFunc func() ([]llm.Model, error)
19+
20+
///////////////////////////////////////////////////////////////////////////////
21+
// LIFECYCLE
22+
23+
func NewModelCache() *ModelCache {
24+
cache := new(ModelCache)
25+
cache.cache = make(map[string]llm.Model, 20)
26+
return cache
27+
}
28+
29+
///////////////////////////////////////////////////////////////////////////////
30+
// METHODS
31+
32+
// Load models and return them
33+
func (c *ModelCache) Load(fn ModelLoadFunc) ([]llm.Model, error) {
34+
c.Lock()
35+
defer c.Unlock()
36+
37+
// Load models
38+
if len(c.cache) == 0 {
39+
if models, err := fn(); err != nil {
40+
return nil, err
41+
} else {
42+
for _, m := range models {
43+
c.cache[m.Name()] = m
44+
}
45+
}
46+
}
47+
48+
// Return models
49+
result := make([]llm.Model, 0, len(c.cache))
50+
for _, model := range c.cache {
51+
result = append(result, model)
52+
}
53+
return result, nil
54+
}
55+
56+
// Return a model by name
57+
func (c *ModelCache) Get(fn ModelLoadFunc, name string) (llm.Model, error) {
58+
if len(c.cache) == 0 {
59+
if _, err := c.Load(fn); err != nil {
60+
return nil, err
61+
}
62+
}
63+
c.RLock()
64+
defer c.RUnlock()
65+
return c.cache[name], nil
66+
}

pkg/session/session.go renamed to pkg/internal/impl/session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package session
1+
package impl
22

33
import (
44
"context"

pkg/mistral/client.go

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
/*
2-
mistral implements an API client for mistral (https://docs.mistral.ai/api/)
2+
mistral implements an API client for mistral
3+
https://docs.mistral.ai/api/
34
*/
45
package mistral
56

67
import (
7-
"context"
8-
98
// Packages
109
client "github.com/mutablelogic/go-client"
1110
llm "github.com/mutablelogic/go-llm"
11+
impl "github.com/mutablelogic/go-llm/pkg/internal/impl"
1212
)
1313

1414
///////////////////////////////////////////////////////////////////////////////
1515
// TYPES
1616

1717
type Client struct {
1818
*client.Client
19-
cache map[string]llm.Model
19+
*impl.ModelCache
2020
}
2121

2222
var _ llm.Agent = (*Client)(nil)
@@ -46,7 +46,7 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) {
4646
}
4747

4848
// Return the client
49-
return &Client{client, nil}, nil
49+
return &Client{client, impl.NewModelCache()}, nil
5050
}
5151

5252
///////////////////////////////////////////////////////////////////////////////
@@ -56,36 +56,3 @@ func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) {
5656
func (Client) Name() string {
5757
return defaultName
5858
}
59-
60-
// Return the models
61-
func (c *Client) Models(ctx context.Context) ([]llm.Model, error) {
62-
// Cache models
63-
if c.cache == nil {
64-
models, err := c.ListModels(ctx)
65-
if err != nil {
66-
return nil, err
67-
}
68-
c.cache = make(map[string]llm.Model, len(models))
69-
for _, model := range models {
70-
c.cache[model.Name()] = model
71-
}
72-
}
73-
74-
// Return models
75-
result := make([]llm.Model, 0, len(c.cache))
76-
for _, model := range c.cache {
77-
result = append(result, model)
78-
}
79-
return result, nil
80-
}
81-
82-
// Return a model by name, or nil if not found.
83-
// Panics on error.
84-
func (c *Client) Model(ctx context.Context, name string) llm.Model {
85-
if c.cache == nil {
86-
if _, err := c.Models(ctx); err != nil {
87-
panic(err)
88-
}
89-
}
90-
return c.cache[name]
91-
}

pkg/mistral/completion.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"strings"
77

8+
// Packages
89
"github.com/mutablelogic/go-client"
910
"github.com/mutablelogic/go-llm"
1011
)
@@ -22,6 +23,17 @@ type Response struct {
2223
Metrics `json:"usage,omitempty"`
2324
}
2425

26+
// Possible completions
27+
type Completions []Completion
28+
29+
// Completion Variation
30+
type Completion struct {
31+
Index uint64 `json:"index"`
32+
Message *Message `json:"message"`
33+
Delta *Message `json:"delta,omitempty"` // For streaming
34+
Reason string `json:"finish_reason,omitempty"`
35+
}
36+
2537
// Metrics
2638
type Metrics struct {
2739
InputTokens uint64 `json:"prompt_tokens,omitempty"`
@@ -203,3 +215,45 @@ func appendCompletion(response *Response, c *Completion) {
203215
}
204216
}
205217
}
218+
219+
///////////////////////////////////////////////////////////////////////////////
220+
// PUBLIC METHODS - COMPLETIONS
221+
222+
// Return the number of completions
223+
func (c Completions) Num() int {
224+
return len(c)
225+
}
226+
227+
// Return message for a specific completion
228+
func (c Completions) Message(index int) *Message {
229+
if index < 0 || index >= len(c) {
230+
return nil
231+
}
232+
return c[index].Message
233+
}
234+
235+
// Return the role of the completion
236+
func (c Completions) Role() string {
237+
// The role should be the same for all completions, let's use the first one
238+
if len(c) == 0 {
239+
return ""
240+
}
241+
return c[0].Message.Role()
242+
}
243+
244+
// Return the text content for a specific completion
245+
func (c Completions) Text(index int) string {
246+
if index < 0 || index >= len(c) {
247+
return ""
248+
}
249+
return c[index].Message.Text(0)
250+
}
251+
252+
// Return the current session tool calls given the completion index.
253+
// Will return nil if no tool calls were returned.
254+
func (c Completions) ToolCalls(index int) []llm.ToolCall {
255+
if index < 0 || index >= len(c) {
256+
return nil
257+
}
258+
return c[index].Message.ToolCalls(0)
259+
}

pkg/mistral/content.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package mistral
2+
3+
import (
4+
"net/url"
5+
6+
"github.com/mutablelogic/go-llm"
7+
)
8+
9+
///////////////////////////////////////////////////////////////////////////////
10+
// TYPES
11+
12+
type Content struct {
13+
Type string `json:"type"` // text or content
14+
*Text `json:"text,omitempty"` // text content
15+
*Prediction `json:"content,omitempty"` // prediction
16+
*Image `json:"image_url,omitempty"` // image_url
17+
}
18+
19+
// text content
20+
type Text string
21+
22+
// text content
23+
type Prediction string
24+
25+
// either a URL or "data:image/png;base64," followed by the base64 encoded image
26+
type Image string
27+
28+
///////////////////////////////////////////////////////////////////////////////
29+
// LICECYCLE
30+
31+
func NewPrediction(content Prediction) *Content {
32+
return &Content{Type: "content", Prediction: &content}
33+
}
34+
35+
func NewTextContext(text Text) *Content {
36+
return &Content{Type: "text", Text: &text}
37+
}
38+
39+
func NewImageData(image *llm.Attachment) *Content {
40+
url := Image(image.Url())
41+
return &Content{Type: "image_url", Image: &url}
42+
}
43+
44+
func NewImageUrl(u *url.URL) *Content {
45+
url := Image(u.String())
46+
return &Content{Type: "image_url", Image: &url}
47+
}

0 commit comments

Comments
 (0)