Skip to content

Commit 48e4c42

Browse files
committed
Added mistral
1 parent 56d152d commit 48e4c42

File tree

6 files changed

+243
-0
lines changed

6 files changed

+243
-0
lines changed

cmd/agent/main.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type Globals struct {
2727
// Agents
2828
Ollama `embed:"" help:"Ollama configuration"`
2929
Anthropic `embed:"" help:"Anthropic configuration"`
30+
Mistral `embed:"" help:"Mistral configuration"`
3031

3132
// Tools
3233
NewsAPI `embed:"" help:"NewsAPI configuration"`
@@ -46,6 +47,10 @@ type Anthropic struct {
4647
AnthropicKey string `env:"ANTHROPIC_API_KEY" help:"Anthropic API Key"`
4748
}
4849

50+
type Mistral struct {
51+
MistralKey string `env:"MISTRAL_API_KEY" help:"Mistral API Key"`
52+
}
53+
4954
type NewsAPI struct {
5055
NewsKey string `env:"NEWSAPI_KEY" help:"News API Key"`
5156
}
@@ -105,6 +110,9 @@ func main() {
105110
if cli.AnthropicKey != "" {
106111
opts = append(opts, agent.WithAnthropic(cli.AnthropicKey, clientopts...))
107112
}
113+
if cli.MistralKey != "" {
114+
opts = append(opts, agent.WithMistral(cli.MistralKey, clientopts...))
115+
}
108116

109117
// Make a toolkit
110118
toolkit := tool.NewToolKit()

pkg/agent/opt.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
client "github.com/mutablelogic/go-client"
66
llm "github.com/mutablelogic/go-llm"
77
anthropic "github.com/mutablelogic/go-llm/pkg/anthropic"
8+
mistral "github.com/mutablelogic/go-llm/pkg/mistral"
89
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
910
)
1011

@@ -32,3 +33,14 @@ func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt {
3233
}
3334
}
3435
}
36+
37+
func WithMistral(key string, opts ...client.ClientOpt) llm.Opt {
38+
return func(o *llm.Opts) error {
39+
client, err := mistral.New(key, opts...)
40+
if err != nil {
41+
return err
42+
} else {
43+
return llm.WithAgent(client)(o)
44+
}
45+
}
46+
}

pkg/mistral/client.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
mistral implements an API client for mistral (https://docs.mistral.ai/api/)
3+
*/
4+
package mistral
5+
6+
import (
7+
// Packages
8+
"context"
9+
10+
"github.com/mutablelogic/go-client"
11+
"github.com/mutablelogic/go-llm"
12+
)
13+
14+
///////////////////////////////////////////////////////////////////////////////
15+
// TYPES
16+
17+
type Client struct {
18+
*client.Client
19+
}
20+
21+
var _ llm.Agent = (*Client)(nil)
22+
23+
///////////////////////////////////////////////////////////////////////////////
24+
// GLOBALS
25+
26+
const (
27+
endPoint = "https://api.mistral.ai/v1"
28+
defaultName = "mistral"
29+
)
30+
31+
///////////////////////////////////////////////////////////////////////////////
32+
// LIFECYCLE
33+
34+
// Create a new client
35+
func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) {
36+
// Create client
37+
opts = append(opts, client.OptEndpoint(endPoint))
38+
opts = append(opts, client.OptReqToken(client.Token{
39+
Scheme: client.Bearer,
40+
Value: ApiKey,
41+
}))
42+
client, err := client.New(opts...)
43+
if err != nil {
44+
return nil, err
45+
}
46+
47+
// Return the client
48+
return &Client{client}, nil
49+
}
50+
51+
///////////////////////////////////////////////////////////////////////////////
52+
// PUBLIC METHODS
53+
54+
// Return the name of the agent
55+
func (Client) Name() string {
56+
return defaultName
57+
}
58+
59+
// Return the models
60+
func (c *Client) Models(ctx context.Context) ([]llm.Model, error) {
61+
return c.ListModels(ctx)
62+
}
63+
64+
// Return a model by name, or nil if not found.
65+
// Panics on error.
66+
func (c *Client) Model(ctx context.Context, name string) llm.Model {
67+
return nil
68+
}

pkg/mistral/client_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package mistral_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
// Packages
8+
opts "github.com/mutablelogic/go-client"
9+
mistral "github.com/mutablelogic/go-llm/pkg/mistral"
10+
assert "github.com/stretchr/testify/assert"
11+
)
12+
13+
func Test_client_001(t *testing.T) {
14+
assert := assert.New(t)
15+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
16+
assert.NoError(err)
17+
assert.NotNil(client)
18+
t.Log(client)
19+
}
20+
21+
///////////////////////////////////////////////////////////////////////////////
22+
// ENVIRONMENT
23+
24+
func GetApiKey(t *testing.T) string {
25+
key := os.Getenv("MISTRAL_API_KEY")
26+
if key == "" {
27+
t.Skip("MISTRAL_API_KEY not set")
28+
t.SkipNow()
29+
}
30+
return key
31+
}

pkg/mistral/model.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package mistral
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"github.com/mutablelogic/go-client"
8+
"github.com/mutablelogic/go-llm"
9+
)
10+
11+
///////////////////////////////////////////////////////////////////////////////
12+
// TYPES
13+
14+
type model struct {
15+
meta Model
16+
}
17+
18+
type Model struct {
19+
Name string `json:"id"`
20+
Description string `json:"description,omitempty"`
21+
Type string `json:"type,omitempty"`
22+
CreatedAt *uint64 `json:"created,omitempty"`
23+
OwnedBy string `json:"owned_by,omitempty"`
24+
MaxContextLength uint64 `json:"max_context_length,omitempty"`
25+
Aliases []string `json:"aliases,omitempty"`
26+
Deprecation *string `json:"deprecation,omitempty"`
27+
DefaultModelTemperature *float64 `json:"default_model_temperature,omitempty"`
28+
Capabilities struct {
29+
CompletionChat bool `json:"completion_chat,omitempty"`
30+
CompletionFim bool `json:"completion_fim,omitempty"`
31+
FunctionCalling bool `json:"function_calling,omitempty"`
32+
FineTuning bool `json:"fine_tuning,omitempty"`
33+
Vision bool `json:"vision,omitempty"`
34+
} `json:"capabilities,omitempty"`
35+
}
36+
37+
///////////////////////////////////////////////////////////////////////////////
38+
// STRINGIFY
39+
40+
func (m model) MarshalJSON() ([]byte, error) {
41+
return json.Marshal(m.meta)
42+
}
43+
44+
func (m model) String() string {
45+
data, err := json.MarshalIndent(m, "", " ")
46+
if err != nil {
47+
return err.Error()
48+
}
49+
return string(data)
50+
}
51+
52+
///////////////////////////////////////////////////////////////////////////////
53+
// PUBLIC METHODS - API
54+
55+
// ListModels returns all the models
56+
func (c *Client) ListModels(ctx context.Context) ([]llm.Model, error) {
57+
// Response
58+
var response struct {
59+
Data []Model `json:"data"`
60+
}
61+
if err := c.DoWithContext(ctx, nil, &response, client.OptPath("models")); err != nil {
62+
return nil, err
63+
}
64+
65+
// Make models
66+
result := make([]llm.Model, 0, len(response.Data))
67+
for _, meta := range response.Data {
68+
result = append(result, &model{meta: meta})
69+
}
70+
71+
// Return models
72+
return result, nil
73+
}
74+
75+
///////////////////////////////////////////////////////////////////////////////
76+
// PUBLIC METHODS - MODEL
77+
78+
// Return the name of the model
79+
func (m model) Name() string {
80+
return m.meta.Name
81+
}
82+
83+
// Return am empty session context object for the model,
84+
// setting session options
85+
func (m model) Context(...llm.Opt) llm.Context {
86+
return nil
87+
}
88+
89+
// Convenience method to create a session context object
90+
// with a user prompt
91+
func (m model) UserPrompt(string, ...llm.Opt) llm.Context {
92+
return nil
93+
}
94+
95+
// Embedding vector generation
96+
func (m model) Embedding(context.Context, string, ...llm.Opt) ([]float64, error) {
97+
return nil, llm.ErrNotImplemented
98+
}

pkg/mistral/model_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package mistral_test
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"os"
7+
"testing"
8+
9+
// Packages
10+
opts "github.com/mutablelogic/go-client"
11+
mistral "github.com/mutablelogic/go-llm/pkg/mistral"
12+
assert "github.com/stretchr/testify/assert"
13+
)
14+
15+
func Test_models_001(t *testing.T) {
16+
assert := assert.New(t)
17+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
18+
assert.NoError(err)
19+
assert.NotNil(client)
20+
response, err := client.ListModels(context.TODO())
21+
assert.NoError(err)
22+
assert.NotEmpty(response)
23+
data, err := json.MarshalIndent(response, "", " ")
24+
assert.NoError(err)
25+
t.Log(string(data))
26+
}

0 commit comments

Comments
 (0)