Skip to content

Commit d7ace9f

Browse files
committed
Updated embedding call
1 parent d0680d5 commit d7ace9f

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

pkg/mistral/chat_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func Test_chat_001(t *testing.T) {
1616
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
1717
assert.NoError(err)
1818
assert.NotNil(client)
19-
err = client.Chat([]schema.Message{
19+
_, err = client.Chat([]schema.Message{
2020
{Role: "user", Content: "What is the weather"},
2121
})
2222
assert.NoError(err)

pkg/mistral/embedding.go

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ package mistral
22

33
import (
44
// Packages
5-
"github.com/mutablelogic/go-client/pkg/client"
5+
client "github.com/mutablelogic/go-client/pkg/client"
6+
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
67

78
// Namespace imports
89
. "github.com/djthorpe/go-errors"
9-
. "github.com/mutablelogic/go-client/pkg/openai/schema"
1010
)
1111

1212
///////////////////////////////////////////////////////////////////////////////
@@ -19,12 +19,23 @@ type reqCreateEmbedding struct {
1919
EncodingFormat string `json:"encoding_format,omitempty"`
2020
}
2121

22+
///////////////////////////////////////////////////////////////////////////////
23+
// GLOBALS
24+
25+
const (
26+
defaultEmbeddingModel = "mistral-embed"
27+
)
28+
2229
///////////////////////////////////////////////////////////////////////////////
2330
// API CALLS
2431

2532
// CreateEmbedding creates an embedding from a string or array of strings
26-
func (c *Client) CreateEmbedding(content any) (Embeddings, error) {
33+
func (c *Client) CreateEmbedding(content any) (schema.Embeddings, error) {
2734
var request reqCreateEmbedding
35+
var response schema.Embeddings
36+
37+
// Set default model
38+
request.Model = defaultEmbeddingModel
2839

2940
// Set the input, which is either a string or array of strings
3041
switch v := content.(type) {
@@ -33,15 +44,14 @@ func (c *Client) CreateEmbedding(content any) (Embeddings, error) {
3344
case []string:
3445
request.Input = v
3546
default:
36-
return Embeddings{}, ErrBadParameter
47+
return response, ErrBadParameter
3748
}
3849

39-
// Return the response
40-
var response Embeddings
50+
// Request->Response
4151
if payload, err := client.NewJSONRequest(request, client.ContentTypeJson); err != nil {
42-
return Embeddings{}, err
52+
return response, err
4353
} else if err := c.Do(payload.Post(), &response, client.OptPath("embeddings")); err != nil {
44-
return Embeddings{}, err
54+
return response, err
4555
}
4656

4757
// Return success

pkg/mistral/embedding_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package mistral_test
2+
3+
import (
4+
"encoding/json"
5+
"os"
6+
"testing"
7+
8+
// Packages
9+
opts "github.com/mutablelogic/go-client/pkg/client"
10+
mistral "github.com/mutablelogic/go-client/pkg/mistral"
11+
assert "github.com/stretchr/testify/assert"
12+
)
13+
14+
func Test_embedding_001(t *testing.T) {
15+
assert := assert.New(t)
16+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
17+
assert.NoError(err)
18+
assert.NotNil(client)
19+
response, err := client.CreateEmbedding("test")
20+
assert.NoError(err)
21+
data, _ := json.MarshalIndent(response, "", " ")
22+
t.Log(string(data))
23+
}

0 commit comments

Comments
 (0)