Skip to content

Commit b0065fc

Browse files
committed
Added mistral options
1 parent 91822c8 commit b0065fc

File tree

10 files changed

+312
-27
lines changed

10 files changed

+312
-27
lines changed

cmd/api/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func main() {
2424
elRegister(flags)
2525
haRegister(flags)
2626
ipifyRegister(flags)
27+
mistralRegister(flags)
2728
newsapiRegister(flags)
2829
samRegister(flags)
2930
weatherapiRegister(flags)

cmd/api/mistral.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
// Packages
8+
9+
"github.com/djthorpe/go-tablewriter"
10+
"github.com/mutablelogic/go-client"
11+
"github.com/mutablelogic/go-client/pkg/mistral"
12+
"github.com/mutablelogic/go-client/pkg/openai/schema"
13+
)
14+
15+
///////////////////////////////////////////////////////////////////////////////
16+
// GLOBALS
17+
18+
var (
19+
mistralName = "mistral"
20+
mistralClient *mistral.Client
21+
mistralModel string
22+
mistralEncodingFormat string
23+
mistralTemperature *float64
24+
mistralMaxTokens *uint64
25+
mistralStream *bool
26+
mistralSafePrompt bool
27+
mistralSeed *uint64
28+
mistralSystemPrompt string
29+
)
30+
31+
///////////////////////////////////////////////////////////////////////////////
32+
// LIFECYCLE
33+
34+
func mistralRegister(flags *Flags) {
35+
// Register flags required
36+
flags.String(mistralName, "mistral-api-key", "${MISTRAL_API_KEY}", "API Key")
37+
flags.String(mistralName, "model", "", "Model to use")
38+
flags.String(mistralName, "encoding-format", "", "The format of the output data")
39+
flags.String(mistralName, "system", "", "Provide a system prompt to the model")
40+
flags.Float(mistralName, "temperature", 0, "Sampling temperature to use, between 0.0 and 1.0")
41+
flags.Unsigned(mistralName, "max-tokens", 0, "Maximum number of tokens to generate")
42+
flags.Bool(mistralName, "stream", false, "Stream output")
43+
flags.Bool(mistralName, "safe-prompt", false, "Inject a safety prompt before all conversations.")
44+
flags.Unsigned(mistralName, "seed", 0, "Set random seed")
45+
46+
flags.Register(Cmd{
47+
Name: mistralName,
48+
Description: "Interact with Mistral models, from https://docs.mistral.ai/api/",
49+
Parse: mistralParse,
50+
Fn: []Fn{
51+
{Name: "models", Call: mistralModels, Description: "Gets a list of available models"},
52+
{Name: "embeddings", Call: mistralEmbeddings, Description: "Create embeddings from text", MinArgs: 1, Syntax: "<text>..."},
53+
{Name: "chat", Call: mistralChat, Description: "Create a chat completion", MinArgs: 1, Syntax: "<text>..."},
54+
},
55+
})
56+
}
57+
58+
func mistralParse(flags *Flags, opts ...client.ClientOpt) error {
59+
apiKey := flags.GetString("mistral-api-key")
60+
if apiKey == "" {
61+
return fmt.Errorf("missing -mistral-api-key flag")
62+
} else if client, err := mistral.New(apiKey, opts...); err != nil {
63+
return err
64+
} else {
65+
mistralClient = client
66+
}
67+
68+
// Get the command-line parameters
69+
mistralModel = flags.GetString("model")
70+
mistralEncodingFormat = flags.GetString("encoding-format")
71+
mistralSafePrompt = flags.GetBool("safe-prompt")
72+
mistralSystemPrompt = flags.GetString("system")
73+
if temp, err := flags.GetValue("temperature"); err == nil {
74+
t := temp.(float64)
75+
mistralTemperature = &t
76+
}
77+
if maxtokens, err := flags.GetValue("max-tokens"); err == nil {
78+
t := maxtokens.(uint64)
79+
mistralMaxTokens = &t
80+
}
81+
if stream, err := flags.GetValue("stream"); err == nil {
82+
t := stream.(bool)
83+
mistralStream = &t
84+
}
85+
if seed, err := flags.GetValue("seed"); err == nil {
86+
t := seed.(uint64)
87+
mistralSeed = &t
88+
}
89+
90+
// Return success
91+
return nil
92+
}
93+
94+
///////////////////////////////////////////////////////////////////////////////
95+
// METHODS
96+
97+
func mistralModels(ctx context.Context, writer *tablewriter.Writer, args []string) error {
98+
// Get models
99+
models, err := mistralClient.ListModels()
100+
if err != nil {
101+
return err
102+
}
103+
104+
return writer.Write(models)
105+
}
106+
107+
func mistralEmbeddings(ctx context.Context, writer *tablewriter.Writer, args []string) error {
108+
// Set options
109+
opts := []mistral.Opt{}
110+
if mistralModel != "" {
111+
opts = append(opts, mistral.OptModel(mistralModel))
112+
}
113+
if mistralEncodingFormat != "" {
114+
opts = append(opts, mistral.OptEncodingFormat(mistralEncodingFormat))
115+
}
116+
117+
// Get embeddings
118+
embeddings, err := mistralClient.CreateEmbedding(args, opts...)
119+
if err != nil {
120+
return err
121+
}
122+
return writer.Write(embeddings)
123+
}
124+
125+
func mistralChat(ctx context.Context, w *tablewriter.Writer, args []string) error {
126+
var messages []*schema.Message
127+
128+
// Set options
129+
opts := []mistral.Opt{}
130+
if mistralModel != "" {
131+
opts = append(opts, mistral.OptModel(mistralModel))
132+
}
133+
if mistralTemperature != nil {
134+
opts = append(opts, mistral.OptTemperature(*mistralTemperature))
135+
}
136+
if mistralMaxTokens != nil {
137+
opts = append(opts, mistral.OptMaxTokens(int(*mistralMaxTokens)))
138+
}
139+
if mistralStream != nil {
140+
opts = append(opts, mistral.OptStream(func() {
141+
fmt.Println("STREAM")
142+
}))
143+
}
144+
if mistralSafePrompt {
145+
opts = append(opts, mistral.OptSafePrompt())
146+
}
147+
if mistralSeed != nil {
148+
opts = append(opts, mistral.OptSeed(int(*mistralSeed)))
149+
}
150+
if mistralSystemPrompt != "" {
151+
messages = append(messages, schema.NewMessage("system").Add(schema.Text(mistralSystemPrompt)))
152+
}
153+
154+
// Append user message
155+
message := schema.NewMessage("user")
156+
for _, arg := range args {
157+
message.Add(schema.Text(arg))
158+
}
159+
messages = append(messages, message)
160+
161+
// Request -> Response
162+
responses, err := mistralClient.Chat(ctx, messages, opts...)
163+
if err != nil {
164+
return err
165+
}
166+
167+
// Write table
168+
return w.Write(responses)
169+
}
File renamed without changes.
File renamed without changes.

pkg/anthropic/opts.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package anthropic
22

33
import (
4+
// Package imports
5+
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
6+
47
// Namespace imports
58
. "github.com/djthorpe/go-errors"
6-
"github.com/mutablelogic/go-client/pkg/openai/schema"
79
)
810

911
///////////////////////////////////////////////////////////////////////////////

pkg/mistral/chat.go

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package mistral
22

33
import (
44
// Packages
5+
"context"
6+
"reflect"
7+
58
client "github.com/mutablelogic/go-client"
69
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
710

@@ -13,14 +16,8 @@ import (
1316
// TYPES
1417

1518
type reqChat struct {
16-
Model string `json:"model"`
17-
Messages []schema.Message `json:"messages,omitempty"`
18-
Temperature float64 `json:"temperature,omitempty"`
19-
TopP float64 `json:"top_p,omitempty"`
20-
MaxTokens int `json:"max_tokens,omitempty"`
21-
Stream bool `json:"stream,omitempty"`
22-
SafePrompt bool `json:"safe_prompt,omitempty"`
23-
Seed int `json:"random_seed,omitempty"`
19+
options
20+
Messages []*schema.Message `json:"messages,omitempty"`
2421
}
2522

2623
type respChat struct {
@@ -46,21 +43,43 @@ const (
4643
// API CALLS
4744

4845
// Chat creates a model response for the given chat conversation.
49-
func (c *Client) Chat(messages []schema.Message) (schema.Message, error) {
46+
func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...Opt) ([]*schema.Content, error) {
5047
var request reqChat
5148
var response respChat
5249

50+
// Check messages
51+
if len(messages) == 0 {
52+
return nil, ErrBadParameter.With("missing messages")
53+
}
54+
55+
// Process options
5356
request.Model = defaultChatCompletionModel
5457
request.Messages = messages
58+
for _, opt := range opts {
59+
if err := opt(&request.options); err != nil {
60+
return nil, err
61+
}
62+
}
5563

56-
// Return the response
64+
// Request->Response
5765
if payload, err := client.NewJSONRequest(request); err != nil {
58-
return schema.Message{}, err
59-
} else if err := c.Do(payload, &response, client.OptPath("chat/completions")); err != nil {
60-
return schema.Message{}, err
66+
return nil, err
67+
} else if err := c.DoWithContext(ctx, payload, &response, client.OptPath("chat/completions")); err != nil {
68+
return nil, err
6169
} else if len(response.Choices) == 0 {
62-
return schema.Message{}, ErrNotFound
63-
} else {
64-
return response.Choices[0].Message, nil
70+
return nil, ErrUnexpectedResponse.With("no choices returned")
6571
}
72+
73+
// Return all choices
74+
var result []*schema.Content
75+
for _, choice := range response.Choices {
76+
if str, ok := choice.Content.(string); ok {
77+
result = append(result, schema.Text(str))
78+
} else {
79+
return nil, ErrUnexpectedResponse.With("unexpected content type", reflect.TypeOf(choice.Content))
80+
}
81+
}
82+
83+
// Return success
84+
return result, nil
6685
}

pkg/mistral/embedding.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ import (
1414

1515
// A request to create embeddings
1616
type reqCreateEmbedding struct {
17-
Input []string `json:"input"`
18-
Model string `json:"model"`
19-
EncodingFormat string `json:"encoding_format,omitempty"`
17+
Input []string `json:"input"`
18+
options
2019
}
2120

2221
///////////////////////////////////////////////////////////////////////////////
@@ -30,12 +29,17 @@ const (
3029
// API CALLS
3130

3231
// CreateEmbedding creates an embedding from a string or array of strings
33-
func (c *Client) CreateEmbedding(content any) (schema.Embeddings, error) {
32+
func (c *Client) CreateEmbedding(content any, opts ...Opt) (schema.Embeddings, error) {
3433
var request reqCreateEmbedding
3534
var response schema.Embeddings
3635

37-
// Set default model
36+
// Set options
3837
request.Model = defaultEmbeddingModel
38+
for _, opt := range opts {
39+
if err := opt(&request.options); err != nil {
40+
return response, err
41+
}
42+
}
3943

4044
// Set the input, which is either a string or array of strings
4145
switch v := content.(type) {
@@ -44,7 +48,7 @@ func (c *Client) CreateEmbedding(content any) (schema.Embeddings, error) {
4448
case []string:
4549
request.Input = v
4650
default:
47-
return response, ErrBadParameter
51+
return response, ErrBadParameter.With("CreateEmbedding")
4852
}
4953

5054
// Request->Response

0 commit comments

Comments
 (0)