1
1
package mistral
2
2
3
3
import (
4
- "bufio"
5
- "bytes"
6
4
"context"
7
5
"encoding/json"
8
6
"io"
9
7
"reflect"
10
- "strings"
11
8
12
9
// Packages
13
10
client "github.com/mutablelogic/go-client"
@@ -20,37 +17,46 @@ import (
20
17
///////////////////////////////////////////////////////////////////////////////
21
18
// TYPES
22
19
20
+ // A request for a chat completion
23
21
type reqChat struct {
24
22
options
23
+ Tools []reqChatTools `json:"tools,omitempty"`
25
24
Messages []* schema.Message `json:"messages,omitempty"`
26
25
}
27
26
28
- type respChat struct {
29
- Id string `json:"id"`
30
- Created int64 `json:"created"`
31
- Model string `json:"model"`
32
- Choices []schema.MessageChoice `json:"choices,omitempty"`
33
- Usage * respUsage `json:"usage,omitempty"`
34
-
35
- // Private fields
36
- callback Callback `json:"-"`
27
+ type reqChatTools struct {
28
+ Type string `json:"type"`
29
+ Function * schema.Tool `json:"function"`
37
30
}
38
31
39
- type respUsage struct {
40
- PromptTokens int `json:"prompt_tokens"`
41
- CompletionTokens int `json:"completion_tokens"`
42
- TotalTokens int `json:"total_tokens"`
32
+ // A chat completion object
33
+ type respChat struct {
34
+ Id string `json:"id"`
35
+ Created int64 `json:"created"`
36
+ Model string `json:"model"`
37
+ Choices []* schema.MessageChoice `json:"choices,omitempty"`
38
+ TokenUsage schema.TokenUsage `json:"usage,omitempty"`
43
39
}
44
40
45
41
///////////////////////////////////////////////////////////////////////////////
46
42
// GLOBALS
47
43
48
44
const (
49
45
defaultChatCompletionModel = "mistral-small-latest"
50
- contentTypeTextStream = "text/event-stream"
51
- endOfStream = "[DONE]"
46
+ endOfStreamToken = "[DONE]"
52
47
)
53
48
49
+ ///////////////////////////////////////////////////////////////////////////////
50
+ // STRINGIFY
51
+
52
+ func (v respChat ) String () string {
53
+ if data , err := json .MarshalIndent (v , "" , " " ); err != nil {
54
+ return err .Error ()
55
+ } else {
56
+ return string (data )
57
+ }
58
+ }
59
+
54
60
///////////////////////////////////////////////////////////////////////////////
55
61
// API CALLS
56
62
@@ -59,11 +65,6 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
59
65
var request reqChat
60
66
var response respChat
61
67
62
- // Check messages
63
- if len (messages ) == 0 {
64
- return nil , ErrBadParameter .With ("missing messages" )
65
- }
66
-
67
68
// Process options
68
69
request .Model = defaultChatCompletionModel
69
70
request .Messages = messages
@@ -73,13 +74,28 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
73
74
}
74
75
}
75
76
76
- // Set the callback
77
- response .callback = request .callback
77
+ // Append tools
78
+ for _ , tool := range request .options .Tools {
79
+ request .Tools = append (request .Tools , reqChatTools {
80
+ Type : "function" ,
81
+ Function : tool ,
82
+ })
83
+ }
84
+
85
+ // Set up the request
86
+ reqopts := []client.RequestOpt {
87
+ client .OptPath ("chat/completions" ),
88
+ }
89
+ if request .Stream {
90
+ reqopts = append (reqopts , client .OptTextStreamCallback (func (event client.TextStreamEvent ) error {
91
+ return response .streamCallback (event , request .StreamCallback )
92
+ }))
93
+ }
78
94
79
95
// Request->Response
80
96
if payload , err := client .NewJSONRequest (request ); err != nil {
81
97
return nil , err
82
- } else if err := c .DoWithContext (ctx , payload , & response , client . OptPath ( "chat/completions" ) ); err != nil {
98
+ } else if err := c .DoWithContext (ctx , payload , & response , reqopts ... ); err != nil {
83
99
return nil , err
84
100
} else if len (response .Choices ) == 0 {
85
101
return nil , ErrUnexpectedResponse .With ("no choices returned" )
@@ -91,6 +107,9 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
91
107
if choice .Message == nil || choice .Message .Content == nil {
92
108
continue
93
109
}
110
+ for _ , tool := range choice .Message .ToolCalls {
111
+ result = append (result , schema .ToolUse (tool ))
112
+ }
94
113
switch v := choice .Message .Content .(type ) {
95
114
case []string :
96
115
for _ , v := range v {
@@ -108,97 +127,76 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
108
127
}
109
128
110
129
///////////////////////////////////////////////////////////////////////////////
111
- // STRINGIFY
130
+ // PRIVATE METHODS
112
131
113
- func (s respChat ) String () string {
114
- data , _ := json .MarshalIndent (s , "" , " " )
115
- return string (data )
116
- }
132
+ func (response * respChat ) streamCallback (v client.TextStreamEvent , fn Callback ) error {
133
+ var delta schema.MessageChunk
117
134
118
- ///////////////////////////////////////////////////////////////////////////////
119
- // UNMARSHAL TEXT STREAM
120
-
121
- func (m * respChat ) Unmarshal (mimetype string , r io.Reader ) error {
122
- switch mimetype {
123
- case client .ContentTypeJson :
124
- return json .NewDecoder (r ).Decode (m )
125
- case contentTypeTextStream :
126
- return m .decodeTextStream (r )
127
- default :
128
- return ErrUnexpectedResponse .Withf ("%q" , mimetype )
135
+ // [DONE] indicates the end of the stream, return io.EOF
136
+ // or decode the data into a MessageChunk
137
+ if v .Data == endOfStreamToken {
138
+ return io .EOF
139
+ } else if err := v .Json (& delta ); err != nil {
140
+ return err
129
141
}
130
- }
131
142
132
- func (m * respChat ) decodeTextStream (r io.Reader ) error {
133
- var stream respChat
134
- scanner := bufio .NewScanner (r )
135
- buf := new (bytes.Buffer )
136
-
137
- FOR_LOOP:
138
- for scanner .Scan () {
139
- data := scanner .Text ()
140
- switch {
141
- case data == "" :
142
- continue FOR_LOOP
143
- case strings .HasPrefix (data , "data:" ) && strings .HasSuffix (data , endOfStream ):
144
- // [DONE] - Set usage from the stream, break the loop
145
- m .Usage = stream .Usage
146
- break FOR_LOOP
147
- case strings .HasPrefix (data , "data:" ):
148
- // Reset
149
- stream .Choices = nil
150
-
151
- // Decode JSON data
152
- data = data [6 :]
153
- if _ , err := buf .WriteString (data ); err != nil {
154
- return err
155
- } else if err := json .Unmarshal (buf .Bytes (), & stream ); err != nil {
156
- return err
157
- }
143
+ // Set the response fields
144
+ if delta .Id != "" {
145
+ response .Id = delta .Id
146
+ }
147
+ if delta .Model != "" {
148
+ response .Model = delta .Model
149
+ }
150
+ if delta .Created != 0 {
151
+ response .Created = delta .Created
152
+ }
153
+ if delta .TokenUsage != nil {
154
+ response .TokenUsage = * delta .TokenUsage
155
+ }
158
156
159
- // Check for sane data
160
- if len (stream .Choices ) == 0 {
161
- return ErrUnexpectedResponse .With ("no choices returned" )
162
- } else if stream .Choices [0 ].Index != 0 {
163
- return ErrUnexpectedResponse .With ("unexpected choice" , stream .Choices [0 ].Index )
164
- } else if stream .Choices [0 ].Delta == nil {
165
- return ErrUnexpectedResponse .With ("no delta returned" )
166
- }
157
+ // With no choices, return success
158
+ if len (delta .Choices ) == 0 {
159
+ return nil
160
+ }
167
161
168
- // Append the choice
169
- if len (m .Choices ) == 0 {
170
- message := schema .NewMessage (stream .Choices [0 ].Delta .Role , stream .Choices [0 ].Delta .Content )
171
- m .Choices = append (m .Choices , schema.MessageChoice {
172
- Index : stream .Choices [0 ].Index ,
173
- Message : message ,
174
- FinishReason : stream .Choices [0 ].FinishReason ,
175
- })
176
- } else {
177
- // Append text to the message
178
- m .Choices [0 ].Message .Add (stream .Choices [0 ].Delta .Content )
179
-
180
- // If the finish reason is set
181
- if stream .Choices [0 ].FinishReason != "" {
182
- m .Choices [0 ].FinishReason = stream .Choices [0 ].FinishReason
183
- }
162
+ // Append choices
163
+ for _ , choice := range delta .Choices {
164
+ // Sanity check the choice index
165
+ if choice .Index < 0 || choice .Index >= 6 {
166
+ continue
167
+ }
168
+ // Ensure message has the choice
169
+ for {
170
+ if choice .Index < len (response .Choices ) {
171
+ break
184
172
}
185
-
186
- // Set the model and id
187
- m .Id = stream .Id
188
- m .Model = stream .Model
189
-
190
- // Callback
191
- if m .callback != nil {
192
- m .callback (stream .Choices [0 ])
173
+ response .Choices = append (response .Choices , new (schema.MessageChoice ))
174
+ }
175
+ // Append the choice data onto the messahe
176
+ if response .Choices [choice .Index ].Message == nil {
177
+ response .Choices [choice .Index ].Message = new (schema.Message )
178
+ }
179
+ if choice .Index != 0 {
180
+ response .Choices [choice .Index ].Index = choice .Index
181
+ }
182
+ if choice .FinishReason != "" {
183
+ response .Choices [choice .Index ].FinishReason = choice .FinishReason
184
+ }
185
+ if choice .Delta != nil {
186
+ if choice .Delta .Role != "" {
187
+ response .Choices [choice .Index ].Message .Role = choice .Delta .Role
193
188
}
189
+ if choice .Delta .Content != "" {
190
+ response .Choices [choice .Index ].Message .Add (choice .Delta .Content )
191
+ }
192
+ }
194
193
195
- // Reset the buffer
196
- buf .Reset ()
197
- default :
198
- return ErrUnexpectedResponse .Withf ("%q" , data )
194
+ // Callback to the client
195
+ if fn != nil {
196
+ fn (choice )
199
197
}
200
198
}
201
199
202
- // Return any errors from the scanner
203
- return scanner . Err ()
200
+ // Return success
201
+ return nil
204
202
}
0 commit comments