Skip to content

Commit 772de80

Browse files
committed
Added mistralai
1 parent 82ed6f6 commit 772de80

File tree

12 files changed

+303
-57
lines changed

12 files changed

+303
-57
lines changed

cmd/cli/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313

1414
func main() {
1515
name := path.Base(os.Args[0])
16-
flags, err := NewFlags(name, os.Args[1:], OpenAIFlags, ElevenlabsFlags, HomeAssistantFlags)
16+
flags, err := NewFlags(name, os.Args[1:], OpenAIFlags, MistralFlags, ElevenlabsFlags, HomeAssistantFlags)
1717
if err != nil {
1818
if err != flag.ErrHelp {
1919
fmt.Fprintln(os.Stderr, err)
@@ -53,6 +53,12 @@ func main() {
5353
os.Exit(1)
5454
}
5555

56+
cmd, err = MistralRegister(cmd, opts, flags)
57+
if err != nil {
58+
fmt.Fprintln(os.Stderr, err)
59+
os.Exit(1)
60+
}
61+
5662
cmd, err = HomeAssistantRegister(cmd, opts, flags)
5763
if err != nil {
5864
fmt.Fprintln(os.Stderr, err)

cmd/cli/mistral.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package main
2+
3+
import (
4+
// Package imports
5+
"github.com/mutablelogic/go-client/pkg/client"
6+
"github.com/mutablelogic/go-client/pkg/mistral"
7+
)
8+
9+
/////////////////////////////////////////////////////////////////////
10+
// REGISTER FUNCTIONS
11+
12+
func MistralFlags(flags *Flags) {
13+
flags.String("mistral-api-key", "${MISTRAL_API_KEY}", "Mistral API key")
14+
}
15+
16+
func MistralRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Client, error) {
17+
mistral, err := mistral.New(flags.GetString("mistral-api-key"), opts...)
18+
if err != nil {
19+
return nil, err
20+
}
21+
22+
// Register commands
23+
cmd = append(cmd, Client{
24+
ns: "mistral",
25+
cmd: []Command{
26+
{Name: "models", Description: "Return registered models", MinArgs: 2, MaxArgs: 2, Fn: mistralModels(mistral, flags)},
27+
},
28+
})
29+
30+
// Return success
31+
return cmd, nil
32+
}
33+
34+
/////////////////////////////////////////////////////////////////////
35+
// API CALL FUNCTIONS
36+
37+
func mistralModels(client *mistral.Client, flags *Flags) CommandFn {
38+
return func() error {
39+
if models, err := client.ListModels(); err != nil {
40+
return err
41+
} else {
42+
return flags.Write(models)
43+
}
44+
}
45+
}

pkg/mistral/client.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
mistral implements an API client for mistral (https://docs.mistral.ai/api/)
3+
*/
4+
package mistral
5+
6+
import (
7+
// Packages
8+
"github.com/mutablelogic/go-client/pkg/client"
9+
)
10+
11+
///////////////////////////////////////////////////////////////////////////////
12+
// TYPES
13+
14+
type Client struct {
15+
*client.Client
16+
}
17+
18+
///////////////////////////////////////////////////////////////////////////////
19+
// GLOBALS
20+
21+
const (
22+
endPoint = "https://api.mistral.ai/v1"
23+
)
24+
25+
///////////////////////////////////////////////////////////////////////////////
26+
// LIFECYCLE
27+
28+
// Create a new client
29+
func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) {
30+
// Create client
31+
opts = append(opts, client.OptEndpoint(endPoint))
32+
opts = append(opts, client.OptReqToken(client.Token{
33+
Scheme: client.Bearer,
34+
Value: ApiKey,
35+
}))
36+
client, err := client.New(opts...)
37+
if err != nil {
38+
return nil, err
39+
}
40+
41+
// Return the client
42+
return &Client{client}, nil
43+
}

pkg/mistral/client_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package mistral_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
// Packages
8+
opts "github.com/mutablelogic/go-client/pkg/client"
9+
mistral "github.com/mutablelogic/go-client/pkg/mistral"
10+
assert "github.com/stretchr/testify/assert"
11+
)
12+
13+
func Test_client_001(t *testing.T) {
14+
assert := assert.New(t)
15+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
16+
assert.NoError(err)
17+
assert.NotNil(client)
18+
t.Log(client)
19+
}
20+
21+
///////////////////////////////////////////////////////////////////////////////
22+
// ENVIRONMENT
23+
24+
func GetApiKey(t *testing.T) string {
25+
key := os.Getenv("MISTRAL_API_KEY")
26+
if key == "" {
27+
t.Skip("MISTRAL_API_KEY not set")
28+
t.SkipNow()
29+
}
30+
return key
31+
}

pkg/mistral/embedding.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package mistral
2+
3+
import (
4+
// Packages
5+
"github.com/mutablelogic/go-client/pkg/client"
6+
7+
// Namespace imports
8+
. "github.com/djthorpe/go-errors"
9+
. "github.com/mutablelogic/go-client/pkg/openai/schema"
10+
)
11+
12+
///////////////////////////////////////////////////////////////////////////////
13+
// TYPES
14+
15+
// A request to create embeddings
16+
type reqCreateEmbedding struct {
17+
Input []string `json:"input"`
18+
Model string `json:"model"`
19+
EncodingFormat string `json:"encoding_format,omitempty"`
20+
}
21+
22+
///////////////////////////////////////////////////////////////////////////////
23+
// API CALLS
24+
25+
// CreateEmbedding creates an embedding from a string or array of strings
26+
func (c *Client) CreateEmbedding(content any) (Embeddings, error) {
27+
var request reqCreateEmbedding
28+
29+
// Set the input, which is either a string or array of strings
30+
switch v := content.(type) {
31+
case string:
32+
request.Input = []string{v}
33+
case []string:
34+
request.Input = v
35+
default:
36+
return Embeddings{}, ErrBadParameter
37+
}
38+
39+
// Return the response
40+
var response Embeddings
41+
if payload, err := client.NewJSONRequest(request, client.ContentTypeJson); err != nil {
42+
return Embeddings{}, err
43+
} else if err := c.Do(payload.Post(), &response, client.OptPath("embeddings")); err != nil {
44+
return Embeddings{}, err
45+
}
46+
47+
// Return success
48+
return response, nil
49+
}

pkg/mistral/model.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package mistral
2+
3+
import (
4+
// Packages
5+
"github.com/mutablelogic/go-client/pkg/client"
6+
7+
// Namespace imports
8+
. "github.com/mutablelogic/go-client/pkg/openai/schema"
9+
)
10+
11+
///////////////////////////////////////////////////////////////////////////////
12+
// TYPES
13+
14+
type responseListModels struct {
15+
Data []Model `json:"data"`
16+
}
17+
18+
///////////////////////////////////////////////////////////////////////////////
19+
// API CALLS
20+
21+
// ListModels returns all the models
22+
func (c *Client) ListModels() ([]Model, error) {
23+
var response responseListModels
24+
25+
// Request the models, populate the response
26+
payload := client.NewRequest(client.ContentTypeJson)
27+
if err := c.Do(payload, &response, client.OptPath("models")); err != nil {
28+
return nil, err
29+
}
30+
31+
// Return success
32+
return response.Data, nil
33+
}

pkg/mistral/model_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package mistral_test
2+
3+
import (
4+
"encoding/json"
5+
"os"
6+
"testing"
7+
8+
// Packages
9+
opts "github.com/mutablelogic/go-client/pkg/client"
10+
mistral "github.com/mutablelogic/go-client/pkg/mistral"
11+
assert "github.com/stretchr/testify/assert"
12+
)
13+
14+
func Test_models_001(t *testing.T) {
15+
assert := assert.New(t)
16+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
17+
assert.NoError(err)
18+
assert.NotNil(client)
19+
response, err := client.ListModels()
20+
assert.NoError(err)
21+
assert.NotEmpty(response)
22+
data, err := json.MarshalIndent(response, "", " ")
23+
assert.NoError(err)
24+
t.Log(string(data))
25+
}

pkg/openai/embedding.go

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,14 @@
11
package openai
22

33
import (
4-
"math"
5-
64
// Packages
75
"github.com/mutablelogic/go-client/pkg/client"
86

97
// Namespace imports
108
. "github.com/djthorpe/go-errors"
9+
. "github.com/mutablelogic/go-client/pkg/openai/schema"
1110
)
1211

13-
///////////////////////////////////////////////////////////////////////////////
14-
// PUBLIC METHODS
15-
16-
func (e Embedding) CosineDistance(other Embedding) float64 {
17-
count := 0
18-
length_a := len(e.Embedding)
19-
length_b := len(other.Embedding)
20-
if length_a > length_b {
21-
count = length_a
22-
} else {
23-
count = length_b
24-
}
25-
sumA := 0.0
26-
s1 := 0.0
27-
s2 := 0.0
28-
for k := 0; k < count; k++ {
29-
if k >= length_a {
30-
s2 += math.Pow(other.Embedding[k], 2)
31-
continue
32-
}
33-
if k >= length_b {
34-
s1 += math.Pow(e.Embedding[k], 2)
35-
continue
36-
}
37-
sumA += e.Embedding[k] * other.Embedding[k]
38-
s1 += math.Pow(e.Embedding[k], 2)
39-
s2 += math.Pow(other.Embedding[k], 2)
40-
}
41-
return sumA / (math.Sqrt(s1) * math.Sqrt(s2))
42-
}
43-
4412
///////////////////////////////////////////////////////////////////////////////
4513
// API CALLS
4614

pkg/openai/model.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package openai
33
import (
44
// Packages
55
"github.com/mutablelogic/go-client/pkg/client"
6+
7+
// Namespace imports
8+
. "github.com/mutablelogic/go-client/pkg/openai/schema"
69
)
710

811
///////////////////////////////////////////////////////////////////////////////

pkg/openai/schema.go

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,12 @@ import (
55

66
// Namespace imports
77
. "github.com/djthorpe/go-errors"
8+
. "github.com/mutablelogic/go-client/pkg/openai/schema"
89
)
910

1011
///////////////////////////////////////////////////////////////////////////////
1112
// TYPES
1213

13-
// A model object
14-
type Model struct {
15-
Id string `json:"id"`
16-
Created int64 `json:"created"`
17-
Owner string `json:"owned_by"`
18-
}
19-
20-
// An embedding object
21-
type Embedding struct {
22-
Embedding []float64 `json:"embedding"`
23-
Index int `json:"index"`
24-
}
25-
26-
// An set of created embeddings
27-
type Embeddings struct {
28-
Data []Embedding `json:"data"`
29-
Model string `json:"model"`
30-
Usage struct {
31-
PromptTokerns int `json:"prompt_tokens"`
32-
TotalTokens int `json:"total_tokens"`
33-
} `json:"usage"`
34-
}
35-
3614
// A chat completion object
3715
type Chat struct {
3816
Id string `json:"id"`

0 commit comments

Comments
 (0)