Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit a426134

Browse files
authored
Gateway: forward X-Fireworks-Genie header from client (#63460)
Previously, there was no way to enable the "tracing" feature from Fireworks https://readme.fireworks.ai/docs/enabling-tracing This PR solves the problem by forwarding the `X-Fireworks-Genie` HTTP header to Fireworks if this HTTP header is set by the Gateway client. Fixes CODY-2555 <!-- 💡 To write a useful PR description, make sure that your description covers: - WHAT this PR is changing: - How was it PREVIOUSLY. - How it will be from NOW on. - WHY this PR is needed. - CONTEXT, i.e. to which initiative, project or RFC it belongs. The structure of the description doesn't matter as much as covering these points, so use your best judgement based on your context. Learn how to write good pull request description: https://www.notion.so/sourcegraph/Write-a-good-pull-request-description-610a7fd3e613496eb76f450db5a49b6e?pvs=4 --> ## Test plan <!-- All pull requests REQUIRE a test plan: https://docs-legacy.sourcegraph.com/dev/background-information/testing_principles --> N/A ## Changelog <!-- 1. Ensure your pull request title is formatted as: $type($domain): $what 2. Add bullet list items for each additional detail you want to cover (see example below) 3. You can edit this after the pull request was merged, as long as release shipping it hasn't been promoted to the public. 4. For more information, please see this how-to https://www.notion.so/sourcegraph/Writing-a-changelog-entry-dd997f411d524caabf0d8d38a24a878c? Audience: TS/CSE > Customers > Teammates (in that order). Cheat sheet: $type = chore|fix|feat $domain: source|search|ci|release|plg|cody|local|... --> <!-- Example: Title: fix(search): parse quotes with the appropriate context Changelog section: ## Changelog - When a quote is used with regexp pattern type, then ... - Refactored underlying code. -->
1 parent c05186e commit a426134

File tree

6 files changed

+37
-32
lines changed

6 files changed

+37
-32
lines changed

cmd/cody-gateway/internal/httpapi/completions/anthropic.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,15 @@ func (a *AnthropicHandlerMethods) getRequestMetadata(body anthropicRequest) (mod
155155
}
156156
}
157157

158-
func (a *AnthropicHandlerMethods) transformRequest(r *http.Request) {
158+
func (a *AnthropicHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
159159
// Mimic headers set by the official Anthropic client:
160160
// https://sourcegraph.com/github.com/anthropics/anthropic-sdk-typescript@493075d70f50f1568a276ed0cb177e297f5fef9f/-/blob/src/index.ts
161-
r.Header.Set("Cache-Control", "no-cache")
162-
r.Header.Set("Accept", "application/json")
163-
r.Header.Set("Content-Type", "application/json")
164-
r.Header.Set("Client", "sourcegraph-cody-gateway/1.0")
165-
r.Header.Set("X-API-Key", a.config.AccessToken)
166-
r.Header.Set("anthropic-version", "2023-01-01")
161+
upstreamRequest.Header.Set("Cache-Control", "no-cache")
162+
upstreamRequest.Header.Set("Accept", "application/json")
163+
upstreamRequest.Header.Set("Content-Type", "application/json")
164+
upstreamRequest.Header.Set("Client", "sourcegraph-cody-gateway/1.0")
165+
upstreamRequest.Header.Set("X-API-Key", a.config.AccessToken)
166+
upstreamRequest.Header.Set("anthropic-version", "2023-01-01")
167167
}
168168

169169
func (a *AnthropicHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody anthropicRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {

cmd/cody-gateway/internal/httpapi/completions/anthropicmessages.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,10 @@ func (a *AnthropicMessagesHandlerMethods) getRequestMetadata(body anthropicMessa
209209
}
210210
}
211211

212-
func (a *AnthropicMessagesHandlerMethods) transformRequest(r *http.Request) {
213-
r.Header.Set("Content-Type", "application/json")
214-
r.Header.Set("X-API-Key", a.config.AccessToken)
215-
r.Header.Set("anthropic-version", "2023-06-01")
212+
func (a *AnthropicMessagesHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
213+
upstreamRequest.Header.Set("Content-Type", "application/json")
214+
upstreamRequest.Header.Set("X-API-Key", a.config.AccessToken)
215+
upstreamRequest.Header.Set("anthropic-version", "2023-06-01")
216216
}
217217

218218
func (a *AnthropicMessagesHandlerMethods) parseResponseAndUsage(logger log.Logger, body anthropicMessagesRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {

cmd/cody-gateway/internal/httpapi/completions/fireworks.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,13 @@ func (f *FireworksHandlerMethods) getRequestMetadata(body fireworksRequest) (mod
143143
return body.Model, map[string]any{"stream": body.Stream}
144144
}
145145

146-
func (f *FireworksHandlerMethods) transformRequest(r *http.Request) {
147-
r.Header.Set("Content-Type", "application/json")
148-
r.Header.Set("Authorization", "Bearer "+f.config.AccessToken)
146+
func (f *FireworksHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
147+
// Enable tracing if the client requests it, see https://readme.fireworks.ai/docs/enabling-tracing
148+
if downstreamRequest.Header.Get("X-Fireworks-Genie") == "true" {
149+
upstreamRequest.Header.Set("X-Fireworks-Genie", "true")
150+
}
151+
upstreamRequest.Header.Set("Content-Type", "application/json")
152+
upstreamRequest.Header.Set("Authorization", "Bearer "+f.config.AccessToken)
149153
}
150154

151155
func (f *FireworksHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody fireworksRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {

cmd/cody-gateway/internal/httpapi/completions/google.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ func (*GoogleHandlerMethods) getRequestMetadata(body googleRequest) (model strin
103103
return body.Model, map[string]any{"stream": body.ShouldStream()}
104104
}
105105

106-
func (o *GoogleHandlerMethods) transformRequest(r *http.Request) {
107-
r.Header.Set("Content-Type", "application/json")
106+
func (o *GoogleHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
107+
upstreamRequest.Header.Set("Content-Type", "application/json")
108108
}
109109

110110
func (*GoogleHandlerMethods) parseResponseAndUsage(logger log.Logger, reqBody googleRequest, r io.Reader, isStreamRequest bool) (promptUsage, completionUsage usageStats) {

cmd/cody-gateway/internal/httpapi/completions/openai.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ func (*OpenAIHandlerMethods) getRequestMetadata(body openaiRequest) (model strin
145145
return body.Model, map[string]any{"stream": body.Stream}
146146
}
147147

148-
func (o *OpenAIHandlerMethods) transformRequest(r *http.Request) {
149-
r.Header.Set("Content-Type", "application/json")
150-
r.Header.Set("Authorization", "Bearer "+o.config.AccessToken)
148+
func (o *OpenAIHandlerMethods) transformRequest(downstreamRequest, upstreamRequest *http.Request) {
149+
upstreamRequest.Header.Set("Content-Type", "application/json")
150+
upstreamRequest.Header.Set("Authorization", "Bearer "+o.config.AccessToken)
151151
if o.config.OrgID != "" {
152-
r.Header.Set("OpenAI-Organization", o.config.OrgID)
152+
upstreamRequest.Header.Set("OpenAI-Organization", o.config.OrgID)
153153
}
154154
}
155155

cmd/cody-gateway/internal/httpapi/completions/upstream.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ type upstreamHandlerMethods[ReqT UpstreamRequest] interface {
9191
// provided to assist in abuse detection.
9292
transformBody(_ *ReqT, identifier string)
9393
// transformRequest can be used to modify the HTTP request before it is sent
94-
// upstream. To manipulate the body, use transformBody.
95-
transformRequest(*http.Request)
94+
// upstream. The downstreamRequest parameter is the request sent from the Gateway client.
95+
// To manipulate the body, use transformBody.
96+
transformRequest(downstreamRequest, upstreamRequest *http.Request)
9697
// getRequestMetadata should extract details about the request we are sending
9798
// upstream for validation and tracking purposes. Usage data does not need
9899
// to be reported here - instead, use parseResponseAndUsage to extract usage,
@@ -158,15 +159,15 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
158159
// upstreamHandler is the actual HTTP handle that will perform "all of the things"
159160
// in order to call the upstream API. e.g. calling the upstreamHandlerMethods in
160161
// the correct order, enforcing rate limits and anti-abuse mechanisms, etc.
161-
upstreamHandler := func(w http.ResponseWriter, r *http.Request) {
162-
ctx := r.Context()
162+
upstreamHandler := func(w http.ResponseWriter, downstreamRequest *http.Request) {
163+
ctx := downstreamRequest.Context()
163164
act := actor.FromContext(ctx)
164165

165166
// TODO: Investigate using actor propagation handler for extracting
166167
// this. We had some issues before getting that to work, so for now
167168
// just stick with what we've seen working so far.
168-
sgActorID := r.Header.Get("X-Sourcegraph-Actor-UID")
169-
sgActorAnonymousUID := r.Header.Get("X-Sourcegraph-Actor-Anonymous-UID")
169+
sgActorID := downstreamRequest.Header.Get("X-Sourcegraph-Actor-UID")
170+
sgActorAnonymousUID := downstreamRequest.Header.Get("X-Sourcegraph-Actor-Anonymous-UID")
170171

171172
// Build logger for lifecycle of this request with lots of details.
172173
logger := act.Logger(sgtrace.Logger(ctx, baseLogger)).With(
@@ -209,7 +210,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
209210

210211
// Parse the request body.
211212
var body ReqT
212-
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
213+
if err := json.NewDecoder(downstreamRequest.Body).Decode(&body); err != nil {
213214
response.JSONError(logger, w, http.StatusBadRequest, errors.Wrap(err, "failed to parse request body"))
214215
return
215216
}
@@ -295,14 +296,14 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
295296
}
296297

297298
// Create a new request to send upstream, making sure we retain the same context.
298-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(upstreamPayload))
299+
upstreamRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(upstreamPayload))
299300
if err != nil {
300301
response.JSONError(logger, w, http.StatusInternalServerError, errors.Wrap(err, "failed to create request"))
301302
return
302303
}
303304

304305
// Run the request transformer.
305-
methods.transformRequest(req)
306+
methods.transformRequest(downstreamRequest, upstreamRequest)
306307

307308
// Retrieve metadata from the initial request.
308309
model, requestMetadata := methods.getRequestMetadata(body)
@@ -393,11 +394,11 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
393394
logger.Error("failed to log event", log.Error(err))
394395
}
395396
}()
396-
resp, err := httpClient.Do(req)
397+
resp, err := httpClient.Do(upstreamRequest)
397398
if err != nil {
398399
// Ignore reporting errors where client disconnected
399-
if req.Context().Err() == context.Canceled && errors.Is(err, context.Canceled) {
400-
oteltrace.SpanFromContext(req.Context()).
400+
if upstreamRequest.Context().Err() == context.Canceled && errors.Is(err, context.Canceled) {
401+
oteltrace.SpanFromContext(upstreamRequest.Context()).
401402
SetStatus(codes.Error, err.Error())
402403
logger.Info("request canceled", log.Error(err))
403404
return

0 commit comments

Comments
 (0)