diff --git a/rpcserver/jsonrpc_server.go b/rpcserver/jsonrpc_server.go index 3c31ab4..8a862f5 100644 --- a/rpcserver/jsonrpc_server.go +++ b/rpcserver/jsonrpc_server.go @@ -96,6 +96,9 @@ type JSONRPCHandlerOpts struct { ExtractOriginFromHeader bool // GET response content GetResponseContent []byte + // Custom handler for /readyz endpoint. If not nil then it is expected to write the response to the provided ResponseWriter. + // If the custom handler returns an error, the error message is written to the ResponseWriter with a 500 status code. + ReadyHandler func(w http.ResponseWriter, r *http.Request) error } // NewJSONRPCHandler creates JSONRPC http.Handler from the map that maps method names to method functions @@ -162,9 +165,27 @@ func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { stepStartAt := time.Now() - if r.Method != http.MethodPost { - // Respond with GET response content if it's set - if r.Method == http.MethodGet && len(h.GetResponseContent) > 0 { + // Some GET requests are allowed + if r.Method == http.MethodGet { + if r.URL.Path == "/livez" { + w.WriteHeader(http.StatusOK) + return + } else if r.URL.Path == "/readyz" { + if h.JSONRPCHandlerOpts.ReadyHandler == nil { + http.Error(w, "ready handler is not set", http.StatusNotFound) + incIncorrectRequest(h.ServerName) + return + } else { + // Handler is expected to write the Response + err := h.JSONRPCHandlerOpts.ReadyHandler(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + incInternalErrors(h.ServerName) + } + return + } + } else if len(h.GetResponseContent) > 0 { + // Static response for all other GET requests w.WriteHeader(http.StatusOK) _, err := w.Write(h.GetResponseContent) if err != nil { @@ -174,8 +195,10 @@ func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } + } - // Responsd with "only POST method is allowed" + // From here we only accept POST requests with JSON body + if r.Method != http.MethodPost { http.Error(w, errMethodNotAllowed, http.StatusMethodNotAllowed) incIncorrectRequest(h.ServerName) return diff --git a/rpcserver/jsonrpc_server_test.go b/rpcserver/jsonrpc_server_test.go index 7f96bc9..4bb1450 100644 --- a/rpcserver/jsonrpc_server_test.go +++ b/rpcserver/jsonrpc_server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "net/http" "net/http/httptest" "testing" @@ -120,3 +121,63 @@ func TestJSONRPCServerWithSignatureWithClient(t *testing.T) { require.NoError(t, err) require.Equal(t, 123, structResp.Field) } + +func TestJSONRPCServerDefaultLiveAndReady(t *testing.T) { + handler := testHandler(JSONRPCHandlerOpts{}) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // /livez (200 by default) + request, err := http.NewRequest(http.MethodGet, "/livez", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "", rr.Body.String()) + + // /readyz (404 by default) + request, err = http.NewRequest(http.MethodGet, "/readyz", nil) + require.NoError(t, err) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, request) + require.Equal(t, http.StatusNotFound, rr.Code) +} + +func TestJSONRPCServerReadyzOK(t *testing.T) { + handler := testHandler(JSONRPCHandlerOpts{ + ReadyHandler: func(w http.ResponseWriter, r *http.Request) error { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("ready")) + return err + }, + }) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + request, err := http.NewRequest(http.MethodGet, "/readyz", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, request) + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "ready", rr.Body.String()) +} + +func TestJSONRPCServerReadyzError(t *testing.T) { + handler := testHandler(JSONRPCHandlerOpts{ + ReadyHandler: func(w http.ResponseWriter, r *http.Request) error { + return fmt.Errorf("not ready") + }, + }) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + request, err := http.NewRequest(http.MethodGet, "/readyz", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, request) + require.Equal(t, http.StatusInternalServerError, rr.Code) + fmt.Println(rr.Body.String()) + require.Equal(t, "not ready\n", rr.Body.String()) +}