diff --git a/pkg/clusters/clusterinfo.go b/pkg/clusters/clusterinfo.go index adc614e..e7fe22a 100644 --- a/pkg/clusters/clusterinfo.go +++ b/pkg/clusters/clusterinfo.go @@ -31,7 +31,6 @@ import ( "k8s.io/apimachinery/pkg/util/net" "k8s.io/apimachinery/pkg/util/proxy" "k8s.io/apiserver/pkg/authorization/authorizer" - "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/util/cert" "k8s.io/component-base/featuregate" @@ -118,9 +117,6 @@ type ClusterInfo struct { // server Cluster Cluster string - // serverNames are used to route requests with different hostnames - serverNames sync.Map - // global rate limiter type globalRateLimiter string @@ -144,9 +140,9 @@ type ClusterInfo struct { currentLoggingConfig atomic.Value featuregate featuregate.MutableFeatureGate - healthCheckIntervalSeconds time.Duration - endpointHeathCheck EndpointHealthCheck - skipSyncEndpoints bool + healthCheckInterval time.Duration + endpointHeathCheck EndpointHealthCheck + skipSyncEndpoints bool } type secureServingConfig struct { @@ -170,18 +166,18 @@ func NewEmptyClusterInfo(clusterName string, config *rest.Config, healthCheck En limiter := gatewayflowcontrol.NewUpstreamLimiter(ctx, clusterName, "", clientSets) info := &ClusterInfo{ - ctx: ctx, - cancel: cancel, - Cluster: clusterName, - restConfig: config, - Endpoints: &EndpointInfoMap{data: sync.Map{}}, - healthCheckIntervalSeconds: 5 * time.Second, - globalRateLimiter: rateLimiter, - flowcontrol: limiter, - loadbalancer: sync.Map{}, - endpointHeathCheck: healthCheck, - skipSyncEndpoints: skipEndpoints, - featuregate: features.DefaultMutableFeatureGate.DeepCopy(), + ctx: ctx, + cancel: cancel, + Cluster: clusterName, + restConfig: config, + Endpoints: &EndpointInfoMap{data: sync.Map{}}, + healthCheckInterval: 5 * time.Second, + globalRateLimiter: rateLimiter, + flowcontrol: limiter, + loadbalancer: sync.Map{}, + endpointHeathCheck: healthCheck, + skipSyncEndpoints: skipEndpoints, + featuregate: features.DefaultMutableFeatureGate.DeepCopy(), } return info } @@ -506,18 +502,13 @@ func (c *ClusterInfo) addOrUpdateEndpoint(endpoint string, disabled bool) error info, ok := c.Endpoints.Load(endpoint) if ok { info.SetDisabled(disabled) - EnsureGatewayHealthCheck(info, c.healthCheckIntervalSeconds, info.ctx) + EnsureGatewayHealthCheck(info, c.healthCheckInterval, info.ctx) return nil } http2configCopy := *c.restConfig http2configCopy.WrapTransport = transport.NewDynamicImpersonatingRoundTripper http2configCopy.Host = endpoint - ts, err := rest.TransportFor(&http2configCopy) - if err != nil { - klog.Errorf("failed to create http2 transport for , err: %v", c.Cluster, endpoint, err) - return err - } // since http2 doesn't support websocket, we need to disable http2 when using websocket upgradeConfigCopy := http2configCopy @@ -532,12 +523,6 @@ func (c *ClusterInfo) addOrUpdateEndpoint(endpoint string, disabled bool) error klog.Errorf("failed to convert transport to proxy.UpgradeRequestRoundTripper for ", c.Cluster, endpoint) } - client, err := kubernetes.NewForConfig(&http2configCopy) - if err != nil { - klog.Errorf("failed to create clientset for , err: %v", c.Cluster, endpoint, err) - return err - } - // initial endpoint status initStatus := endpointStatus{ Disabled: disabled, @@ -550,19 +535,21 @@ func (c *ClusterInfo) addOrUpdateEndpoint(endpoint string, disabled bool) error cancel: cancel, Cluster: c.Cluster, Endpoint: endpoint, - status: initStatus, + status: &initStatus, proxyConfig: &http2configCopy, - ProxyTransport: ts, proxyUpgradeConfig: &upgradeConfigCopy, PorxyUpgradeTransport: urrt, - clientset: client, healthCheckFun: c.endpointHeathCheck, } + if err := info.ResetTransport(); err != nil { + klog.Errorf("failed to init transport for , err: %v", c.Cluster, endpoint, err) + return err + } klog.Infof("[cluster info] new endpoint added, cluster=%q, endpoint=%q", c.Cluster, info.Endpoint) c.Endpoints.Store(endpoint, info) - EnsureGatewayHealthCheck(info, c.healthCheckIntervalSeconds, info.ctx) + EnsureGatewayHealthCheck(info, c.healthCheckInterval, info.ctx) return nil } diff --git a/pkg/clusters/endpoint.go b/pkg/clusters/endpoint.go index 3ff9dd2..99c1a10 100644 --- a/pkg/clusters/endpoint.go +++ b/pkg/clusters/endpoint.go @@ -17,24 +17,29 @@ package clusters import ( "context" "fmt" + "net" "net/http" "sync" "time" + utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apimachinery/pkg/util/proxy" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" + kts "k8s.io/client-go/transport" "k8s.io/klog" "github.com/kubewharf/kubegateway/pkg/gateway/metrics" + "github.com/kubewharf/kubegateway/pkg/transport" ) type endpointStatus struct { - Healthy bool - Reason string - Message string - Disabled bool - mux sync.RWMutex + Healthy bool + Reason string + Message string + Disabled bool + UnhealthyCount int + mux sync.RWMutex } func (s *endpointStatus) IsReady() bool { @@ -43,6 +48,12 @@ func (s *endpointStatus) IsReady() bool { return !s.Disabled && s.Healthy } +func (s *endpointStatus) GetUnhealthyCount() int { + s.mux.RLock() + defer s.mux.RUnlock() + return s.UnhealthyCount +} + func (s *endpointStatus) SetDisabled(disabled bool) { s.mux.Lock() defer s.mux.Unlock() @@ -55,6 +66,11 @@ func (s *endpointStatus) SetStatus(healthy bool, reason, message string) { s.Healthy = healthy s.Reason = reason s.Message = message + if healthy { + s.UnhealthyCount = 0 + } else { + s.UnhealthyCount++ + } } type EndpointInfo struct { @@ -71,9 +87,10 @@ type EndpointInfo struct { // http1 proxy round tripper for websocket PorxyUpgradeTransport proxy.UpgradeRequestRoundTripper - clientset kubernetes.Interface + clientset kubernetes.Interface + cancelableTs *transport.CancelableTransport - status endpointStatus + status *endpointStatus healthCheckFun EndpointHealthCheck healthCheckCh chan struct{} @@ -89,6 +106,80 @@ func (e *EndpointInfo) Clientset() kubernetes.Interface { return e.clientset } +func (e *EndpointInfo) createTransport() (*transport.CancelableTransport, http.RoundTripper, *kubernetes.Clientset, error) { + ts, err := newTransport(e.proxyConfig) + if err != nil { + klog.Errorf("failed to create http2 transport for , err: %v", e.Cluster, e.Endpoint, err) + return nil, nil, nil, err + } + cancelableTs := transport.NewCancelableTransport(ts) + ts = cancelableTs + + proxyTs, err := rest.HTTPWrappersForConfig(e.proxyConfig, ts) + if err != nil { + klog.Errorf("failed to wrap http2 transport for , err: %v", e.Cluster, e.Endpoint, err) + return nil, nil, nil, err + } + + clientsetConfig := *e.proxyConfig + clientsetConfig.Transport = ts // let client set use the same transport as proxy + clientsetConfig.TLSClientConfig = rest.TLSClientConfig{} + client, err := kubernetes.NewForConfig(&clientsetConfig) + if err != nil { + klog.Errorf("failed to create clientset for , err: %v", e.Cluster, e.Endpoint, err) + return nil, nil, nil, err + } + + return cancelableTs, proxyTs, client, nil +} + +func (e *EndpointInfo) ResetTransport() error { + cancelableTs, ts, client, err := e.createTransport() + if err != nil { + return err + } + klog.Infof("set new transport %p for cluster %s endpoint: %s", cancelableTs, e.Cluster, e.Endpoint) + e.ProxyTransport = ts + e.clientset = client + cancelTs := e.cancelableTs + e.cancelableTs = cancelableTs + if cancelTs != nil { + klog.Infof("close transport %p for cluster %s endpoint: %s", cancelTs, e.Cluster, e.Endpoint) + cancelTs.Close() + } + return nil +} + +func newTransport(cfg *rest.Config) (http.RoundTripper, error) { + config, err := cfg.TransportConfig() + if err != nil { + return nil, err + } + tlsConfig, err := kts.TLSConfigFor(config) + if err != nil { + return nil, err + } + // The options didn't require a custom TLS config + if tlsConfig == nil && config.Dial == nil { + return http.DefaultTransport, nil + } + dial := config.Dial + if dial == nil { + dial = (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext + } + return utilnet.SetTransportDefaults(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: tlsConfig, + MaxIdleConnsPerHost: 25, + DialContext: dial, + DisableCompression: config.DisableCompression, + }), nil +} + func (e *EndpointInfo) SetDisabled(disabled bool) { if e.status.Disabled != disabled { e.status.Disabled = disabled @@ -100,13 +191,18 @@ func (e *EndpointInfo) IstDisabled() bool { return e.status.Disabled } +func (e *EndpointInfo) GetUnhealthyCount() int { + return e.status.GetUnhealthyCount() +} + func (e *EndpointInfo) UpdateStatus(healthy bool, reason, message string) { if !healthy { metrics.RecordUnhealthyUpstream(e.Cluster, e.Endpoint, reason) } + e.status.SetStatus(healthy, reason, message) + if e.status.Healthy != healthy { // healthy changed - e.status.SetStatus(healthy, reason, message) e.recordStatusChange() } } diff --git a/pkg/clusters/endpoint_test.go b/pkg/clusters/endpoint_test.go index 9bee9e1..2e560aa 100644 --- a/pkg/clusters/endpoint_test.go +++ b/pkg/clusters/endpoint_test.go @@ -21,13 +21,13 @@ import ( func TestEndpointInfo_ReadyAndReason(t *testing.T) { tests := []struct { name string - status endpointStatus + status *endpointStatus wantReady bool want string }{ { "ready", - endpointStatus{ + &endpointStatus{ Disabled: false, Healthy: true, }, @@ -36,7 +36,7 @@ func TestEndpointInfo_ReadyAndReason(t *testing.T) { }, { "disabled", - endpointStatus{ + &endpointStatus{ Disabled: true, Healthy: true, }, @@ -45,7 +45,7 @@ func TestEndpointInfo_ReadyAndReason(t *testing.T) { }, { "unhealthy", - endpointStatus{ + &endpointStatus{ Disabled: false, Healthy: false, Reason: "Timeout", @@ -71,3 +71,34 @@ func TestEndpointInfo_ReadyAndReason(t *testing.T) { }) } } + +func TestEndpointInfo_UnhealthyCount(t *testing.T) { + e := &EndpointInfo{ + Endpoint: "", + status: &endpointStatus{ + Disabled: true, + Healthy: true, + }, + } + if e.GetUnhealthyCount() != 0 { + t.Errorf("unhealthy count should be 0, actual: %d", e.GetUnhealthyCount()) + } + + for i := 0; i < 2; i++ { + e.UpdateStatus(false, "mock error", "mock error message") + if e.GetUnhealthyCount() != i+1 { + t.Errorf("unhealthy count should be %d, actual: %d", i+1, e.GetUnhealthyCount()) + } + } + e.UpdateStatus(true, "", "") + if e.GetUnhealthyCount() != 0 { + t.Errorf("unhealthy count should be 0, actual: %d", e.GetUnhealthyCount()) + } + + for i := 0; i < 5; i++ { + e.UpdateStatus(false, "mock error", "mock error message") + if e.GetUnhealthyCount() != i+1 { + t.Errorf("unhealthy count should be %d, actual: %d", i+1, e.GetUnhealthyCount()) + } + } +} diff --git a/pkg/gateway/controllers/upstream_controller.go b/pkg/gateway/controllers/upstream_controller.go index 5b4c654..bf15d49 100644 --- a/pkg/gateway/controllers/upstream_controller.go +++ b/pkg/gateway/controllers/upstream_controller.go @@ -360,5 +360,16 @@ func GatewayHealthCheck(e *clusters.EndpointInfo) (done bool) { } klog.Errorf("upstream health check failed, cluster=%q endpoint=%q reason=%q message=%q", e.Cluster, e.Endpoint, reason, message) e.UpdateStatus(false, reason, message) + + const ResetTransportThreshold = 3 + if err != nil && + strings.Contains(err.Error(), "Client.Timeout or context cancellation while reading body") && + e.GetUnhealthyCount() >= ResetTransportThreshold { + klog.Warningf("transport to endpoint %s hang", e.Endpoint) + if err := e.ResetTransport(); err != nil { + klog.Warningf("reset transport to endpoint %s error: %v", e.Endpoint, err) + } + } + return done } diff --git a/pkg/transport/cancelable_transport.go b/pkg/transport/cancelable_transport.go new file mode 100644 index 0000000..101a215 --- /dev/null +++ b/pkg/transport/cancelable_transport.go @@ -0,0 +1,37 @@ +package transport + +import ( + "context" + "net/http" +) + +type CancelableTransport struct { + inner http.RoundTripper + ctx context.Context + cancel context.CancelFunc +} + +func (ts *CancelableTransport) RoundTrip(r *http.Request) (*http.Response, error) { + reqCtx := r.Context() + ctx, cancel := context.WithCancel(reqCtx) + go func() { + select { + case <-reqCtx.Done(): // req ctx is done + case <-ts.ctx.Done(): // transport is done + cancel() + } + }() + r2 := r.Clone(ctx) + return ts.inner.RoundTrip(r2) +} +func (ts *CancelableTransport) Close() { + ts.cancel() +} +func NewCancelableTransport(inner http.RoundTripper) *CancelableTransport { + ctx, cancel := context.WithCancel(context.Background()) + return &CancelableTransport{ + inner: inner, + ctx: ctx, + cancel: cancel, + } +} diff --git a/pkg/transport/cancelable_transport_test.go b/pkg/transport/cancelable_transport_test.go new file mode 100644 index 0000000..01d63cf --- /dev/null +++ b/pkg/transport/cancelable_transport_test.go @@ -0,0 +1,75 @@ +package transport + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestCancelableTransportClose(t *testing.T) { + ts := httptest.NewServer(new(testServer)) + defer ts.Close() + ct := NewCancelableTransport(http.DefaultTransport) + client := http.Client{ + Transport: ct, + } + resp, err := client.Get(ts.URL) + if err != nil { + t.Errorf("get error: %v", err) + } + resp.Body.Close() + time.AfterFunc(200*time.Millisecond, func() { + ct.Close() + }) + start := time.Now() + _, err = client.Get(ts.URL) + latency := time.Since(start) + if err == nil { + t.Errorf("should have error") + } + if latency > 300*time.Millisecond { + t.Errorf("should cancel request") + } +} + +func TestCancelableTransportCancel(t *testing.T) { + ct := NewCancelableTransport(new(mockTransport)) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + req, err := http.NewRequestWithContext(ctx, "GET", "localhost", nil) + if err != nil { + t.Errorf("new request error") + } + _, err = ct.RoundTrip(req) + if err == nil { + t.Errorf("should have error") + } + if err.Error() != "req context done" { + t.Errorf("unexpected error: %v", err) + } +} + +type mockTransport struct { +} + +// RoundTrip implements http.RoundTripper. +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + select { + case <-ctx.Done(): + return nil, fmt.Errorf("req context done") + default: + } + return nil, fmt.Errorf("mock error") +} + +type testServer struct { +} + +func (s *testServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + time.Sleep(500 * time.Millisecond) + w.WriteHeader(200) +}