Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ const (

var (
proxy = flag.String("proxy", "", "URL (including scheme) of the inverting proxy")
proxyTimeout = flag.Duration("proxy-timeout", 60*time.Second, "Client timeout when sending requests to the inverting proxy")
proxyTimeout = flag.Duration("proxy-timeout", 60*time.Second, "Timeout for polling the inverting proxy for new requests")
requestForwardingTimeout = flag.Duration("request-forwarding-timeout", 0*time.Second, "Timeout for forwarding individual requests to the backend and returning a response (matches proxy-timeout by default)")
host = flag.String("host", "localhost:8080", "Hostname (including port) of the backend server")
forceHTTP2 = flag.Bool("force-http2", false, "Force connections to the backend host to be performed using HTTP/2")
backendID = flag.String("backend", "", "Unique ID for this backend.")
Expand Down Expand Up @@ -212,7 +213,9 @@ func pollForNewRequests(pollingCtx context.Context, client *http.Client, hostPro
log.Printf("Request polling context completed with ctx err: %v\n", pollingCtx.Err())
return
default:
if requests, err := utils.ListPendingRequests(client, *proxy, backendID, metricHandler); err != nil {
listRequestsCtx, cancel := context.WithTimeout(pollingCtx, *proxyTimeout)
defer cancel()
if requests, err := utils.ListPendingRequests(listRequestsCtx, client, *proxy, backendID, metricHandler); err != nil {
log.Printf("Failed to read pending requests: %q\n", err.Error())
time.Sleep(utils.ExponentialBackoffDuration(retryCount))
retryCount++
Expand Down Expand Up @@ -304,7 +307,12 @@ func runAdapter(ctx context.Context, requestPollingCtx context.Context) error {
if err != nil {
return err
}
client.Timeout = *proxyTimeout

// Request forwarding should use the larger of proxyTimeout and requestForwardingTimeout
effectiveRequestForwardingTimeout := max(*proxyTimeout, *requestForwardingTimeout)
client.Timeout = effectiveRequestForwardingTimeout

log.Printf("Request forwarding timeout is %v; proxy timeout is %v\n", effectiveRequestForwardingTimeout, *proxyTimeout)

hostProxy, err := hostProxy(ctx, *host, *shimPath, *shimWebsockets, *forceHTTP2)
if err != nil {
Expand Down
83 changes: 83 additions & 0 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,89 @@ func TestWithInMemoryProxyAndBackendWithSessions(t *testing.T) {
}
}

func TestProxyTimeoutWithShortTimeout(t *testing.T) {
proxyTimeout := "10ms"
requestForwardingTimeout := "60s"
wantTimeout := true

timeoutTest(t, proxyTimeout, requestForwardingTimeout, wantTimeout)
}

func TestProxyTimeoutWithLongTimeout(t *testing.T) {
proxyTimeout := "60s"
requestForwardingTimeout := "60s"
wantTimeout := false

timeoutTest(t, proxyTimeout, requestForwardingTimeout, wantTimeout)
}

func timeoutTest(t *testing.T, proxyTimeout string, requestForwardingTimeout string, wantTimeout bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

backendHomeDir := filepath.Join(t.TempDir(), "backend-home")
gcloudCfg := filepath.Join(backendHomeDir, ".config", "gcloud")
if err := os.MkdirAll(gcloudCfg, os.ModePerm); err != nil {
t.Fatalf("Failed to set up a temporary home directory for the test: %v", err)
}
backendURL := RunBackend(ctx, t)
fakeMetadataURL := RunFakeMetadataServer(ctx, t)

parsedBackendURL, err := url.Parse(backendURL)
if err != nil {
t.Fatalf("Failed to parse the backend URL: %v", err)
}
proxyPort, err := RunLocalProxy(ctx, t)
proxyURL := fmt.Sprintf("http://localhost:%d", proxyPort)
if err != nil {
t.Fatalf("Failed to run the local inverting proxy: %v", err)
}
t.Logf("Started backend at localhost:%s and proxy at %s", parsedBackendURL.Port(), proxyURL)

// This assumes that "Make build" has been run
args := strings.Join(append(
[]string{"${GOPATH}/bin/proxy-forwarding-agent"},
"--backend=testBackend",
"--proxy", proxyURL+"/",
"--proxy-timeout="+proxyTimeout,
"--request-forwarding-timeout="+requestForwardingTimeout,
"--host=localhost:"+parsedBackendURL.Port()),
" ")
agentCmd := exec.CommandContext(ctx, "/bin/bash", "-c", args)

var out bytes.Buffer
agentCmd.Stdout = &out
agentCmd.Stderr = &out
agentCmd.Env = append(os.Environ(), "PATH=", "HOME="+backendHomeDir, "GCE_METADATA_HOST="+strings.TrimPrefix(fakeMetadataURL, "http://"))
if err := agentCmd.Start(); err != nil {
t.Fatalf("Failed to start the agent binary: %v", err)
}
defer func() {
cancel()
err := agentCmd.Wait()

s := out.String()
t.Logf("Agent result: %v, stdout/stderr: %q", err, s)
timeoutOccurred := strings.Contains(s, "context deadline exceeded")
if timeoutOccurred != wantTimeout {
t.Errorf("Unexpected timeout state: got %v, want %v", timeoutOccurred, wantTimeout)
}
}()

// Send one request through the proxy to make sure the agent has come up.
//
// We give this initial request a long time to complete, as the agent takes
// a long time to start up.
testPath := "/some/request/path"
if err := checkRequest(proxyURL, testPath, testPath, time.Second, backendCookie); err != nil {
t.Fatalf("Failed to send the initial request: %v", err)
}

if err := checkRequest(proxyURL, testPath, testPath, 100*time.Millisecond, backendCookie); err != nil {
t.Fatalf("Failed to send request %v", err)
}
}

func TestGracefulShutdown(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
4 changes: 2 additions & 2 deletions agent/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ func RoundTripperWithVMIdentity(ctx context.Context, wrapped http.RoundTripper,
}

// ListPendingRequests issues a single request to the proxy to ask for the IDs of pending requests.
func ListPendingRequests(client *http.Client, proxyHost, backendID string, metricHandler *metrics.MetricHandler) ([]string, error) {
func ListPendingRequests(ctx context.Context, client *http.Client, proxyHost, backendID string, metricHandler *metrics.MetricHandler) ([]string, error) {
proxyURL := proxyHost + PendingPath
proxyReq, err := http.NewRequest(http.MethodGet, proxyURL, nil)
proxyReq, err := http.NewRequestWithContext(ctx, http.MethodGet, proxyURL, nil)
if err != nil {
return nil, err
}
Expand Down