Skip to content

Commit 8a48a43

Browse files
committed
Added chat completions
1 parent 772de80 commit 8a48a43

File tree

7 files changed

+135
-18
lines changed

7 files changed

+135
-18
lines changed

cmd/cli/mistral.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
// Package imports
55
"github.com/mutablelogic/go-client/pkg/client"
66
"github.com/mutablelogic/go-client/pkg/mistral"
7+
"github.com/mutablelogic/go-client/pkg/openai/schema"
78
)
89

910
/////////////////////////////////////////////////////////////////////
@@ -24,6 +25,7 @@ func MistralRegister(cmd []Client, opts []client.ClientOpt, flags *Flags) ([]Cli
2425
ns: "mistral",
2526
cmd: []Command{
2627
{Name: "models", Description: "Return registered models", MinArgs: 2, MaxArgs: 2, Fn: mistralModels(mistral, flags)},
28+
{Name: "chat", Description: "Chat", Syntax: "<prompt>", MinArgs: 3, MaxArgs: 3, Fn: mistralChat(mistral, flags)},
2729
},
2830
})
2931

@@ -43,3 +45,16 @@ func mistralModels(client *mistral.Client, flags *Flags) CommandFn {
4345
}
4446
}
4547
}
48+
49+
func mistralChat(client *mistral.Client, flags *Flags) CommandFn {
50+
return func() error {
51+
if message, err := client.Chat([]schema.Message{
52+
{Role: "user", Content: flags.Arg(2)},
53+
}); err != nil {
54+
return err
55+
} else if err := flags.Write(message); err != nil {
56+
return err
57+
}
58+
return nil
59+
}
60+
}

pkg/mistral/chat.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package mistral
2+
3+
import (
4+
client "github.com/mutablelogic/go-client/pkg/client"
5+
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
6+
7+
// Namespace imports
8+
. "github.com/djthorpe/go-errors"
9+
)
10+
11+
///////////////////////////////////////////////////////////////////////////////
12+
// TYPES
13+
14+
type reqChat struct {
15+
Model string `json:"model"`
16+
Messages []schema.Message `json:"messages,omitempty"`
17+
Temperature float64 `json:"temperature,omitempty"`
18+
TopP float64 `json:"top_p,omitempty"`
19+
MaxTokens int `json:"max_tokens,omitempty"`
20+
Stream bool `json:"stream,omitempty"`
21+
SafePrompt bool `json:"safe_prompt,omitempty"`
22+
Seed int `json:"random_seed,omitempty"`
23+
}
24+
25+
type respChat struct {
26+
Id string `json:"id"`
27+
Created int64 `json:"created"`
28+
Model string `json:"model"`
29+
Choices []schema.MessageChoice `json:"choices,omitempty"`
30+
Usage struct {
31+
PromptTokens int `json:"prompt_tokens"`
32+
CompletionTokens int `json:"completion_tokens"`
33+
TotalTokens int `json:"total_tokens"`
34+
} `json:"usage"`
35+
}
36+
37+
///////////////////////////////////////////////////////////////////////////////
38+
// GLOBALS
39+
40+
const (
41+
defaultChatCompletionModel = "mistral-small-latest"
42+
)
43+
44+
///////////////////////////////////////////////////////////////////////////////
45+
// API CALLS
46+
47+
// Chat creates a model response for the given chat conversation.
48+
func (c *Client) Chat(messages []schema.Message) (schema.Message, error) {
49+
var request reqChat
50+
var response respChat
51+
52+
request.Model = defaultChatCompletionModel
53+
request.Messages = messages
54+
55+
// Return the response
56+
if payload, err := client.NewJSONRequest(request, client.ContentTypeJson); err != nil {
57+
return schema.Message{}, err
58+
} else if err := c.Do(payload.Post(), &response, client.OptPath("chat/completions")); err != nil {
59+
return schema.Message{}, err
60+
} else if len(response.Choices) == 0 {
61+
return schema.Message{}, ErrNotFound
62+
} else {
63+
return response.Choices[0].Message, nil
64+
}
65+
}

pkg/mistral/chat_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package mistral_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
// Packages
8+
opts "github.com/mutablelogic/go-client/pkg/client"
9+
mistral "github.com/mutablelogic/go-client/pkg/mistral"
10+
"github.com/mutablelogic/go-client/pkg/openai/schema"
11+
assert "github.com/stretchr/testify/assert"
12+
)
13+
14+
func Test_chat_001(t *testing.T) {
15+
assert := assert.New(t)
16+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
17+
assert.NoError(err)
18+
assert.NotNil(client)
19+
err = client.Chat([]schema.Message{
20+
{Role: "user", Content: "What is the weather"},
21+
})
22+
assert.NoError(err)
23+
}

pkg/openai/embedding.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@ package openai
22

33
import (
44
// Packages
5-
"github.com/mutablelogic/go-client/pkg/client"
5+
client "github.com/mutablelogic/go-client/pkg/client"
6+
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
67

78
// Namespace imports
89
. "github.com/djthorpe/go-errors"
9-
. "github.com/mutablelogic/go-client/pkg/openai/schema"
1010
)
1111

1212
///////////////////////////////////////////////////////////////////////////////
1313
// API CALLS
1414

1515
// CreateEmbedding creates an embedding from a string or array of strings
16-
func (c *Client) CreateEmbedding(content any, opts ...Opt) (Embeddings, error) {
16+
func (c *Client) CreateEmbedding(content any, opts ...Opt) (schema.Embeddings, error) {
1717

1818
// Apply the options
1919
var request reqCreateEmbedding
2020
for _, opt := range opts {
2121
if err := opt(&request); err != nil {
22-
return Embeddings{}, err
22+
return schema.Embeddings{}, err
2323
}
2424
}
2525

@@ -30,15 +30,15 @@ func (c *Client) CreateEmbedding(content any, opts ...Opt) (Embeddings, error) {
3030
case []string:
3131
request.Input = v
3232
default:
33-
return Embeddings{}, ErrBadParameter
33+
return schema.Embeddings{}, ErrBadParameter
3434
}
3535

3636
// Return the response
37-
var response Embeddings
37+
var response schema.Embeddings
3838
if payload, err := client.NewJSONRequest(request, client.ContentTypeJson); err != nil {
39-
return Embeddings{}, err
39+
return schema.Embeddings{}, err
4040
} else if err := c.Do(payload.Post(), &response, client.OptPath("embeddings")); err != nil {
41-
return Embeddings{}, err
41+
return schema.Embeddings{}, err
4242
}
4343

4444
// Return success

pkg/openai/model.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@ package openai
22

33
import (
44
// Packages
5-
"github.com/mutablelogic/go-client/pkg/client"
6-
7-
// Namespace imports
8-
. "github.com/mutablelogic/go-client/pkg/openai/schema"
5+
client "github.com/mutablelogic/go-client/pkg/client"
6+
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
97
)
108

119
///////////////////////////////////////////////////////////////////////////////
1210
// API CALLS
1311

1412
// ListModels returns all the models
15-
func (c *Client) ListModels() ([]Model, error) {
13+
func (c *Client) ListModels() ([]schema.Model, error) {
1614
// Return the response
1715
var response responseListModels
1816
payload := client.NewRequest(client.ContentTypeJson)
@@ -25,12 +23,12 @@ func (c *Client) ListModels() ([]Model, error) {
2523
}
2624

2725
// GetModel returns one model
28-
func (c *Client) GetModel(model string) (Model, error) {
26+
func (c *Client) GetModel(model string) (schema.Model, error) {
2927
// Return the response
30-
var response Model
28+
var response schema.Model
3129
payload := client.NewRequest(client.ContentTypeJson)
3230
if err := c.Do(payload, &response, client.OptPath("models", model)); err != nil {
33-
return Model{}, err
31+
return schema.Model{}, err
3432
}
3533

3634
// Return success

pkg/openai/schema.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package openai
33
import (
44
"encoding/json"
55

6+
// Packages
7+
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
8+
69
// Namespace imports
710
. "github.com/djthorpe/go-errors"
8-
. "github.com/mutablelogic/go-client/pkg/openai/schema"
911
)
1012

1113
///////////////////////////////////////////////////////////////////////////////
@@ -174,7 +176,7 @@ type reqImage struct {
174176
// RESPONSES
175177

176178
type responseListModels struct {
177-
Data []Model `json:"data"`
179+
Data []schema.Model `json:"data"`
178180
}
179181

180182
///////////////////////////////////////////////////////////////////////////////

pkg/openai/schema/chat.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package schema
2+
3+
// A chat completion message
4+
type Message struct {
5+
Role string `json:"role"`
6+
Content string `json:"content"`
7+
}
8+
9+
// One choice of chat completion messages
10+
type MessageChoice struct {
11+
Message `json:"message"`
12+
Index int `json:"index"`
13+
FinishReason string `json:"finish_reason"`
14+
}

0 commit comments

Comments
 (0)