Skip to content

Commit ae05ed9

Browse files
authored
handle stream completion (#86)
* handle stream completion * fix tests
1 parent 1eb5d62 commit ae05ed9

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

stream.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"errors"
99
"fmt"
10+
"io"
1011
"net/http"
1112
)
1213

@@ -16,12 +17,18 @@ var (
1617

1718
type CompletionStream struct {
1819
emptyMessagesLimit uint
20+
isFinished bool
1921

2022
reader *bufio.Reader
2123
response *http.Response
2224
}
2325

2426
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
27+
if stream.isFinished {
28+
err = io.EOF
29+
return
30+
}
31+
2532
var emptyMessagesCount uint
2633

2734
waitForData:
@@ -44,6 +51,8 @@ waitForData:
4451

4552
line = bytes.TrimPrefix(line, headerData)
4653
if string(line) == "[DONE]" {
54+
stream.isFinished = true
55+
err = io.EOF
4756
return
4857
}
4958

stream_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"github.com/sashabaranov/go-gpt3/internal/test"
66

77
"context"
8+
"errors"
9+
"io"
810
"net/http"
911
"net/http/httptest"
1012
"testing"
@@ -75,7 +77,6 @@ func TestCreateCompletionStream(t *testing.T) {
7577
Model: "text-davinci-002",
7678
Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}},
7779
},
78-
{},
7980
}
8081

8182
for ix, expectedResponse := range expectedResponses {
@@ -87,6 +88,11 @@ func TestCreateCompletionStream(t *testing.T) {
8788
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
8889
}
8990
}
91+
92+
_, streamErr := stream.Recv()
93+
if !errors.Is(streamErr, io.EOF) {
94+
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
95+
}
9096
}
9197

9298
// A "tokenRoundTripper" is a struct that implements the RoundTripper

0 commit comments

Comments
 (0)