Skip to content

feat: support for Assistant stream mode and implemented stream event callbacks #713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions assistant_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package openai

import (
"context"
"fmt"
"net/http"
)

type AssistantThreadRunStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Delta MessageDelta `json:"delta,omitempty"`
}

type AssistantThreadRunStream struct {
*streamReader[AssistantThreadRunStreamResponse]
}

func (c *Client) CreateAssistantThreadRunStream(
ctx context.Context,
threadID string,
request RunRequest,
) (stream *AssistantThreadRunStream, err error) {
request.Stream = true
urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID)
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix),
withBody(request),
withBetaAssistantV1(),
)
if err != nil {
return nil, err

Check warning on line 34 in assistant_stream.go

View check run for this annotation

Codecov / codecov/patch

assistant_stream.go#L34

Added line #L34 was not covered by tests
}

resp, err := sendRequestStream[AssistantThreadRunStreamResponse](c, req)
if err != nil {
return nil, err

Check warning on line 39 in assistant_stream.go

View check run for this annotation

Codecov / codecov/patch

assistant_stream.go#L39

Added line #L39 was not covered by tests
}
stream = &AssistantThreadRunStream{
streamReader: resp,
}
return
}

func (c *Client) CreateAssistantThreadRunSubmitToolOutputStream(
ctx context.Context,
threadID string,
runID string,
request SubmitToolOutputsRequest,
) (stream *AssistantThreadRunStream, err error) {
request.Stream = true
urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID)
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix),
withBody(request),
withBetaAssistantV1(),
)
if err != nil {
return

Check warning on line 63 in assistant_stream.go

View check run for this annotation

Codecov / codecov/patch

assistant_stream.go#L63

Added line #L63 was not covered by tests
}

resp, err := sendRequestStream[AssistantThreadRunStreamResponse](c, req)
if err != nil {
return nil, err

Check warning on line 68 in assistant_stream.go

View check run for this annotation

Codecov / codecov/patch

assistant_stream.go#L68

Added line #L68 was not covered by tests
}
stream = &AssistantThreadRunStream{
streamReader: resp,
}
return
}
364 changes: 364 additions & 0 deletions assistant_stream_test.go

Large diffs are not rendered by default.

129 changes: 129 additions & 0 deletions examples/assistant-streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package main

import (
"context"
"encoding/json"
"fmt"
"io"
"os"

openai "github.com/sashabaranov/go-openai"
)

func main() {
c := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
ctx := context.Background()
fmt.Println("Creating new thread")
thread, err := c.CreateThread(ctx, openai.ThreadRequest{
Messages: []openai.ThreadMessage{{Role: openai.ThreadMessageRoleUser, Content: "i want to go home"}},
})

if err != nil {
fmt.Printf("Thread error: %v\n", err)
return
}

fmt.Printf("thread: %s\n", thread.ID)
message, err := c.CreateMessage(ctx, thread.ID, openai.MessageRequest{
Role: "user",
Content: "i want to go home",
})

if err != nil {
fmt.Printf("Message error: %v\n", err)
return
}

fmt.Printf("Message created: %v\n", message.ID)

stream, err := c.CreateAssistantThreadRunStream(ctx, thread.ID, openai.RunRequest{
AssistantID: os.Getenv("ASSISTANT_ID"),
Model: openai.GPT4TurboPreview,
})

if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}

defer stream.Close()

fmt.Printf("Stream response: ")
/*
err = stream.On("thread.run.step.delta", func (resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
fmt.Printf("run.step.delta: %s", rawData)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
*/
err = stream.On("thread.message.delta", func(resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
fmt.Printf("%s", rawData)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
var requiredActionRuns = []openai.Run{}
err = stream.On("thread.run.requires_action", func(resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
run := openai.Run{}
err := json.Unmarshal(rawData, &run)
if err != nil {
fmt.Printf("run unmarshal error: %v\n", err)
return
}
fmt.Printf("Stream require action: %v\n", run.RequiredAction.SubmitToolOutputs.ToolCalls)
requiredActionRuns = append(requiredActionRuns, run)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
err = stream.Wait()
if err != io.EOF {
fmt.Println("\nStream finished with error", err)
return
}
fmt.Println("\nStream finished")

if len(requiredActionRuns) > 0 {
fmt.Println("Action required")
for _, run := range requiredActionRuns {
toolOuputs := []openai.ToolOutput{}
for _, call := range run.RequiredAction.SubmitToolOutputs.ToolCalls {
output := openai.ToolOutput{
ToolCallID: call.ID,
Output: true,
}
toolOuputs = append(toolOuputs, output)
}

fmt.Printf("\nSubmit tool ouputs: %v\n", toolOuputs)
stream, err := c.CreateAssistantThreadRunSubmitToolOutputStream(ctx, run.ThreadID, run.ID, openai.SubmitToolOutputsRequest{
ToolOutputs: toolOuputs,
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}

defer stream.Close()

fmt.Printf("Stream response: ")
err = stream.On("thread.message.delta", func(resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
fmt.Printf("%s", rawData)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
err = stream.Wait()
if err != io.EOF {
fmt.Println("\nStream finished with error", err)
return
}
fmt.Println("\nStream finished")
}
}
}
62 changes: 62 additions & 0 deletions examples/chat-streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package main

import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
"io"
"os"
)

func main() {
c := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
ctx := context.Background()

req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
MaxTokens: 200,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "hi👋",
},
},
}
stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("ChatCompletionStream error: %v\n", err)
return
}
defer stream.Close()

fmt.Printf("Stream response: ")
/*
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
fmt.Println("\nStream finished")
return
}

if err != nil {
fmt.Printf("\nStream error: %v\n", err)
return
}

fmt.Printf("%s", response.Choices[0].Delta.Content)
}
*/
err = stream.On("message", func(resp openai.ChatCompletionStreamResponse, rawData []byte) {
fmt.Printf("%s", resp.Choices[0].Delta.Content)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
err = stream.Wait()
if err != io.EOF {
fmt.Println("\nStream finished with error", err)
return
}
fmt.Println("\nStream finished")
}
14 changes: 14 additions & 0 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ type MessageFilesList struct {
httpHeader
}

type MessageDelta struct {
Content []MessageDeltaContent `json:"content"`
Role string `json:"role"`
FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility

}

type MessageDeltaContent struct {
Index int `json:"index"`
Type string `json:"type"`
Text *MessageText `json:"text,omitempty"`
ImageFile *ImageFile `json:"image_file,omitempty"`
}

// CreateMessage creates a new message.
func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) {
urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix)
Expand Down
2 changes: 2 additions & 0 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type RunRequest struct {
AdditionalInstructions string `json:"additional_instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}

type RunModifyRequest struct {
Expand All @@ -93,6 +94,7 @@ type RunList struct {

type SubmitToolOutputsRequest struct {
ToolOutputs []ToolOutput `json:"tool_outputs"`
Stream bool `json:"stream"`
}

type ToolOutput struct {
Expand Down
Loading
Loading