From 5f673862c5746e936af8daae949e44cefa4bbf80 Mon Sep 17 00:00:00 2001 From: Karel Suta Date: Fri, 13 Sep 2024 14:38:24 +0200 Subject: [PATCH] Ray cluster client redesign --- support/ray_api.go | 4 +- support/ray_cluster_client.go | 114 ++++++++++----------------- support/ray_cluster_client_helper.go | 44 +++++++++++ 3 files changed, 87 insertions(+), 75 deletions(-) create mode 100644 support/ray_cluster_client_helper.go diff --git a/support/ray_api.go b/support/ray_api.go index 1f04f6d..c088a7a 100644 --- a/support/ray_api.go +++ b/support/ray_api.go @@ -27,9 +27,9 @@ func GetRayJobAPIDetails(t Test, rayClient RayClusterClient, jobID string) *RayJ func WriteRayJobAPILogs(t Test, rayClient RayClusterClient, jobID string) { t.T().Helper() - logs, err := rayClient.GetJobLogs(jobID) + jobLogs, err := rayClient.GetJobLogs(jobID) t.Expect(err).NotTo(gomega.HaveOccurred()) - WriteToOutputDir(t, "ray-job-log-"+jobID, Log, []byte(logs)) + WriteToOutputDir(t, "ray-job-log-"+jobID, Log, []byte(jobLogs.Logs)) } func RayJobAPIDetails(t Test, rayClient RayClusterClient, jobID string) func(g gomega.Gomega) *RayJobDetailsResponse { diff --git a/support/ray_cluster_client.go b/support/ray_cluster_client.go index ca9cd1a..f92e36d 100644 --- a/support/ray_cluster_client.go +++ b/support/ray_cluster_client.go @@ -18,7 +18,6 @@ package support import ( "bytes" - "crypto/tls" "encoding/json" "fmt" "io" @@ -47,39 +46,30 @@ type RayJobLogsResponse struct { } type RayClusterClientConfig struct { - Address string - Client *http.Client - InsecureSkipVerify bool + Address string + Client *http.Client } var _ RayClusterClient = (*rayClusterClient)(nil) type rayClusterClient struct { - endpoint url.URL - httpClient *http.Client - bearerToken string + config RayClusterClientConfig } type RayClusterClient interface { CreateJob(job *RayJobSetup) (*RayJobResponse, error) GetJobDetails(jobID string) (*RayJobDetailsResponse, error) - GetJobLogs(jobID string) (string, error) - GetJobs() (*[]RayJobDetailsResponse, error) + GetJobLogs(jobID string) (*RayJobLogsResponse, error) + ListJobs() ([]RayJobDetailsResponse, error) } -func NewRayClusterClient(config RayClusterClientConfig, bearerToken string) (RayClusterClient, error) { - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: config.InsecureSkipVerify}, - Proxy: http.ProxyFromEnvironment, - } - config.Client = &http.Client{Transport: tr} +func NewRayClusterClient(config RayClusterClientConfig) (RayClusterClient, error) { endpoint, err := url.Parse(config.Address) if err != nil { - return nil, fmt.Errorf("invalid dashboard endpoint address") - } - rayClusterApiClient := &rayClusterClient{ - endpoint: *endpoint, httpClient: config.Client, bearerToken: bearerToken, + return nil, fmt.Errorf("invalid dashboard endpoint address: %s", endpoint) } + + rayClusterApiClient := &rayClusterClient{config} return rayClusterApiClient, nil } @@ -89,13 +79,15 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes return } - createJobURL := client.endpoint.String() + "/api/jobs/" + createJobURL := client.config.Address + "/api/jobs/" - resp, err := client.httpClient.Post(createJobURL, "application/json", bytes.NewReader(marshalled)) + resp, err := client.config.Client.Post(createJobURL, "application/json", bytes.NewReader(marshalled)) if err != nil { return } + defer resp.Body.Close() + respData, err := io.ReadAll(resp.Body) if err != nil { return @@ -109,95 +101,71 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes return } -func (client *rayClusterClient) GetJobs() (response *[]RayJobDetailsResponse, err error) { - getAllJobsDetailsURL := client.endpoint.String() + "/api/jobs/" +func (client *rayClusterClient) ListJobs() (response []RayJobDetailsResponse, err error) { + getAllJobsDetailsURL := client.config.Address + "/api/jobs/" - req, err := http.NewRequest(http.MethodGet, getAllJobsDetailsURL, nil) + resp, err := client.config.Client.Get(getAllJobsDetailsURL) if err != nil { - return nil, err - } - if client.bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+client.bearerToken) - } - resp, err := client.httpClient.Do(req) - if err != nil { - return nil, err + return } + defer resp.Body.Close() - if resp.StatusCode == 503 { - return nil, fmt.Errorf("service unavailable") - } + respData, err := io.ReadAll(resp.Body) if err != nil { - return nil, err + return } + if resp.StatusCode != 200 { return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job details, response body: %s", resp.StatusCode, respData) } + err = json.Unmarshal(respData, &response) - if err != nil { - return nil, err - } - return response, nil + return } func (client *rayClusterClient) GetJobDetails(jobID string) (response *RayJobDetailsResponse, err error) { - getJobDetailsURL := client.endpoint.String() + "/api/jobs/" + jobID + getJobDetailsURL := client.config.Address + "/api/jobs/" + jobID - req, err := http.NewRequest(http.MethodGet, getJobDetailsURL, nil) + resp, err := client.config.Client.Get(getJobDetailsURL) if err != nil { - return nil, err - } - if client.bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+client.bearerToken) + return } - resp, err := client.httpClient.Do(req) - if err != nil { - return nil, err - } - if resp.StatusCode == 503 { - return nil, fmt.Errorf("service unavailable") - } + defer resp.Body.Close() respData, err := io.ReadAll(resp.Body) if err != nil { return } + if resp.StatusCode != 200 { return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job details, response body: %s", resp.StatusCode, respData) } + err = json.Unmarshal(respData, &response) - if err != nil { - return nil, err - } - return response, nil + return } -func (client *rayClusterClient) GetJobLogs(jobID string) (logs string, err error) { - getJobLogsURL := client.endpoint.String() + "/api/jobs/" + jobID + "/logs" - req, err := http.NewRequest(http.MethodGet, getJobLogsURL, nil) - if err != nil { - return "", err - } - if client.bearerToken != "" { - req.Header.Set("Authorization", "Bearer "+client.bearerToken) - } - resp, err := client.httpClient.Do(req) +func (client *rayClusterClient) GetJobLogs(jobID string) (response *RayJobLogsResponse, err error) { + getJobLogsURL := client.config.Address + "/api/jobs/" + jobID + "/logs" + + resp, err := client.config.Client.Get(getJobLogsURL) if err != nil { - return "", err + return } + defer resp.Body.Close() + respData, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return } if resp.StatusCode != 200 { - return "", fmt.Errorf("incorrect response code: %d for retrieving Ray Job logs, response body: %s", resp.StatusCode, respData) + return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job logs, response body: %s", resp.StatusCode, respData) } - jobLogs := RayJobLogsResponse{} - err = json.Unmarshal(respData, &jobLogs) - return jobLogs.Logs, err + err = json.Unmarshal(respData, &response) + return } diff --git a/support/ray_cluster_client_helper.go b/support/ray_cluster_client_helper.go new file mode 100644 index 0000000..160c213 --- /dev/null +++ b/support/ray_cluster_client_helper.go @@ -0,0 +1,44 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package support + +import ( + "crypto/tls" + "net/http" + + . "github.com/onsi/gomega" + + "k8s.io/client-go/transport" +) + +func GetRayClusterClient(t Test, dashboardURL, bearerToken string) RayClusterClient { + t.T().Helper() + + // Skip TLS check to work on clusters with insecure certificates too + // Functionality intended just for testing purpose, DO NOT USE IN PRODUCTION + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + Proxy: http.ProxyFromEnvironment, + } + client, err := NewRayClusterClient(RayClusterClientConfig{ + Address: dashboardURL, + Client: &http.Client{Transport: transport.NewBearerAuthRoundTripper(bearerToken, tr)}, + }) + t.Expect(err).NotTo(HaveOccurred()) + + return client +}