diff --git a/cmd/frontend/internal/httpapi/completions/handler.go b/cmd/frontend/internal/httpapi/completions/handler.go index 7e79afaa688a..e65e914b81a7 100644 --- a/cmd/frontend/internal/httpapi/completions/handler.go +++ b/cmd/frontend/internal/httpapi/completions/handler.go @@ -326,14 +326,19 @@ func newStreamingResponseHandler(logger log.Logger, db database.DB, feature type Version: version, Parameters: requestParams, } - sendEventFn := func(event types.CompletionResponse) error { + + var responseMetadataCapture types.ResponseMetadataCapture + responseMetadataCapture = types.NewResponseMetadataCapture(func(event types.CompletionResponse) error { if !firstEventObserved { + responseMetadataCapture.ApplyCapturedMetadata(w) firstEventObserved = true timeToFirstEventMetrics.Observe(time.Since(start).Seconds(), 1, nil, requestParams.Model) } return f.Send(ctx, event) - } - err := cc.Stream(ctx, logger, compReq, sendEventFn) + }) + + err := cc.Stream(ctx, logger, compReq, &responseMetadataCapture) + if err != nil { l := trace.Logger(ctx, logger) diff --git a/internal/completions/client/anthropic/anthropic.go b/internal/completions/client/anthropic/anthropic.go index 1a8bbe1e01ce..91df601dca49 100644 --- a/internal/completions/client/anthropic/anthropic.go +++ b/internal/completions/client/anthropic/anthropic.go @@ -84,7 +84,7 @@ func (a *anthropicClient) Stream( ctx context.Context, logger log.Logger, request types.CompletionRequest, - sendEvent types.SendCompletionEvent) error { + responseMetadataCapture *types.ResponseMetadataCapture) error { feature := request.Feature version := request.Version @@ -94,6 +94,10 @@ func (a *anthropicClient) Stream( if err != nil { return err } + + responseMetadataCapture.CaptureHeaders(resp.Header) + responseMetadataCapture.CaptureStatusCode(resp.StatusCode) + defer resp.Body.Close() dec := NewDecoder(resp.Body) @@ -139,7 +143,7 @@ func (a *anthropicClient) Stream( continue } - err = sendEvent(types.CompletionResponse{ + err = responseMetadataCapture.SendEvent(types.CompletionResponse{ Completion: completedString, StopReason: stopReason, }) diff --git a/internal/completions/client/awsbedrock/bedrock.go b/internal/completions/client/awsbedrock/bedrock.go index 86fbf5bdca2f..6e767fe5b356 100644 --- a/internal/completions/client/awsbedrock/bedrock.go +++ b/internal/completions/client/awsbedrock/bedrock.go @@ -86,7 +86,7 @@ func (a *awsBedrockAnthropicCompletionStreamClient) Stream( ctx context.Context, logger log.Logger, request types.CompletionRequest, - sendEvent types.SendCompletionEvent) error { + responseMetadataCapture *types.ResponseMetadataCapture) error { feature := request.Feature version := request.Version @@ -169,7 +169,7 @@ func (a *awsBedrockAnthropicCompletionStreamClient) Stream( continue } sentEvent = true - err = sendEvent(types.CompletionResponse{ + err = responseMetadataCapture.SendEvent(types.CompletionResponse{ Completion: totalCompletion, StopReason: stopReason, }) diff --git a/internal/completions/client/azureopenai/openai.go b/internal/completions/client/azureopenai/openai.go index 5749f43fbd73..a5be830b428f 100644 --- a/internal/completions/client/azureopenai/openai.go +++ b/internal/completions/client/azureopenai/openai.go @@ -190,16 +190,16 @@ func (c *azureCompletionClient) Stream( ctx context.Context, log log.Logger, request types.CompletionRequest, - sendEvent types.SendCompletionEvent, + responseMetadataCapture *types.ResponseMetadataCapture, ) error { feature := request.Feature requestParams := request.Parameters switch feature { case types.CompletionsFeatureCode: - return streamAutocomplete(ctx, c.client, requestParams, sendEvent, log) + return streamAutocomplete(ctx, c.client, requestParams, responseMetadataCapture, log) case types.CompletionsFeatureChat: - return streamChat(ctx, c.client, requestParams, sendEvent, log) + return streamChat(ctx, c.client, requestParams, responseMetadataCapture, log) default: return errors.New("invalid completions feature") } @@ -209,7 +209,7 @@ func streamAutocomplete( ctx context.Context, client CompletionsClient, requestParams types.CompletionRequestParameters, - sendEvent types.SendCompletionEvent, + responseMetadataCapture *types.ResponseMetadataCapture, logger log.Logger, ) error { options, err := getCompletionsOptions(requestParams) @@ -252,7 +252,7 @@ func streamAutocomplete( Completion: content, StopReason: finish, } - err := sendEvent(ev) + err := responseMetadataCapture.SendEvent(ev) if err != nil { return err } @@ -264,7 +264,7 @@ func streamChat( ctx context.Context, client CompletionsClient, requestParams types.CompletionRequestParameters, - sendEvent types.SendCompletionEvent, + responseMetadataCapture *types.ResponseMetadataCapture, logger log.Logger, ) error { @@ -308,7 +308,7 @@ func streamChat( Completion: content, StopReason: finish, } - err := sendEvent(ev) + err := responseMetadataCapture.SendEvent(ev) if err != nil { return err } diff --git a/internal/completions/client/codygateway/codygateway.go b/internal/completions/client/codygateway/codygateway.go index 9dc9e87e245f..c3c8beabe865 100644 --- a/internal/completions/client/codygateway/codygateway.go +++ b/internal/completions/client/codygateway/codygateway.go @@ -47,13 +47,13 @@ type codyGatewayClient struct { } func (c *codyGatewayClient) Stream( - ctx context.Context, logger log.Logger, request types.CompletionRequest, sendEvent types.SendCompletionEvent) error { + ctx context.Context, logger log.Logger, request types.CompletionRequest, responseMetadataCapture *types.ResponseMetadataCapture) error { cc, err := c.clientForParams(request.Feature, &request.Parameters) if err != nil { return err } - err = cc.Stream(ctx, logger, request, sendEvent) + err = cc.Stream(ctx, logger, request, responseMetadataCapture) return overwriteErrSource(err) } @@ -137,7 +137,7 @@ func gatewayDoer(upstream httpcli.Doer, feature types.CompletionsFeature, gatewa }), }).RoundTrip(req) - // If we get a repsonse, record Cody Gateway's x-trace response header, + // If we get a response, record Cody Gateway's x-trace response header, // so that we can link up to an event on our end if needed. if resp != nil && resp.Header != nil { if span := trace.SpanFromContext(req.Context()); span.SpanContext().IsValid() { diff --git a/internal/completions/client/fireworks/fireworks.go b/internal/completions/client/fireworks/fireworks.go index 88f8678d2105..6d71310608df 100644 --- a/internal/completions/client/fireworks/fireworks.go +++ b/internal/completions/client/fireworks/fireworks.go @@ -116,7 +116,7 @@ func (c *fireworksClient) Stream( ctx context.Context, logger log.Logger, request types.CompletionRequest, - sendEvent types.SendCompletionEvent) error { + responseMetadataCapture *types.ResponseMetadataCapture) error { feature := request.Feature requestParams := request.Parameters logprobsInclude := uint8(0) @@ -132,6 +132,10 @@ func (c *fireworksClient) Stream( if err != nil { return err } + + responseMetadataCapture.CaptureHeaders(resp.Header) + responseMetadataCapture.CaptureStatusCode(resp.StatusCode) + defer resp.Body.Close() dec := NewDecoder(resp.Body) @@ -166,7 +170,7 @@ func (c *fireworksClient) Stream( StopReason: event.Choices[0].FinishReason, Logprobs: accumulatedLogprobs, } - err = sendEvent(ev) + err = responseMetadataCapture.SendEvent(ev) if err != nil { return err } @@ -232,10 +236,10 @@ func (c *fireworksClient) makeRequest(ctx context.Context, feature types.Complet Role: role, Content: m.Text, }) - // HACK: Replace the ending part of the endpint from `/completions` to `/chat/completions` + // HACK: Replace the ending part of the endpoint from `/completions` to `/chat/completions` // // This is _only_ used when running the Fireworks API directly from the SG instance - // (without Cody Gateway) and is neccessary because every client can only have one + // (without Cody Gateway) and is necessary because every client can only have one // endpoint configured at the moment. If the request is routed to Cody Gateway, the // endpoint will not have `inference/v1/completions` in the URL endpoint = strings.Replace(c.endpoint, "/inference/v1/completions", "/inference/v1/chat/completions", 1) @@ -249,6 +253,7 @@ func (c *fireworksClient) makeRequest(ctx context.Context, feature types.Complet } req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(reqBody)) + if err != nil { return nil, err } @@ -257,6 +262,7 @@ func (c *fireworksClient) makeRequest(ctx context.Context, feature types.Complet req.Header.Set("Authorization", "Bearer "+c.accessToken) resp, err := c.cli.Do(req) + if err != nil { return nil, err } diff --git a/internal/completions/client/google/google.go b/internal/completions/client/google/google.go index 00818584c08d..330fed477b52 100644 --- a/internal/completions/client/google/google.go +++ b/internal/completions/client/google/google.go @@ -169,18 +169,18 @@ func (c *googleCompletionStreamClient) Stream( ctx context.Context, logger log.Logger, request types.CompletionRequest, - sendEvent types.SendCompletionEvent) error { + responseMetadataCapture *types.ResponseMetadataCapture) error { if c.apiFamily == VertexAnthropic { - return c.handleVertexAnthropicStream(ctx, request.Parameters, sendEvent) + return c.handleVertexAnthropicStream(ctx, request.Parameters, responseMetadataCapture) } else { - return c.handleGeminiStream(ctx, request.Parameters, sendEvent) + return c.handleGeminiStream(ctx, request.Parameters, responseMetadataCapture) } } func (c *googleCompletionStreamClient) handleGeminiStream( ctx context.Context, requestParams types.CompletionRequestParameters, - sendEvent types.SendCompletionEvent, + responseMetadataCapture *types.ResponseMetadataCapture, ) error { resp, err := c.makeGeminiRequest(ctx, requestParams, true) if err != nil { @@ -215,7 +215,7 @@ func (c *googleCompletionStreamClient) handleGeminiStream( Completion: content, StopReason: event.Candidates[0].FinishReason, } - err = sendEvent(ev) + err = responseMetadataCapture.SendEvent(ev) if err != nil { return err } @@ -231,7 +231,7 @@ func (c *googleCompletionStreamClient) handleGeminiStream( func (c *googleCompletionStreamClient) handleVertexAnthropicStream( ctx context.Context, requestParams types.CompletionRequestParameters, - sendEvent types.SendCompletionEvent, + responseMetadataCapture *types.ResponseMetadataCapture, ) error { var resp *http.Response var err error @@ -294,7 +294,7 @@ func (c *googleCompletionStreamClient) handleVertexAnthropicStream( } totalCompletion += d.Delta.Text sentEvent = true - err = sendEvent(types.CompletionResponse{ + err = responseMetadataCapture.SendEvent(types.CompletionResponse{ Completion: totalCompletion, }) if err != nil { diff --git a/internal/completions/client/observe.go b/internal/completions/client/observe.go index 2907b31d8fff..60ea0147f94c 100644 --- a/internal/completions/client/observe.go +++ b/internal/completions/client/observe.go @@ -33,7 +33,7 @@ type observedClient struct { var _ types.CompletionsClient = (*observedClient)(nil) -func (o *observedClient) Stream(ctx context.Context, logger log.Logger, request types.CompletionRequest, send types.SendCompletionEvent) (err error) { +func (o *observedClient) Stream(ctx context.Context, logger log.Logger, request types.CompletionRequest, responseMetadataCapture *types.ResponseMetadataCapture) (err error) { feature := request.Feature version := request.Version params := request.Parameters @@ -47,17 +47,18 @@ func (o *observedClient) Stream(ctx context.Context, logger log.Logger, request }) defer endObservation(1, observation.Args{}) - tracedSend := func(event types.CompletionResponse) error { + originalCaptureEvent := responseMetadataCapture.SendEvent + responseMetadataCapture.SendEvent = func(event types.CompletionResponse) error { if event.StopReason != "" { tr.AddEvent("stopped", attribute.String("reason", event.StopReason)) } else { tr.AddEvent("completion", attribute.Int("len", len(event.Completion))) } - return send(event) + return originalCaptureEvent(event) } - return o.inner.Stream(ctx, logger, request, tracedSend) + return o.inner.Stream(ctx, logger, request, responseMetadataCapture) } func (o *observedClient) Complete(ctx context.Context, logger log.Logger, request types.CompletionRequest) (resp *types.CompletionResponse, err error) { diff --git a/internal/completions/client/openai/openai.go b/internal/completions/client/openai/openai.go index 2fdd3b248271..de03b5b3503e 100644 --- a/internal/completions/client/openai/openai.go +++ b/internal/completions/client/openai/openai.go @@ -86,7 +86,7 @@ func (c *openAIChatCompletionStreamClient) Stream( ctx context.Context, logger log.Logger, request types.CompletionRequest, - sendEvent types.SendCompletionEvent) error { + responseMetadataCapture *types.ResponseMetadataCapture) error { feature := request.Feature requestParams := request.Parameters @@ -145,7 +145,7 @@ func (c *openAIChatCompletionStreamClient) Stream( Completion: content, StopReason: event.Choices[0].FinishReason, } - err = sendEvent(ev) + err = responseMetadataCapture.SendEvent(ev) if err != nil { return err } diff --git a/internal/completions/client/openai/openai_test.go b/internal/completions/client/openai/openai_test.go index f79e1cb0e4c6..eb446958c2b3 100644 --- a/internal/completions/client/openai/openai_test.go +++ b/internal/completions/client/openai/openai_test.go @@ -54,8 +54,8 @@ func TestErrStatusNotOK(t *testing.T) { t.Run("Stream", func(t *testing.T) { logger := log.Scoped("completions") - sendEventFn := func(event types.CompletionResponse) error { return nil } - err := mockClient.Stream(context.Background(), logger, compRequest, sendEventFn) + responseMetadataCapture := types.NewResponseMetadataCapture(func(types.CompletionResponse) error { return nil }) + err := mockClient.Stream(context.Background(), logger, compRequest, &responseMetadataCapture) require.Error(t, err) autogold.Expect("OpenAI: unexpected status code 429: oh no, please slow down!").Equal(t, err.Error()) diff --git a/internal/completions/types/types.go b/internal/completions/types/types.go index 1add02943f02..4f46776b5550 100644 --- a/internal/completions/types/types.go +++ b/internal/completions/types/types.go @@ -3,6 +3,7 @@ package types import ( "context" "fmt" + "net/http" "strings" "go.opentelemetry.io/otel/attribute" @@ -197,7 +198,7 @@ type CompletionRequest struct { type CompletionsClient interface { // Stream executions a completions request, streaming results to the callback. // Callers should check for ErrStatusNotOK and handle the error appropriately. - Stream(context.Context, log.Logger, CompletionRequest, SendCompletionEvent) error + Stream(context.Context, log.Logger, CompletionRequest, *ResponseMetadataCapture) error // Complete executions a completions request until done. Callers should check // for ErrStatusNotOK and handle the error appropriately. Complete(context.Context, log.Logger, CompletionRequest) (*CompletionResponse, error) @@ -252,3 +253,47 @@ func ConvertFromLegacyMessages(messages []Message) []Message { return filteredMessages } + +// ResponseMetadataCapture holds metadata from an HTTP response, including headers, +// status code, and a function to send completion events. +type ResponseMetadataCapture struct { + headers http.Header + statusCode *int + SendEvent SendCompletionEvent +} + +// NewResponseMetadataCapture creates and initializes a new ResponseMetadataCapture +// with empty headers and the provided SendCompletionEvent function. +func NewResponseMetadataCapture(sendEvent SendCompletionEvent) ResponseMetadataCapture { + return ResponseMetadataCapture{ + headers: make(http.Header), + SendEvent: sendEvent, + } +} + +// CaptureHeaders copies the provided headers into the ResponseMetadataCapture. +func (rc *ResponseMetadataCapture) CaptureHeaders(headers http.Header) { + for key, values := range headers { + rc.headers[key] = values + } +} + +// CaptureStatusCode stores the provided status code in the ResponseMetadataCapture. +func (rc *ResponseMetadataCapture) CaptureStatusCode(statusCode int) { + rc.statusCode = &statusCode +} + +// ApplyCapturedMetadata applies the captured headers and status code to the provided +// http.ResponseWriter. This is typically used to propagate upstream response metadata +// to the client. +func (rc *ResponseMetadataCapture) ApplyCapturedMetadata(w http.ResponseWriter) { + for key, values := range rc.headers { + for _, value := range values { + w.Header().Add(key, value) + } + } + + if rc.statusCode != nil { + w.WriteHeader(*rc.statusCode) + } +}