Skip to content

Commit 72a5b56

Browse files
committed
Added tool calls to mistral
1 parent d26484c commit 72a5b56

File tree

4 files changed

+167
-123
lines changed

4 files changed

+167
-123
lines changed

pkg/mistral/chat.go

Lines changed: 105 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
package mistral
22

33
import (
4-
"bufio"
5-
"bytes"
64
"context"
75
"encoding/json"
86
"io"
97
"reflect"
10-
"strings"
118

129
// Packages
1310
client "github.com/mutablelogic/go-client"
@@ -20,37 +17,46 @@ import (
2017
///////////////////////////////////////////////////////////////////////////////
2118
// TYPES
2219

20+
// A request for a chat completion
2321
type reqChat struct {
2422
options
23+
Tools []reqChatTools `json:"tools,omitempty"`
2524
Messages []*schema.Message `json:"messages,omitempty"`
2625
}
2726

28-
type respChat struct {
29-
Id string `json:"id"`
30-
Created int64 `json:"created"`
31-
Model string `json:"model"`
32-
Choices []schema.MessageChoice `json:"choices,omitempty"`
33-
Usage *respUsage `json:"usage,omitempty"`
34-
35-
// Private fields
36-
callback Callback `json:"-"`
27+
type reqChatTools struct {
28+
Type string `json:"type"`
29+
Function *schema.Tool `json:"function"`
3730
}
3831

39-
type respUsage struct {
40-
PromptTokens int `json:"prompt_tokens"`
41-
CompletionTokens int `json:"completion_tokens"`
42-
TotalTokens int `json:"total_tokens"`
32+
// A chat completion object
33+
type respChat struct {
34+
Id string `json:"id"`
35+
Created int64 `json:"created"`
36+
Model string `json:"model"`
37+
Choices []*schema.MessageChoice `json:"choices,omitempty"`
38+
TokenUsage schema.TokenUsage `json:"usage,omitempty"`
4339
}
4440

4541
///////////////////////////////////////////////////////////////////////////////
4642
// GLOBALS
4743

4844
const (
4945
defaultChatCompletionModel = "mistral-small-latest"
50-
contentTypeTextStream = "text/event-stream"
51-
endOfStream = "[DONE]"
46+
endOfStreamToken = "[DONE]"
5247
)
5348

49+
///////////////////////////////////////////////////////////////////////////////
50+
// STRINGIFY
51+
52+
func (v respChat) String() string {
53+
if data, err := json.MarshalIndent(v, "", " "); err != nil {
54+
return err.Error()
55+
} else {
56+
return string(data)
57+
}
58+
}
59+
5460
///////////////////////////////////////////////////////////////////////////////
5561
// API CALLS
5662

@@ -59,11 +65,6 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
5965
var request reqChat
6066
var response respChat
6167

62-
// Check messages
63-
if len(messages) == 0 {
64-
return nil, ErrBadParameter.With("missing messages")
65-
}
66-
6768
// Process options
6869
request.Model = defaultChatCompletionModel
6970
request.Messages = messages
@@ -73,13 +74,28 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
7374
}
7475
}
7576

76-
// Set the callback
77-
response.callback = request.callback
77+
// Append tools
78+
for _, tool := range request.options.Tools {
79+
request.Tools = append(request.Tools, reqChatTools{
80+
Type: "function",
81+
Function: tool,
82+
})
83+
}
84+
85+
// Set up the request
86+
reqopts := []client.RequestOpt{
87+
client.OptPath("chat/completions"),
88+
}
89+
if request.Stream {
90+
reqopts = append(reqopts, client.OptTextStreamCallback(func(event client.TextStreamEvent) error {
91+
return response.streamCallback(event, request.StreamCallback)
92+
}))
93+
}
7894

7995
// Request->Response
8096
if payload, err := client.NewJSONRequest(request); err != nil {
8197
return nil, err
82-
} else if err := c.DoWithContext(ctx, payload, &response, client.OptPath("chat/completions")); err != nil {
98+
} else if err := c.DoWithContext(ctx, payload, &response, reqopts...); err != nil {
8399
return nil, err
84100
} else if len(response.Choices) == 0 {
85101
return nil, ErrUnexpectedResponse.With("no choices returned")
@@ -91,6 +107,9 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
91107
if choice.Message == nil || choice.Message.Content == nil {
92108
continue
93109
}
110+
for _, tool := range choice.Message.ToolCalls {
111+
result = append(result, schema.ToolUse(tool))
112+
}
94113
switch v := choice.Message.Content.(type) {
95114
case []string:
96115
for _, v := range v {
@@ -108,97 +127,76 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
108127
}
109128

110129
///////////////////////////////////////////////////////////////////////////////
111-
// STRINGIFY
130+
// PRIVATE METHODS
112131

113-
func (s respChat) String() string {
114-
data, _ := json.MarshalIndent(s, "", " ")
115-
return string(data)
116-
}
132+
func (response *respChat) streamCallback(v client.TextStreamEvent, fn Callback) error {
133+
var delta schema.MessageChunk
117134

118-
///////////////////////////////////////////////////////////////////////////////
119-
// UNMARSHAL TEXT STREAM
120-
121-
func (m *respChat) Unmarshal(mimetype string, r io.Reader) error {
122-
switch mimetype {
123-
case client.ContentTypeJson:
124-
return json.NewDecoder(r).Decode(m)
125-
case contentTypeTextStream:
126-
return m.decodeTextStream(r)
127-
default:
128-
return ErrUnexpectedResponse.Withf("%q", mimetype)
135+
// [DONE] indicates the end of the stream, return io.EOF
136+
// or decode the data into a MessageChunk
137+
if v.Data == endOfStreamToken {
138+
return io.EOF
139+
} else if err := v.Json(&delta); err != nil {
140+
return err
129141
}
130-
}
131142

132-
func (m *respChat) decodeTextStream(r io.Reader) error {
133-
var stream respChat
134-
scanner := bufio.NewScanner(r)
135-
buf := new(bytes.Buffer)
136-
137-
FOR_LOOP:
138-
for scanner.Scan() {
139-
data := scanner.Text()
140-
switch {
141-
case data == "":
142-
continue FOR_LOOP
143-
case strings.HasPrefix(data, "data:") && strings.HasSuffix(data, endOfStream):
144-
// [DONE] - Set usage from the stream, break the loop
145-
m.Usage = stream.Usage
146-
break FOR_LOOP
147-
case strings.HasPrefix(data, "data:"):
148-
// Reset
149-
stream.Choices = nil
150-
151-
// Decode JSON data
152-
data = data[6:]
153-
if _, err := buf.WriteString(data); err != nil {
154-
return err
155-
} else if err := json.Unmarshal(buf.Bytes(), &stream); err != nil {
156-
return err
157-
}
143+
// Set the response fields
144+
if delta.Id != "" {
145+
response.Id = delta.Id
146+
}
147+
if delta.Model != "" {
148+
response.Model = delta.Model
149+
}
150+
if delta.Created != 0 {
151+
response.Created = delta.Created
152+
}
153+
if delta.TokenUsage != nil {
154+
response.TokenUsage = *delta.TokenUsage
155+
}
158156

159-
// Check for sane data
160-
if len(stream.Choices) == 0 {
161-
return ErrUnexpectedResponse.With("no choices returned")
162-
} else if stream.Choices[0].Index != 0 {
163-
return ErrUnexpectedResponse.With("unexpected choice", stream.Choices[0].Index)
164-
} else if stream.Choices[0].Delta == nil {
165-
return ErrUnexpectedResponse.With("no delta returned")
166-
}
157+
// With no choices, return success
158+
if len(delta.Choices) == 0 {
159+
return nil
160+
}
167161

168-
// Append the choice
169-
if len(m.Choices) == 0 {
170-
message := schema.NewMessage(stream.Choices[0].Delta.Role, stream.Choices[0].Delta.Content)
171-
m.Choices = append(m.Choices, schema.MessageChoice{
172-
Index: stream.Choices[0].Index,
173-
Message: message,
174-
FinishReason: stream.Choices[0].FinishReason,
175-
})
176-
} else {
177-
// Append text to the message
178-
m.Choices[0].Message.Add(stream.Choices[0].Delta.Content)
179-
180-
// If the finish reason is set
181-
if stream.Choices[0].FinishReason != "" {
182-
m.Choices[0].FinishReason = stream.Choices[0].FinishReason
183-
}
162+
// Append choices
163+
for _, choice := range delta.Choices {
164+
// Sanity check the choice index
165+
if choice.Index < 0 || choice.Index >= 6 {
166+
continue
167+
}
168+
// Ensure message has the choice
169+
for {
170+
if choice.Index < len(response.Choices) {
171+
break
184172
}
185-
186-
// Set the model and id
187-
m.Id = stream.Id
188-
m.Model = stream.Model
189-
190-
// Callback
191-
if m.callback != nil {
192-
m.callback(stream.Choices[0])
173+
response.Choices = append(response.Choices, new(schema.MessageChoice))
174+
}
175+
// Append the choice data onto the messahe
176+
if response.Choices[choice.Index].Message == nil {
177+
response.Choices[choice.Index].Message = new(schema.Message)
178+
}
179+
if choice.Index != 0 {
180+
response.Choices[choice.Index].Index = choice.Index
181+
}
182+
if choice.FinishReason != "" {
183+
response.Choices[choice.Index].FinishReason = choice.FinishReason
184+
}
185+
if choice.Delta != nil {
186+
if choice.Delta.Role != "" {
187+
response.Choices[choice.Index].Message.Role = choice.Delta.Role
193188
}
189+
if choice.Delta.Content != "" {
190+
response.Choices[choice.Index].Message.Add(choice.Delta.Content)
191+
}
192+
}
194193

195-
// Reset the buffer
196-
buf.Reset()
197-
default:
198-
return ErrUnexpectedResponse.Withf("%q", data)
194+
// Callback to the client
195+
if fn != nil {
196+
fn(choice)
199197
}
200198
}
201199

202-
// Return any errors from the scanner
203-
return scanner.Err()
200+
// Return success
201+
return nil
204202
}

pkg/mistral/chat_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mistral_test
33
import (
44
"context"
55
"os"
6+
"reflect"
67
"testing"
78

89
// Packages
@@ -22,3 +23,31 @@ func Test_chat_001(t *testing.T) {
2223
})
2324
assert.NoError(err)
2425
}
26+
27+
func Test_chat_002(t *testing.T) {
28+
assert := assert.New(t)
29+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
30+
assert.NoError(err)
31+
assert.NotNil(client)
32+
_, err = client.Chat(context.Background(), []*schema.Message{
33+
{Role: "user", Content: "What is the weather"},
34+
}, mistral.OptStream(func(message schema.MessageChoice) {
35+
t.Log(message)
36+
}))
37+
assert.NoError(err)
38+
}
39+
40+
func Test_chat_003(t *testing.T) {
41+
assert := assert.New(t)
42+
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
43+
assert.NoError(err)
44+
assert.NotNil(client)
45+
46+
tool := schema.NewTool("weather", "get weather in a specific city")
47+
tool.Add("city", "name of the city, if known", false, reflect.TypeOf(""))
48+
49+
_, err = client.Chat(context.Background(), []*schema.Message{
50+
{Role: "user", Content: "What is the weather in Berlin"},
51+
}, mistral.OptTool(tool))
52+
assert.NoError(err)
53+
}

pkg/mistral/opts.go

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@ import (
1212
// TYPES
1313

1414
type options struct {
15+
// Common options
1516
Model string `json:"model,omitempty"`
1617
EncodingFormat string `json:"encoding_format,omitempty"`
17-
Temperature *float64 `json:"temperature,omitempty"`
18+
Temperature *float32 `json:"temperature,omitempty"`
1819
MaxTokens int `json:"max_tokens,omitempty"`
19-
Stream bool `json:"stream,omitempty"`
2020
SafePrompt bool `json:"safe_prompt,omitempty"`
2121
Seed int `json:"random_seed,omitempty"`
2222

23-
// Private methods
24-
callback Callback `json:"-"`
23+
// Options for chat
24+
Stream bool `json:"stream,omitempty"`
25+
StreamCallback Callback `json:"-"`
26+
Tools []*schema.Tool `json:"-"`
2527
}
2628

2729
// Opt is a function which can be used to set options on a request
@@ -61,7 +63,7 @@ func OptMaxTokens(v int) Opt {
6163
func OptStream(fn Callback) Opt {
6264
return func(o *options) error {
6365
o.Stream = true
64-
o.callback = fn
66+
o.StreamCallback = fn
6567
return nil
6668
}
6769
}
@@ -83,7 +85,7 @@ func OptSeed(v int) Opt {
8385
}
8486

8587
// Amount of randomness injected into the response.
86-
func OptTemperature(v float64) Opt {
88+
func OptTemperature(v float32) Opt {
8789
return func(o *options) error {
8890
if v < 0.0 || v > 1.0 {
8991
return ErrBadParameter.With("OptTemperature")
@@ -92,3 +94,21 @@ func OptTemperature(v float64) Opt {
9294
return nil
9395
}
9496
}
97+
98+
// A list of tools the model may call.
99+
func OptTool(value ...*schema.Tool) Opt {
100+
return func(o *options) error {
101+
// Check tools
102+
for _, tool := range value {
103+
if tool == nil {
104+
return ErrBadParameter.With("OptTool")
105+
}
106+
}
107+
108+
// Append tools
109+
o.Tools = append(o.Tools, value...)
110+
111+
// Return success
112+
return nil
113+
}
114+
}

0 commit comments

Comments
 (0)