Skip to content

Commit 813e808

Browse files
committed
Updated mistral
1 parent 15e772f commit 813e808

File tree

7 files changed

+176
-55
lines changed

7 files changed

+176
-55
lines changed

pkg/mistral/chat_completion_test.go

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"testing"
99

1010
// Packages
11-
opts "github.com/mutablelogic/go-client"
11+
1212
"github.com/mutablelogic/go-llm"
1313
mistral "github.com/mutablelogic/go-llm/pkg/mistral"
1414
"github.com/mutablelogic/go-llm/pkg/tool"
@@ -17,11 +17,8 @@ import (
1717

1818
func Test_chat_001(t *testing.T) {
1919
assert := assert.New(t)
20-
21-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
22-
assert.NoError(err)
23-
2420
model := client.Model(context.TODO(), "mistral-small-latest")
21+
2522
if assert.NotNil(model) {
2623
response, err := client.ChatCompletion(context.TODO(), model.UserPrompt("Hello, how are you?"))
2724
assert.NoError(err)
@@ -32,8 +29,6 @@ func Test_chat_001(t *testing.T) {
3229

3330
func Test_chat_002(t *testing.T) {
3431
assert := assert.New(t)
35-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
36-
assert.NoError(err)
3732
model := client.Model(context.TODO(), "mistral-large-latest")
3833
if !assert.NotNil(model) {
3934
t.FailNow()
@@ -181,8 +176,6 @@ func Test_chat_002(t *testing.T) {
181176

182177
func Test_chat_003(t *testing.T) {
183178
assert := assert.New(t)
184-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
185-
assert.NoError(err)
186179
model := client.Model(context.TODO(), "pixtral-12b-2409")
187180
if !assert.NotNil(model) {
188181
t.FailNow()
@@ -206,8 +199,6 @@ func Test_chat_003(t *testing.T) {
206199

207200
func Test_chat_004(t *testing.T) {
208201
assert := assert.New(t)
209-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
210-
assert.NoError(err)
211202
model := client.Model(context.TODO(), "mistral-small-latest")
212203
if !assert.NotNil(model) {
213204
t.FailNow()

pkg/mistral/client_test.go

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package mistral_test
22

33
import (
4+
"flag"
5+
"log"
46
"os"
7+
"strconv"
58
"testing"
69

710
// Packages
@@ -10,22 +13,46 @@ import (
1013
assert "github.com/stretchr/testify/assert"
1114
)
1215

13-
func Test_client_001(t *testing.T) {
14-
assert := assert.New(t)
15-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
16-
assert.NoError(err)
17-
assert.NotNil(client)
18-
t.Log(client)
16+
///////////////////////////////////////////////////////////////////////////////
17+
// TEST SET-UP
18+
19+
var (
20+
client *mistral.Client
21+
)
22+
23+
func TestMain(m *testing.M) {
24+
var verbose bool
25+
26+
// Verbose output
27+
flag.Parse()
28+
if f := flag.Lookup("test.v"); f != nil {
29+
if v, err := strconv.ParseBool(f.Value.String()); err == nil {
30+
verbose = v
31+
}
32+
}
33+
34+
// API KEY
35+
api_key := os.Getenv("MISTRAL_API_KEY")
36+
if api_key == "" {
37+
log.Print("MISTRAL_API_KEY not set")
38+
os.Exit(0)
39+
}
40+
41+
// Create client
42+
var err error
43+
client, err = mistral.New(api_key, opts.OptTrace(os.Stderr, verbose))
44+
if err != nil {
45+
log.Println(err)
46+
os.Exit(-1)
47+
}
48+
os.Exit(m.Run())
1949
}
2050

2151
///////////////////////////////////////////////////////////////////////////////
22-
// ENVIRONMENT
52+
// TESTS
2353

24-
func GetApiKey(t *testing.T) string {
25-
key := os.Getenv("MISTRAL_API_KEY")
26-
if key == "" {
27-
t.Skip("MISTRAL_API_KEY not set")
28-
t.SkipNow()
29-
}
30-
return key
54+
func Test_client_001(t *testing.T) {
55+
assert := assert.New(t)
56+
assert.NotNil(client)
57+
t.Log(client)
3158
}

pkg/mistral/embeddings.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package mistral
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
// Packages
8+
client "github.com/mutablelogic/go-client"
9+
llm "github.com/mutablelogic/go-llm"
10+
)
11+
12+
///////////////////////////////////////////////////////////////////////////////
13+
// TYPES
14+
15+
// embeddings is the implementation of the llm.Embedding interface
16+
type embeddings struct {
17+
Embeddings
18+
}
19+
20+
// Embeddings is the metadata for a generated embedding vector
21+
type Embeddings struct {
22+
Id string `json:"id"`
23+
Type string `json:"object"`
24+
Model string `json:"model"`
25+
Data []Embedding `json:"data"`
26+
Metrics
27+
}
28+
29+
// Embedding is a single vector
30+
type Embedding struct {
31+
Type string `json:"object"`
32+
Index uint64 `json:"index"`
33+
Vector []float64 `json:"embedding"`
34+
}
35+
36+
///////////////////////////////////////////////////////////////////////////////
37+
// STRINGIFY
38+
39+
func (m Embedding) MarshalJSON() ([]byte, error) {
40+
return json.Marshal(m.Vector)
41+
}
42+
43+
func (m embeddings) MarshalJSON() ([]byte, error) {
44+
return json.Marshal(m.Embeddings)
45+
}
46+
47+
func (m embeddings) String() string {
48+
data, err := json.MarshalIndent(m, "", " ")
49+
if err != nil {
50+
return err.Error()
51+
}
52+
return string(data)
53+
}
54+
55+
///////////////////////////////////////////////////////////////////////////////
56+
// PUBLIC METHODS
57+
58+
type reqEmbedding struct {
59+
Model string `json:"model"`
60+
Input []string `json:"input"`
61+
Format string `json:"encoding_format,omitempty"`
62+
}
63+
64+
func (mistral *Client) GenerateEmbedding(ctx context.Context, name string, prompt []string, _ ...llm.Opt) (*embeddings, error) {
65+
// Options are currently ignored
66+
67+
// Bail out is no prompt
68+
if len(prompt) == 0 {
69+
return nil, llm.ErrBadParameter.With("missing prompt")
70+
}
71+
72+
// Request
73+
req, err := client.NewJSONRequest(reqEmbedding{
74+
Model: name,
75+
Input: prompt,
76+
})
77+
if err != nil {
78+
return nil, err
79+
}
80+
81+
// Response
82+
var response embeddings
83+
if err := mistral.DoWithContext(ctx, req, &response, client.OptPath("embeddings")); err != nil {
84+
return nil, err
85+
}
86+
87+
// Return success
88+
return &response, nil
89+
}
90+
91+
// Generate one vector
92+
func (model *model) Embedding(ctx context.Context, prompt string, opts ...llm.Opt) ([]float64, error) {
93+
response, err := model.GenerateEmbedding(ctx, model.Name(), []string{prompt}, opts...)
94+
if err != nil {
95+
return nil, err
96+
}
97+
if len(response.Embeddings.Data) == 0 {
98+
return nil, llm.ErrNotFound.With("no embeddings returned")
99+
}
100+
return response.Embeddings.Data[0].Vector, nil
101+
}

pkg/mistral/embeddings_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package mistral_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
// Packages
8+
assert "github.com/stretchr/testify/assert"
9+
)
10+
11+
func Test_embeddings_001(t *testing.T) {
12+
assert := assert.New(t)
13+
model := client.Model(context.TODO(), "mistral-embed")
14+
if assert.NotNil(model) {
15+
response, err := model.Embedding(context.TODO(), "Hello, how are you?")
16+
assert.NoError(err)
17+
assert.NotEmpty(response)
18+
t.Log(response)
19+
}
20+
}

pkg/mistral/model.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,3 @@ func (c *Client) ListModels(ctx context.Context) ([]llm.Model, error) {
8080
func (m model) Name() string {
8181
return m.meta.Name
8282
}
83-
84-
// Embedding vector generation
85-
func (m model) Embedding(context.Context, string, ...llm.Opt) ([]float64, error) {
86-
return nil, llm.ErrNotImplemented
87-
}

pkg/mistral/model_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,20 @@ package mistral_test
33
import (
44
"context"
55
"encoding/json"
6-
"os"
76
"testing"
87

98
// Packages
10-
opts "github.com/mutablelogic/go-client"
11-
mistral "github.com/mutablelogic/go-llm/pkg/mistral"
9+
1210
assert "github.com/stretchr/testify/assert"
1311
)
1412

1513
func Test_models_001(t *testing.T) {
1614
assert := assert.New(t)
17-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
18-
assert.NoError(err)
19-
assert.NotNil(client)
15+
2016
response, err := client.ListModels(context.TODO())
2117
assert.NoError(err)
2218
assert.NotEmpty(response)
19+
2320
data, err := json.MarshalIndent(response, "", " ")
2421
assert.NoError(err)
2522
t.Log(string(data))

pkg/mistral/session_test.go

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,16 @@ package mistral_test
22

33
import (
44
"context"
5-
"os"
65
"testing"
76

87
// Packages
9-
opts "github.com/mutablelogic/go-client"
10-
"github.com/mutablelogic/go-llm"
11-
mistral "github.com/mutablelogic/go-llm/pkg/mistral"
12-
"github.com/mutablelogic/go-llm/pkg/tool"
8+
llm "github.com/mutablelogic/go-llm"
9+
tool "github.com/mutablelogic/go-llm/pkg/tool"
1310
assert "github.com/stretchr/testify/assert"
1411
)
1512

1613
func Test_session_001(t *testing.T) {
1714
assert := assert.New(t)
18-
19-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
20-
assert.NoError(err)
21-
2215
model := client.Model(context.TODO(), "mistral-small-latest")
2316
if !assert.NotNil(model) {
2417
t.FailNow()
@@ -34,10 +27,6 @@ func Test_session_001(t *testing.T) {
3427

3528
func Test_session_002(t *testing.T) {
3629
assert := assert.New(t)
37-
38-
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
39-
assert.NoError(err)
40-
4130
model := client.Model(context.TODO(), "mistral-small-latest")
4231
if !assert.NotNil(model) {
4332
t.FailNow()
@@ -53,14 +42,15 @@ func Test_session_002(t *testing.T) {
5342

5443
assert.NoError(session.FromUser(context.TODO(), "What is the weather like in London today?"))
5544
calls := session.ToolCalls(0)
56-
assert.Len(calls, 1)
57-
assert.Equal("weather_in_city", calls[0].Name())
45+
if assert.Len(calls, 1) {
46+
assert.Equal("weather_in_city", calls[0].Name())
5847

59-
result, err := toolkit.Run(context.TODO(), calls...)
60-
assert.NoError(err)
61-
assert.Len(result, 1)
48+
result, err := toolkit.Run(context.TODO(), calls...)
49+
assert.NoError(err)
50+
assert.Len(result, 1)
6251

63-
assert.NoError(session.FromTool(context.TODO(), result...))
52+
assert.NoError(session.FromTool(context.TODO(), result...))
53+
}
6454

6555
t.Log(session)
6656
}

0 commit comments

Comments
 (0)