Skip to content

Commit ff3c9bd

Browse files
committed
Added streaming chat for openai
1 parent 20053e7 commit ff3c9bd

File tree

9 files changed

+330
-57
lines changed

9 files changed

+330
-57
lines changed

client.go

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"encoding/xml"
7-
"errors"
87
"fmt"
98
"io"
109
"mime"
@@ -57,7 +56,6 @@ const (
5756
PathSeparator = string(os.PathSeparator)
5857
ContentTypeAny = "*/*"
5958
ContentTypeJson = "application/json"
60-
ContentTypeJsonStream = "application/x-ndjson"
6159
ContentTypeTextXml = "text/xml"
6260
ContentTypeApplicationXml = "application/xml"
6361
ContentTypeTextPlain = "text/plain"
@@ -301,31 +299,20 @@ func do(client *http.Client, req *http.Request, accept string, strict bool, out
301299
return nil
302300
}
303301

304-
// Decode the body - and call any callback once the body has been decoded
302+
// Decode the body
305303
switch mimetype {
306-
case ContentTypeJson, ContentTypeJsonStream:
307-
dec := json.NewDecoder(response.Body)
308-
for {
309-
if err := dec.Decode(out); errors.Is(err, io.EOF) {
310-
break
311-
} else if err != nil {
312-
return err
313-
}
314-
if reqopts.callback != nil {
315-
if err := reqopts.callback(); err != nil {
316-
return err
317-
}
318-
}
304+
case ContentTypeJson:
305+
if err := json.NewDecoder(response.Body).Decode(out); err != nil {
306+
return err
307+
}
308+
case ContentTypeTextStream:
309+
if err := NewTextStream().Decode(response.Body, reqopts.textStreamCallback); err != nil {
310+
return err
319311
}
320312
case ContentTypeTextXml, ContentTypeApplicationXml:
321313
if err := xml.NewDecoder(response.Body).Decode(out); err != nil {
322314
return err
323315
}
324-
if reqopts.callback != nil {
325-
if err := reqopts.callback(); err != nil {
326-
return err
327-
}
328-
}
329316
default:
330317
if v, ok := out.(Unmarshaler); ok {
331318
return v.Unmarshal(mimetype, response.Body)
@@ -336,11 +323,6 @@ func do(client *http.Client, req *http.Request, accept string, strict bool, out
336323
} else {
337324
return ErrInternalAppError.Withf("do: response does not implement Unmarshaler for %q", mimetype)
338325
}
339-
if reqopts.callback != nil {
340-
if err := reqopts.callback(); err != nil {
341-
return err
342-
}
343-
}
344326
}
345327

346328
// Return success

clientopts.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func OptUserAgent(value string) ClientOpt {
5656
// Setting verbose to true also displays the JSON response
5757
func OptTrace(w io.Writer, verbose bool) ClientOpt {
5858
return func(client *Client) error {
59-
client.Client.Transport = NewLogTransport(w, client.Client.Transport, verbose)
59+
client.Client.Transport = newLogTransport(w, client.Client.Transport, verbose)
6060
return nil
6161
}
6262
}

cmd/api/openai.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,21 @@ func openaiChat(ctx context.Context, w *tablewriter.Writer, args []string) error
221221
opts = append(opts, openai.OptResponseFormat(openaiResponseFormat))
222222
}
223223
if openaiStream {
224-
opts = append(opts, openai.OptStream())
224+
opts = append(opts, openai.OptStream(func(choice schema.MessageChoice) {
225+
w := w.Output()
226+
if choice.Delta == nil {
227+
return
228+
}
229+
if choice.Delta.Role != "" {
230+
fmt.Fprintf(w, "\nrole: %q\n", choice.Delta.Role)
231+
}
232+
if choice.Delta.Content != "" {
233+
fmt.Fprintf(w, "%v", choice.Delta.Content)
234+
}
235+
if choice.FinishReason != "" {
236+
fmt.Printf("\nfinish_reason: %q\n", choice.FinishReason)
237+
}
238+
}))
225239
}
226240
if openaiUser != "" {
227241
opts = append(opts, openai.OptUser(openaiUser))
@@ -243,7 +257,12 @@ func openaiChat(ctx context.Context, w *tablewriter.Writer, args []string) error
243257
return err
244258
}
245259

246-
return w.Write(responses)
260+
// Write table (if not streaming)
261+
if !openaiStream {
262+
return w.Write(responses)
263+
} else {
264+
return nil
265+
}
247266
}
248267

249268
func openaiTranscribe(ctx context.Context, w *tablewriter.Writer, args []string) error {

pkg/openai/chat.go

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai
33
import (
44
"context"
55
"encoding/json"
6+
"io"
67
"reflect"
78

89
// Packages
@@ -30,19 +31,15 @@ type respChat struct {
3031
Model string `json:"model"`
3132
Choices []*schema.MessageChoice `json:"choices"`
3233
SystemFingerprint string `json:"system_fingerprint,omitempty"`
33-
34-
Usage struct {
35-
PromptTokens int `json:"prompt_tokens"`
36-
CompletionTokens int `json:"completion_tokens"`
37-
TotalTokens int `json:"total_tokens"`
38-
} `json:"usage"`
34+
TokenUsage schema.TokenUsage `json:"usage,omitempty"`
3935
}
4036

4137
///////////////////////////////////////////////////////////////////////////////
4238
// GLOBALS
4339

4440
const (
4541
defaultChatCompletion = "gpt-3.5-turbo"
42+
endOfStreamToken = "[DONE]"
4643
)
4744

4845
///////////////////////////////////////////////////////////////////////////////
@@ -73,10 +70,20 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
7370
}
7471
}
7572

76-
// Return the response
73+
// Set up the request
74+
reqopts := []client.RequestOpt{
75+
client.OptPath("chat/completions"),
76+
}
77+
if request.Stream {
78+
reqopts = append(reqopts, client.OptTextStreamCallback(func(event client.TextStreamEvent) error {
79+
return response.streamCallback(event, request.StreamCallback)
80+
}))
81+
}
82+
83+
// Request->Response
7784
if payload, err := client.NewJSONRequest(request); err != nil {
7885
return nil, err
79-
} else if err := c.DoWithContext(ctx, payload, &response, client.OptPath("chat/completions")); err != nil {
86+
} else if err := c.DoWithContext(ctx, payload, &response, reqopts...); err != nil {
8087
return nil, err
8188
}
8289

@@ -101,3 +108,81 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O
101108
// Return success
102109
return result, nil
103110
}
111+
112+
///////////////////////////////////////////////////////////////////////////////
113+
// PRIVATE METHODS
114+
115+
func (response *respChat) streamCallback(v client.TextStreamEvent, fn Callback) error {
116+
var delta schema.MessageChunk
117+
118+
// [DONE] indicates the end of the stream, return io.EOF
119+
// or decode the data into a MessageChunk
120+
if v.Data == endOfStreamToken {
121+
return io.EOF
122+
} else if err := v.Json(&delta); err != nil {
123+
return err
124+
}
125+
126+
// Set the response fields
127+
if delta.Id != "" {
128+
response.Id = delta.Id
129+
}
130+
if delta.Model != "" {
131+
response.Model = delta.Model
132+
}
133+
if delta.Created != 0 {
134+
response.Created = delta.Created
135+
}
136+
if delta.SystemFingerprint != "" {
137+
response.SystemFingerprint = delta.SystemFingerprint
138+
}
139+
if delta.TokenUsage != nil {
140+
response.TokenUsage = *delta.TokenUsage
141+
}
142+
143+
// With no choices, return success
144+
if len(delta.Choices) == 0 {
145+
return nil
146+
}
147+
148+
// Append choices
149+
for _, choice := range delta.Choices {
150+
// Sanity check the choice index
151+
if choice.Index < 0 || choice.Index >= 6 {
152+
continue
153+
}
154+
// Ensure message has the choice
155+
for {
156+
if choice.Index < len(response.Choices) {
157+
break
158+
}
159+
response.Choices = append(response.Choices, new(schema.MessageChoice))
160+
}
161+
// Append the choice data onto the messahe
162+
if response.Choices[choice.Index].Message == nil {
163+
response.Choices[choice.Index].Message = new(schema.Message)
164+
}
165+
if choice.Index != 0 {
166+
response.Choices[choice.Index].Index = choice.Index
167+
}
168+
if choice.FinishReason != "" {
169+
response.Choices[choice.Index].FinishReason = choice.FinishReason
170+
}
171+
if choice.Delta != nil {
172+
if choice.Delta.Role != "" {
173+
response.Choices[choice.Index].Message.Role = choice.Delta.Role
174+
}
175+
if choice.Delta.Content != "" {
176+
response.Choices[choice.Index].Message.Add(choice.Delta.Content)
177+
}
178+
}
179+
180+
// Callback to the client
181+
if fn != nil {
182+
fn(choice)
183+
}
184+
}
185+
186+
// Return success
187+
return nil
188+
}

pkg/openai/opts.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type options struct {
2727
User string `json:"user,omitempty"`
2828
Stream bool `json:"stream,omitempty"`
2929
StreamOptions *streamoptions `json:"stream_options,omitempty"`
30+
StreamCallback Callback `json:"-"`
3031

3132
// Options for audio
3233
Speed *float32 `json:"speed,omitempty"`
@@ -45,6 +46,9 @@ type streamoptions struct {
4546

4647
type Opt func(*options) error
4748

49+
// Callback when new stream data is received
50+
type Callback func(schema.MessageChoice)
51+
4852
///////////////////////////////////////////////////////////////////////////////
4953
// PUBLIC METHODS
5054

@@ -127,14 +131,14 @@ func OptStop(value ...string) Opt {
127131
}
128132
}
129133

130-
// Partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
131-
// server-sent events as they become available, with the stream terminated by a data: [DONE]
132-
func OptStream() Opt {
134+
// Stream the response, which will be returned as a series of message chunks.
135+
func OptStream(fn Callback) Opt {
133136
return func(o *options) error {
134137
o.Stream = true
135138
o.StreamOptions = &streamoptions{
136139
IncludeUsage: true,
137140
}
141+
o.StreamCallback = fn
138142
return nil
139143
}
140144
}

pkg/openai/schema/message.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,23 @@ type Message struct {
3535
Created int64 `json:"created,omitempty"`
3636
}
3737

38+
// Chat completion chunk
39+
type MessageChunk struct {
40+
Id string `json:"id,omitempty"`
41+
Model string `json:"model,omitempty"`
42+
Created int64 `json:"created,omitempty"`
43+
SystemFingerprint string `json:"system_fingerprint,omitempty"`
44+
TokenUsage *TokenUsage `json:"usage,omitempty"`
45+
Choices []MessageChoice `json:"choices,omitempty"`
46+
}
47+
48+
// Token usage
49+
type TokenUsage struct {
50+
PromptTokens int `json:"prompt_tokens,omitempty"`
51+
CompletionTokens int `json:"completion_tokens,omitempty"`
52+
TotalTokens int `json:"total_tokens,omitempty"`
53+
}
54+
3855
// One choice of chat completion messages
3956
type MessageChoice struct {
4057
Message *Message `json:"message,omitempty"`
@@ -149,6 +166,11 @@ func (m MessageChoice) String() string {
149166
return string(data)
150167
}
151168

169+
func (m MessageChunk) String() string {
170+
data, _ := json.MarshalIndent(m, "", " ")
171+
return string(data)
172+
}
173+
152174
func (m MessageDelta) String() string {
153175
data, _ := json.MarshalIndent(m, "", " ")
154176
return string(data)

requestopts.go

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@ import (
1515

1616
type requestOpts struct {
1717
*http.Request
18-
19-
// OptResponse
20-
callback func() error
21-
22-
// OptNoTimeout
23-
noTimeout bool
18+
noTimeout bool // OptNoTimeout
19+
textStreamCallback TextStreamCallback // OptTextStreamCallback
2420
}
2521

2622
type RequestOpt func(*requestOpts) error
@@ -91,21 +87,19 @@ func OptReqHeader(name, value string) RequestOpt {
9187
}
9288
}
9389

94-
// OptResponse calls a function when a response has been decoded,
95-
// used for streaming responses. The function can return an error to
96-
// stop the request immediately
97-
func OptResponse(fn func() error) RequestOpt {
90+
// OptNoTimeout disables the timeout for this request, useful for long-running
91+
// requests. The context can be used instead for cancelling requests
92+
func OptNoTimeout() RequestOpt {
9893
return func(r *requestOpts) error {
99-
r.callback = fn
94+
r.noTimeout = true
10095
return nil
10196
}
10297
}
10398

104-
// OptNoTimeout disables the timeout for this request, useful for long-running
105-
// requests. The context can be used instead for cancelling requests
106-
func OptNoTimeout() RequestOpt {
99+
// OptTextStreamCallback is called for each event in a text stream
100+
func OptTextStreamCallback(fn TextStreamCallback) RequestOpt {
107101
return func(r *requestOpts) error {
108-
r.noTimeout = true
102+
r.textStreamCallback = fn
109103
return nil
110104
}
111105
}

0 commit comments

Comments
 (0)