Skip to content

Update ray support functions to handle ray job api operation using tls check verification for insecure cluster #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 101 additions & 9 deletions support/ray_cluster_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ package support

import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"

. "github.com/onsi/gomega"
)

type RayJobSetup struct {
Expand All @@ -45,20 +49,32 @@ type RayJobLogsResponse struct {
Logs string `json:"logs"`
}

type RayClusterClientConfig struct {
SkipTlsVerification bool
}

var _ RayClusterClient = (*rayClusterClient)(nil)

type rayClusterClient struct {
endpoint url.URL
endpoint url.URL
httpClient *http.Client
bearerToken string
}

type RayClusterClient interface {
CreateJob(job *RayJobSetup) (*RayJobResponse, error)
GetJobDetails(jobID string) (*RayJobDetailsResponse, error)
GetJobLogs(jobID string) (string, error)
GetJobs() ([]map[string]interface{}, error)
WaitForJobStatus(test Test, jobID string) string
}

func NewRayClusterClient(dashboardEndpoint url.URL) RayClusterClient {
return &rayClusterClient{endpoint: dashboardEndpoint}
func NewRayClusterClient(dashboardEndpoint url.URL, config RayClusterClientConfig, bearerToken string) RayClusterClient {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: config.SkipTlsVerification},
Proxy: http.ProxyFromEnvironment,
}
return &rayClusterClient{endpoint: dashboardEndpoint, httpClient: &http.Client{Transport: tr}, bearerToken: bearerToken}
}

func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobResponse, err error) {
Expand All @@ -68,7 +84,8 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
}

createJobURL := client.endpoint.String() + "/api/jobs/"
resp, err := http.Post(createJobURL, "application/json", bytes.NewReader(marshalled))

resp, err := client.httpClient.Post(createJobURL, "application/json", bytes.NewReader(marshalled))
if err != nil {
return
}
Expand All @@ -86,11 +103,51 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
return
}

func (client *rayClusterClient) GetJobs() ([]map[string]interface{}, error) {
getAllJobsDetailsURL := client.endpoint.String() + "/api/jobs/"

req, err := http.NewRequest(http.MethodGet, getAllJobsDetailsURL, nil)
if err != nil {
return nil, err
}
if client.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
}
resp, err := client.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == 503 {
return nil, fmt.Errorf("service unavailable")
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

var result []map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
}
return result, nil
}

func (client *rayClusterClient) GetJobDetails(jobID string) (response *RayJobDetailsResponse, err error) {
getJobDetailsURL := client.endpoint.String() + "/api/jobs/" + jobID
resp, err := http.Get(getJobDetailsURL)

req, err := http.NewRequest(http.MethodGet, getJobDetailsURL, nil)
if err != nil {
return
return nil, err
}
if client.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
}

resp, err := client.httpClient.Do(req)
if err != nil {
return nil, err
}

respData, err := io.ReadAll(resp.Body)
Expand All @@ -108,14 +165,21 @@ func (client *rayClusterClient) GetJobDetails(jobID string) (response *RayJobDet

func (client *rayClusterClient) GetJobLogs(jobID string) (logs string, err error) {
getJobLogsURL := client.endpoint.String() + "/api/jobs/" + jobID + "/logs"
resp, err := http.Get(getJobLogsURL)
req, err := http.NewRequest(http.MethodGet, getJobLogsURL, nil)
if err != nil {
return
return "", err
}
if client.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
}
resp, err := client.httpClient.Do(req)
if err != nil {
return "", err
}

respData, err := io.ReadAll(resp.Body)
if err != nil {
return
return "", err
}

if resp.StatusCode != 200 {
Expand All @@ -126,3 +190,31 @@ func (client *rayClusterClient) GetJobLogs(jobID string) (logs string, err error
err = json.Unmarshal(respData, &jobLogs)
return jobLogs.Logs, err
}

func (client *rayClusterClient) WaitForJobStatus(test Test, jobID string) string {
var status string
fmt.Printf("Waiting for job to be Succeeded...\n")

test.Eventually(func() string {
resp, err := client.GetJobDetails(jobID)
test.Expect(err).ToNot(HaveOccurred())
statusVal := resp.Status
if statusVal == "SUCCEEDED" || statusVal == "FAILED" {
fmt.Printf("JobStatus : %s\n", statusVal)
status = statusVal
return status
}
if status != statusVal && statusVal != "SUCCEEDED" {
fmt.Printf("JobStatus : %s...\n", statusVal)
status = statusVal
}
return status
}, TestTimeoutDouble, 3*time.Second).Should(Or(Equal("SUCCEEDED"), Equal("FAILED")), "Job did not complete within the expected time")

if status == "SUCCEEDED" {
fmt.Printf("Job succeeded !\n")
} else {
fmt.Printf("Job failed !\n")
}
return status
}
1 change: 1 addition & 0 deletions support/support.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var (
TestTimeoutShort = 1 * time.Minute
TestTimeoutMedium = 2 * time.Minute
TestTimeoutLong = 5 * time.Minute
TestTimeoutDouble = 10 * time.Minute
TestTimeoutGpuProvisioning = 30 * time.Minute
)

Expand Down