diff --git a/support/defaults.go b/support/defaults.go index 72d1091..c165b22 100644 --- a/support/defaults.go +++ b/support/defaults.go @@ -5,6 +5,7 @@ package support // *********************** const ( - RayVersion = "2.35.0" - RayImage = "quay.io/modh/ray:2.35.0-py39-cu121" + RayVersion = "2.35.0" + RayImage = "quay.io/modh/ray:2.35.0-py39-cu121" + RayROCmImage = "quay.io/modh/ray:2.35.0-py39-rocm61" ) diff --git a/support/environment.go b/support/environment.go index 6ed07a6..3de25e6 100644 --- a/support/environment.go +++ b/support/environment.go @@ -27,6 +27,7 @@ const ( CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION" CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE" + CodeFlareTestRayROCmImage = "CODEFLARE_TEST_RAY_ROCM_IMAGE" CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE" // The testing output directory, to write output files into. @@ -78,6 +79,10 @@ func GetRayImage() string { return lookupEnvOrDefault(CodeFlareTestRayImage, RayImage) } +func GetRayROCmImage() string { + return lookupEnvOrDefault(CodeFlareTestRayROCmImage, RayROCmImage) +} + func GetPyTorchImage() string { return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime") }