Skip to content

feat: support for Assistant stream mode and implemented stream event callbacks #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions assistant_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package openai

import (
"context"
"fmt"
"net/http"
)

type AssistantThreadRunStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Delta MessageDelta `json:"delta,omitempty"`
}

type AssistantThreadRunStream struct {
*streamReader[AssistantThreadRunStreamResponse]
}

func (c *Client) CreateAssistantThreadRunStream(
ctx context.Context,
threadID string,
request RunRequest,
) (stream *AssistantThreadRunStream, err error) {
request.Stream = true
urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID)
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix),
withBody(request),
withBetaAssistantV1(),
)
if err != nil {
return nil, err
}

resp, err := sendRequestStream[AssistantThreadRunStreamResponse](c, req)
if err != nil {
return nil, err
}
stream = &AssistantThreadRunStream{
streamReader: resp,
}
return
}

func (c *Client) CreateAssistantThreadRunSubmitToolOutputStream(
ctx context.Context,
threadID string,
runID string,
request SubmitToolOutputsRequest,
) (stream *AssistantThreadRunStream, err error) {
request.Stream = true
urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID)
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix),
withBody(request),
withBetaAssistantV1(),
)
if err != nil {
return
}

resp, err := sendRequestStream[AssistantThreadRunStreamResponse](c, req)
if err != nil {
return nil, err
}
stream = &AssistantThreadRunStream{
streamReader: resp,
}
return
}
360 changes: 360 additions & 0 deletions assistant_stream_test.go

Large diffs are not rendered by default.

129 changes: 129 additions & 0 deletions examples/assistant-streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package main

import (
"context"
"encoding/json"
"fmt"
"io"
"os"

openai "github.com/sashabaranov/go-openai"
)

func main() {
c := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
ctx := context.Background()
fmt.Println("Creating new thread")
thread, err := c.CreateThread(ctx, openai.ThreadRequest{
Messages: []openai.ThreadMessage{{Role: openai.ThreadMessageRoleUser, Content: "i want to go home"}},
})

if err != nil {
fmt.Printf("Thread error: %v\n", err)
return
}

fmt.Printf("thread: %s\n", thread.ID)
message, err := c.CreateMessage(ctx, thread.ID, openai.MessageRequest{
Role: "user",
Content: "i want to go home",
})

if err != nil {
fmt.Printf("Message error: %v\n", err)
return
}

fmt.Printf("Message created: %v\n", message.ID)

stream, err := c.CreateAssistantThreadRunStream(ctx, thread.ID, openai.RunRequest{
AssistantID: os.Getenv("ASSISTANT_ID"),
Model: openai.GPT4TurboPreview,
})

if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}

defer stream.Close()

fmt.Printf("Stream response: ")
/*
err = stream.On("thread.run.step.delta", func (resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
fmt.Printf("run.step.delta: %s", rawData)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
*/
err = stream.On("thread.message.delta", func(resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
fmt.Printf("%s", rawData)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
var requiredActionRuns = []openai.Run{}
err = stream.On("thread.run.requires_action", func(resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
run := openai.Run{}
err := json.Unmarshal(rawData, &run)
if err != nil {
fmt.Printf("run unmarshal error: %v\n", err)
return
}
fmt.Printf("Stream require action: %v\n", run.RequiredAction.SubmitToolOutputs.ToolCalls)
requiredActionRuns = append(requiredActionRuns, run)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
err = stream.Wait()
if err != io.EOF {
fmt.Println("\nStream finished with error", err)
return
}
fmt.Println("\nStream finished")

if len(requiredActionRuns) > 0 {
fmt.Println("Action required")
for _, run := range requiredActionRuns {
toolOuputs := []openai.ToolOutput{}
for _, call := range run.RequiredAction.SubmitToolOutputs.ToolCalls {
output := openai.ToolOutput{
ToolCallID: call.ID,
Output: true,
}
toolOuputs = append(toolOuputs, output)
}

fmt.Printf("\nSubmit tool ouputs: %v\n", toolOuputs)
stream, err := c.CreateAssistantThreadRunSubmitToolOutputStream(ctx, run.ThreadID, run.ID, openai.SubmitToolOutputsRequest{
ToolOutputs: toolOuputs,
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}

defer stream.Close()

fmt.Printf("Stream response: ")
err = stream.On("thread.message.delta", func(resp openai.AssistantThreadRunStreamResponse, rawData []byte) {
fmt.Printf("%s", rawData)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
err = stream.Wait()
if err != io.EOF {
fmt.Println("\nStream finished with error", err)
return
}
fmt.Println("\nStream finished")
}
}
}
62 changes: 62 additions & 0 deletions examples/chat-streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package main

import (
"context"
"fmt"
openai "github.com/sashabaranov/go-openai"
"io"
"os"
)

func main() {
c := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
ctx := context.Background()

req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
MaxTokens: 200,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "hi👋",
},
},
}
stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("ChatCompletionStream error: %v\n", err)
return
}
defer stream.Close()

fmt.Printf("Stream response: ")
/*
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
fmt.Println("\nStream finished")
return
}

if err != nil {
fmt.Printf("\nStream error: %v\n", err)
return
}

fmt.Printf("%s", response.Choices[0].Delta.Content)
}
*/
err = stream.On("message", func(resp openai.ChatCompletionStreamResponse, rawData []byte) {
fmt.Printf("%s", resp.Choices[0].Delta.Content)
})
if err != nil {
fmt.Printf("Stream error: %v\n", err)
return
}
err = stream.Wait()
if err != io.EOF {
fmt.Println("\nStream finished with error", err)
return
}
fmt.Println("\nStream finished")
}
13 changes: 13 additions & 0 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@
httpHeader
}

type MessageDelta struct {
Content []MessageDeltaContent `json:"content"`
Role string `json:"role"`
FileIds []string `json:"file_ids,omitempty"`

Check warning on line 79 in messages.go

View workflow job for this annotation

GitHub Actions / Sanity check

var-naming: struct field FileIds should be FileIDs (revive)

Check warning on line 79 in messages.go

View workflow job for this annotation

GitHub Actions / Sanity check

var-naming: struct field FileIds should be FileIDs (revive)
}

type MessageDeltaContent struct {
Index int `json:"index"`
Type string `json:"type"`
Text *MessageText `json:"text,omitempty"`
ImageFile *ImageFile `json:"image_file,omitempty"`
}

// CreateMessage creates a new message.
func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) {
urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix)
Expand Down
2 changes: 2 additions & 0 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type RunRequest struct {
AdditionalInstructions string `json:"additional_instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}

type RunModifyRequest struct {
Expand All @@ -93,6 +94,7 @@ type RunList struct {

type SubmitToolOutputsRequest struct {
ToolOutputs []ToolOutput `json:"tool_outputs"`
Stream bool `json:"stream"`
}

type ToolOutput struct {
Expand Down
Loading
Loading