1
1
package mistral
2
2
3
3
import (
4
- // Packages
4
+ "bufio"
5
+ "bytes"
5
6
"context"
7
+ "encoding/json"
8
+ "io"
6
9
"reflect"
10
+ "strings"
7
11
12
+ // Packages
8
13
client "github.com/mutablelogic/go-client"
9
14
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
10
15
@@ -25,18 +30,25 @@ type respChat struct {
25
30
Created int64 `json:"created"`
26
31
Model string `json:"model"`
27
32
Choices []schema.MessageChoice `json:"choices,omitempty"`
28
- Usage struct {
29
- PromptTokens int `json:"prompt_tokens"`
30
- CompletionTokens int `json:"completion_tokens"`
31
- TotalTokens int `json:"total_tokens"`
32
- } `json:"usage"`
33
+ Usage * respUsage `json:"usage,omitempty"`
34
+
35
+ // Private fields
36
+ callback Callback `json:"-"`
37
+ }
38
+
39
+ type respUsage struct {
40
+ PromptTokens int `json:"prompt_tokens"`
41
+ CompletionTokens int `json:"completion_tokens"`
42
+ TotalTokens int `json:"total_tokens"`
33
43
}
34
44
35
45
///////////////////////////////////////////////////////////////////////////////
36
46
// GLOBALS
37
47
38
48
const (
39
49
defaultChatCompletionModel = "mistral-small-latest"
50
+ contentTypeTextStream = "text/event-stream"
51
+ endOfStream = "[DONE]"
40
52
)
41
53
42
54
///////////////////////////////////////////////////////////////////////////////
@@ -61,6 +73,9 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
61
73
}
62
74
}
63
75
76
+ // Set the callback
77
+ response .callback = request .callback
78
+
64
79
// Request->Response
65
80
if payload , err := client .NewJSONRequest (request ); err != nil {
66
81
return nil , err
@@ -73,13 +88,117 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
73
88
// Return all choices
74
89
var result []* schema.Content
75
90
for _ , choice := range response .Choices {
76
- if str , ok := choice .Content .(string ); ok {
77
- result = append (result , schema .Text (str ))
78
- } else {
79
- return nil , ErrUnexpectedResponse .With ("unexpected content type" , reflect .TypeOf (choice .Content ))
91
+ if choice .Message == nil || choice .Message .Content == nil {
92
+ continue
93
+ }
94
+ switch v := choice .Message .Content .(type ) {
95
+ case []string :
96
+ for _ , v := range v {
97
+ result = append (result , schema .Text (v ))
98
+ }
99
+ case string :
100
+ result = append (result , schema .Text (v ))
101
+ default :
102
+ return nil , ErrUnexpectedResponse .With ("unexpected content type " , reflect .TypeOf (choice .Message .Content ))
80
103
}
81
104
}
82
105
83
106
// Return success
84
107
return result , nil
85
108
}
109
+
110
+ ///////////////////////////////////////////////////////////////////////////////
111
+ // STRINGIFY
112
+
113
+ func (s respChat ) String () string {
114
+ data , _ := json .MarshalIndent (s , "" , " " )
115
+ return string (data )
116
+ }
117
+
118
+ ///////////////////////////////////////////////////////////////////////////////
119
+ // UNMARSHAL TEXT STREAM
120
+
121
+ func (m * respChat ) Unmarshal (mimetype string , r io.Reader ) error {
122
+ switch mimetype {
123
+ case client .ContentTypeJson :
124
+ return json .NewDecoder (r ).Decode (m )
125
+ case contentTypeTextStream :
126
+ return m .decodeTextStream (r )
127
+ default :
128
+ return ErrUnexpectedResponse .Withf ("%q" , mimetype )
129
+ }
130
+ }
131
+
132
+ func (m * respChat ) decodeTextStream (r io.Reader ) error {
133
+ var stream respChat
134
+ scanner := bufio .NewScanner (r )
135
+ buf := new (bytes.Buffer )
136
+
137
+ FOR_LOOP:
138
+ for scanner .Scan () {
139
+ data := scanner .Text ()
140
+ switch {
141
+ case data == "" :
142
+ continue FOR_LOOP
143
+ case strings .HasPrefix (data , "data:" ) && strings .HasSuffix (data , endOfStream ):
144
+ // [DONE] - Set usage from the stream, break the loop
145
+ m .Usage = stream .Usage
146
+ break FOR_LOOP
147
+ case strings .HasPrefix (data , "data:" ):
148
+ // Reset
149
+ stream .Choices = nil
150
+
151
+ // Decode JSON data
152
+ data = data [6 :]
153
+ if _ , err := buf .WriteString (data ); err != nil {
154
+ return err
155
+ } else if err := json .Unmarshal (buf .Bytes (), & stream ); err != nil {
156
+ return err
157
+ }
158
+
159
+ // Check for sane data
160
+ if len (stream .Choices ) == 0 {
161
+ return ErrUnexpectedResponse .With ("no choices returned" )
162
+ } else if stream .Choices [0 ].Index != 0 {
163
+ return ErrUnexpectedResponse .With ("unexpected choice" , stream .Choices [0 ].Index )
164
+ } else if stream .Choices [0 ].Delta == nil {
165
+ return ErrUnexpectedResponse .With ("no delta returned" )
166
+ }
167
+
168
+ // Append the choice
169
+ if len (m .Choices ) == 0 {
170
+ message := schema .NewMessage (stream .Choices [0 ].Delta .Role , stream .Choices [0 ].Delta .Content )
171
+ m .Choices = append (m .Choices , schema.MessageChoice {
172
+ Index : stream .Choices [0 ].Index ,
173
+ Message : message ,
174
+ FinishReason : stream .Choices [0 ].FinishReason ,
175
+ })
176
+ } else {
177
+ // Append text to the message
178
+ m .Choices [0 ].Message .Add (stream .Choices [0 ].Delta .Content )
179
+
180
+ // If the finish reason is set
181
+ if stream .Choices [0 ].FinishReason != "" {
182
+ m .Choices [0 ].FinishReason = stream .Choices [0 ].FinishReason
183
+ }
184
+ }
185
+
186
+ // Set the model and id
187
+ m .Id = stream .Id
188
+ m .Model = stream .Model
189
+
190
+ // Callback
191
+ if m .callback != nil {
192
+ m .callback (stream .Choices [0 ])
193
+ }
194
+
195
+ // Reset the buffer
196
+ buf .Reset ()
197
+ default :
198
+ return ErrUnexpectedResponse .Withf ("%q" , data )
199
+ }
200
+ }
201
+
202
+ // Return any errors from the scanner
203
+ return scanner .Err ()
204
+ }
0 commit comments