Skip to content

Commit fa2019f

Browse files
committed
Finished the streaming version
1 parent b0065fc commit fa2019f

File tree

5 files changed

+160
-21
lines changed

5 files changed

+160
-21
lines changed

cmd/api/mistral.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ func mistralChat(ctx context.Context, w *tablewriter.Writer, args []string) erro
137137
opts = append(opts, mistral.OptMaxTokens(int(*mistralMaxTokens)))
138138
}
139139
if mistralStream != nil {
140-
opts = append(opts, mistral.OptStream(func() {
141-
fmt.Println("STREAM")
140+
opts = append(opts, mistral.OptStream(func(choice schema.MessageChoice) {
141+
fmt.Printf("%v\n", choice)
142142
}))
143143
}
144144
if mistralSafePrompt {

pkg/mistral/chat.go

Lines changed: 129 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
package mistral
22

33
import (
4-
// Packages
4+
"bufio"
5+
"bytes"
56
"context"
7+
"encoding/json"
8+
"io"
69
"reflect"
10+
"strings"
711

12+
// Packages
813
client "github.com/mutablelogic/go-client"
914
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
1015

@@ -25,18 +30,25 @@ type respChat struct {
2530
Created int64 `json:"created"`
2631
Model string `json:"model"`
2732
Choices []schema.MessageChoice `json:"choices,omitempty"`
28-
Usage struct {
29-
PromptTokens int `json:"prompt_tokens"`
30-
CompletionTokens int `json:"completion_tokens"`
31-
TotalTokens int `json:"total_tokens"`
32-
} `json:"usage"`
33+
Usage *respUsage `json:"usage,omitempty"`
34+
35+
// Private fields
36+
callback Callback `json:"-"`
37+
}
38+
39+
type respUsage struct {
40+
PromptTokens int `json:"prompt_tokens"`
41+
CompletionTokens int `json:"completion_tokens"`
42+
TotalTokens int `json:"total_tokens"`
3343
}
3444

3545
///////////////////////////////////////////////////////////////////////////////
3646
// GLOBALS
3747

3848
const (
3949
defaultChatCompletionModel = "mistral-small-latest"
50+
contentTypeTextStream = "text/event-stream"
51+
endOfStream = "[DONE]"
4052
)
4153

4254
///////////////////////////////////////////////////////////////////////////////
@@ -61,6 +73,9 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
6173
}
6274
}
6375

76+
// Set the callback
77+
response.callback = request.callback
78+
6479
// Request->Response
6580
if payload, err := client.NewJSONRequest(request); err != nil {
6681
return nil, err
@@ -73,13 +88,117 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
7388
// Return all choices
7489
var result []*schema.Content
7590
for _, choice := range response.Choices {
76-
if str, ok := choice.Content.(string); ok {
77-
result = append(result, schema.Text(str))
78-
} else {
79-
return nil, ErrUnexpectedResponse.With("unexpected content type", reflect.TypeOf(choice.Content))
91+
if choice.Message == nil || choice.Message.Content == nil {
92+
continue
93+
}
94+
switch v := choice.Message.Content.(type) {
95+
case []string:
96+
for _, v := range v {
97+
result = append(result, schema.Text(v))
98+
}
99+
case string:
100+
result = append(result, schema.Text(v))
101+
default:
102+
return nil, ErrUnexpectedResponse.With("unexpected content type ", reflect.TypeOf(choice.Message.Content))
80103
}
81104
}
82105

83106
// Return success
84107
return result, nil
85108
}
109+
110+
///////////////////////////////////////////////////////////////////////////////
111+
// STRINGIFY
112+
113+
func (s respChat) String() string {
114+
data, _ := json.MarshalIndent(s, "", " ")
115+
return string(data)
116+
}
117+
118+
///////////////////////////////////////////////////////////////////////////////
119+
// UNMARSHAL TEXT STREAM
120+
121+
func (m *respChat) Unmarshal(mimetype string, r io.Reader) error {
122+
switch mimetype {
123+
case client.ContentTypeJson:
124+
return json.NewDecoder(r).Decode(m)
125+
case contentTypeTextStream:
126+
return m.decodeTextStream(r)
127+
default:
128+
return ErrUnexpectedResponse.Withf("%q", mimetype)
129+
}
130+
}
131+
132+
func (m *respChat) decodeTextStream(r io.Reader) error {
133+
var stream respChat
134+
scanner := bufio.NewScanner(r)
135+
buf := new(bytes.Buffer)
136+
137+
FOR_LOOP:
138+
for scanner.Scan() {
139+
data := scanner.Text()
140+
switch {
141+
case data == "":
142+
continue FOR_LOOP
143+
case strings.HasPrefix(data, "data:") && strings.HasSuffix(data, endOfStream):
144+
// [DONE] - Set usage from the stream, break the loop
145+
m.Usage = stream.Usage
146+
break FOR_LOOP
147+
case strings.HasPrefix(data, "data:"):
148+
// Reset
149+
stream.Choices = nil
150+
151+
// Decode JSON data
152+
data = data[6:]
153+
if _, err := buf.WriteString(data); err != nil {
154+
return err
155+
} else if err := json.Unmarshal(buf.Bytes(), &stream); err != nil {
156+
return err
157+
}
158+
159+
// Check for sane data
160+
if len(stream.Choices) == 0 {
161+
return ErrUnexpectedResponse.With("no choices returned")
162+
} else if stream.Choices[0].Index != 0 {
163+
return ErrUnexpectedResponse.With("unexpected choice", stream.Choices[0].Index)
164+
} else if stream.Choices[0].Delta == nil {
165+
return ErrUnexpectedResponse.With("no delta returned")
166+
}
167+
168+
// Append the choice
169+
if len(m.Choices) == 0 {
170+
message := schema.NewMessage(stream.Choices[0].Delta.Role, stream.Choices[0].Delta.Content)
171+
m.Choices = append(m.Choices, schema.MessageChoice{
172+
Index: stream.Choices[0].Index,
173+
Message: message,
174+
FinishReason: stream.Choices[0].FinishReason,
175+
})
176+
} else {
177+
// Append text to the message
178+
m.Choices[0].Message.Add(stream.Choices[0].Delta.Content)
179+
180+
// If the finish reason is set
181+
if stream.Choices[0].FinishReason != "" {
182+
m.Choices[0].FinishReason = stream.Choices[0].FinishReason
183+
}
184+
}
185+
186+
// Set the model and id
187+
m.Id = stream.Id
188+
m.Model = stream.Model
189+
190+
// Callback
191+
if m.callback != nil {
192+
m.callback(stream.Choices[0])
193+
}
194+
195+
// Reset the buffer
196+
buf.Reset()
197+
default:
198+
return ErrUnexpectedResponse.Withf("%q", data)
199+
}
200+
}
201+
202+
// Return any errors from the scanner
203+
return scanner.Err()
204+
}

pkg/mistral/chat_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package mistral_test
22

33
import (
4+
"context"
45
"os"
56
"testing"
67

@@ -16,7 +17,7 @@ func Test_chat_001(t *testing.T) {
1617
client, err := mistral.New(GetApiKey(t), opts.OptTrace(os.Stderr, true))
1718
assert.NoError(err)
1819
assert.NotNil(client)
19-
_, err = client.Chat([]schema.Message{
20+
_, err = client.Chat(context.Background(), []*schema.Message{
2021
{Role: "user", Content: "What is the weather"},
2122
})
2223
assert.NoError(err)

pkg/mistral/opts.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package mistral
22

33
import (
4+
// Packages
5+
schema "github.com/mutablelogic/go-client/pkg/openai/schema"
6+
47
// Namespace imports
58
. "github.com/djthorpe/go-errors"
69
)
@@ -24,7 +27,8 @@ type options struct {
2427
// Opt is a function which can be used to set options on a request
2528
type Opt func(*options) error
2629

27-
type Callback func()
30+
// Callback when new stream data is received
31+
type Callback func(schema.MessageChoice)
2832

2933
///////////////////////////////////////////////////////////////////////////////
3034
// OPTIONS

pkg/openai/schema/message.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ import (
1818

1919
// A chat completion message
2020
type Message struct {
21-
// user or assistant
22-
Role string `json:"role"`
21+
// user, system or assistant
22+
Role string `json:"role,omitempty"`
2323

2424
// Message Id
2525
Id string `json:"id,omitempty"`
@@ -30,20 +30,30 @@ type Message struct {
3030
// Content can be a string, array of strings, content
3131
// object or an array of content objects
3232
Content any `json:"content,omitempty"`
33+
34+
// Time the message was created, in unix seconds
35+
Created int64 `json:"created,omitempty"`
3336
}
3437

3538
// One choice of chat completion messages
3639
type MessageChoice struct {
37-
Message `json:"message"`
38-
Index int `json:"index"`
39-
FinishReason string `json:"finish_reason"`
40+
Message *Message `json:"message,omitempty"`
41+
Delta *MessageDelta `json:"delta,omitempty"`
42+
Index int `json:"index"`
43+
FinishReason string `json:"finish_reason,omitempty"`
44+
}
45+
46+
// Delta between messages (for streaming responses)
47+
type MessageDelta struct {
48+
Role string `json:"role,omitempty"`
49+
Content string `json:"content,omitempty"`
4050
}
4151

4252
// Message Content
4353
type Content struct {
4454
Id string `json:"id,omitempty"`
45-
Type string `json:"type,width:4"`
46-
Text string `json:"text,omitempty,wrap,width:60"`
55+
Type string `json:"type" writer:",width:4"`
56+
Text string `json:"text,omitempty" writer:",width:60,wrap"`
4757
Source *contentSource `json:"source,omitempty"`
4858
toolUse
4959

@@ -134,6 +144,11 @@ func (m Message) String() string {
134144
return string(data)
135145
}
136146

147+
func (m MessageChoice) String() string {
148+
data, _ := json.MarshalIndent(m, "", " ")
149+
return string(data)
150+
}
151+
137152
func (c Content) String() string {
138153
data, _ := json.MarshalIndent(c, "", " ")
139154
return string(data)

0 commit comments

Comments
 (0)