Skip to content

fix(xunix): improve handling of gpu library mounts #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 105 additions & 8 deletions integration/gpu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/coder/envbox/integration/integrationtest"
Expand Down Expand Up @@ -41,8 +42,7 @@ func TestDocker_Nvidia(t *testing.T) {
)

// Assert that we can run nvidia-smi in the inner container.
_, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "nvidia-smi")
require.NoError(t, err, "failed to run nvidia-smi in the inner container")
assertInnerNvidiaSMI(ctx, t, ctID)
})

t.Run("Redhat", func(t *testing.T) {
Expand All @@ -52,16 +52,23 @@ func TestDocker_Nvidia(t *testing.T) {

// Start the envbox container.
ctID := startEnvboxCmd(ctx, t, integrationtest.RedhatImage, "root",
"-v", "/usr/lib/x86_64-linux-gnu:/var/coder/usr/lib64",
"-v", "/usr/lib/x86_64-linux-gnu:/var/coder/usr/lib",
"--env", "CODER_ADD_GPU=true",
"--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib64",
"--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib",
"--runtime=nvidia",
"--gpus=all",
)

// Assert that we can run nvidia-smi in the inner container.
_, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "nvidia-smi")
require.NoError(t, err, "failed to run nvidia-smi in the inner container")
assertInnerNvidiaSMI(ctx, t, ctID)

// Make sure dnf still works. This checks for a regression due to
// gpuExtraRegex matching `libglib.so` in the outer container.
// This had a dependency on `libpcre.so.3` which would cause dnf to fail.
out, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "dnf")
if !assert.NoError(t, err, "failed to run dnf in the inner container") {
t.Logf("dnf output:\n%s", strings.TrimSpace(out))
}
})

t.Run("InnerUsrLibDirOverride", func(t *testing.T) {
Expand All @@ -79,11 +86,58 @@ func TestDocker_Nvidia(t *testing.T) {
"--gpus=all",
)

// Assert that the libraries end up in the expected location in the inner container.
out, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "ls", "-l", "/usr/lib/coder")
// Assert that the libraries end up in the expected location in the inner
// container.
out, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "ls", "-1", "/usr/lib/coder")
require.NoError(t, err, "inner usr lib dir override failed")
require.Regexp(t, `(?i)(libgl|nvidia|vulkan|cuda)`, out)
})

t.Run("EmptyHostUsrLibDir", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
emptyUsrLibDir := t.TempDir()

// Start the envbox container.
ctID := startEnvboxCmd(ctx, t, integrationtest.UbuntuImage, "root",
"-v", emptyUsrLibDir+":/var/coder/usr/lib",
"--env", "CODER_ADD_GPU=true",
"--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib",
"--runtime=nvidia",
"--gpus=all",
)

ofs := outerFiles(ctx, t, ctID, "/usr/lib/x86_64-linux-gnu/libnv*")
// Assert invariant: the outer container has the files we expect.
require.NotEmpty(t, ofs, "failed to list outer container files")
// Assert that expected files are available in the inner container.
assertInnerFiles(ctx, t, ctID, "/usr/lib/x86_64-linux-gnu/libnv*", ofs...)
assertInnerNvidiaSMI(ctx, t, ctID)
})
Comment on lines +102 to +123
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review: this tests that we can get by with no extra files in CODER_USR_LIB_DIR


t.Run("CUDASample", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

// Start the envbox container.
ctID := startEnvboxCmd(ctx, t, integrationtest.CUDASampleImage, "root",
"-v", "/usr/lib/x86_64-linux-gnu:/var/coder/usr/lib",
"--env", "CODER_ADD_GPU=true",
"--env", "CODER_USR_LIB_DIR=/var/coder/usr/lib",
"--runtime=nvidia",
"--gpus=all",
)

// Assert that we can run nvidia-smi in the inner container.
assertInnerNvidiaSMI(ctx, t, ctID)

// Assert that /tmp/vectorAdd runs successfully in the inner container.
_, err := execContainerCmd(ctx, t, ctID, "docker", "exec", "workspace_cvm", "/tmp/vectorAdd")
require.NoError(t, err, "failed to run /tmp/vectorAdd in the inner container")
})
}

// dockerRuntimes returns the list of container runtimes available on the host.
Expand All @@ -101,6 +155,49 @@ func dockerRuntimes(t *testing.T) []string {
return strings.Split(raw, "\n")
}

// outerFiles returns the list of files in the outer container matching the
// given pattern. It does this by running `ls -1` in the outer container.
func outerFiles(ctx context.Context, t *testing.T, containerID, pattern string) []string {
t.Helper()
// We need to use /bin/sh -c to avoid the shell interpreting the glob.
out, err := execContainerCmd(ctx, t, containerID, "/bin/sh", "-c", "ls -1 "+pattern)
require.NoError(t, err, "failed to list outer container files")
files := strings.Split(strings.TrimSpace(out), "\n")
slices.Sort(files)
return files
}

// assertInnerFiles checks that all the files matching the given pattern exist in the
// inner container.
func assertInnerFiles(ctx context.Context, t *testing.T, containerID, pattern string, expected ...string) {
t.Helper()

// Get the list of files in the inner container.
// We need to use /bin/sh -c to avoid the shell interpreting the glob.
out, err := execContainerCmd(ctx, t, containerID, "docker", "exec", "workspace_cvm", "/bin/sh", "-c", "ls -1 "+pattern)
require.NoError(t, err, "failed to list inner container files")
innerFiles := strings.Split(strings.TrimSpace(out), "\n")

// Check that the expected files exist in the inner container.
missingFiles := make([]string, 0)
for _, expectedFile := range expected {
if !slices.Contains(innerFiles, expectedFile) {
missingFiles = append(missingFiles, expectedFile)
}
}
require.Empty(t, missingFiles, "missing files in inner container: %s", strings.Join(missingFiles, ", "))
}

// assertInnerNvidiaSMI checks that nvidia-smi runs successfully in the inner
// container.
func assertInnerNvidiaSMI(ctx context.Context, t *testing.T, containerID string) {
t.Helper()
// Assert that we can run nvidia-smi in the inner container.
out, err := execContainerCmd(ctx, t, containerID, "docker", "exec", "workspace_cvm", "nvidia-smi")
require.NoError(t, err, "failed to run nvidia-smi in the inner container")
require.Contains(t, out, "NVIDIA-SMI", "nvidia-smi output does not contain NVIDIA-SMI")
}

// startEnvboxCmd starts the envbox container with the given arguments.
// Ideally we would use ory/dockertest for this, but it doesn't support
// specifying the runtime. We have alternatively used the docker client library,
Expand Down
3 changes: 3 additions & 0 deletions integration/integrationtest/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ const (
UbuntuImage = "gcr.io/coder-dev-1/sreya/ubuntu-coder"
// Redhat UBI9 image as of 2025-03-05
RedhatImage = "registry.access.redhat.com/ubi9/ubi:9.5"
// CUDASampleImage is a CUDA sample image from NVIDIA's container registry.
// It contains a binary /tmp/vectorAdd which can be run to test the CUDA setup.
CUDASampleImage = "nvcr.io/nvidia/k8s/cuda-sample:vectoradd-cuda10.2"

// RegistryImage is used to assert that we add certs
// correctly to the docker daemon when pulling an image
Expand Down
92 changes: 88 additions & 4 deletions xunix/gpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"regexp"
"slices"
"sort"
"strings"

Expand All @@ -17,9 +18,9 @@ import (
)

var (
gpuMountRegex = regexp.MustCompile("(?i)(nvidia|vulkan|cuda)")
gpuExtraRegex = regexp.MustCompile("(?i)(libgl|nvidia|vulkan|cuda)")
gpuEnvRegex = regexp.MustCompile("(?i)nvidia")
gpuMountRegex = regexp.MustCompile(`(?i)(nvidia|vulkan|cuda)`)
gpuExtraRegex = regexp.MustCompile(`(?i)(libgl(e|sx|\.)|nvidia|vulkan|cuda)`)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review: modified this regex to hopefully match the right things and not the wrong things

gpuEnvRegex = regexp.MustCompile(`(?i)nvidia`)
sharedObjectRegex = regexp.MustCompile(`\.so(\.[0-9\.]+)?$`)
)

Expand All @@ -39,6 +40,7 @@ func GPUEnvs(ctx context.Context) []string {

func GPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]Device, []mount.MountPoint, error) {
var (
afs = GetFS(ctx)
mounter = Mounter(ctx)
devices = []Device{}
binds = []mount.MountPoint{}
Expand All @@ -64,6 +66,22 @@ func GPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]Device, []m

// If it's not in /dev treat it as a bind mount.
binds = append(binds, m)
// We also want to find any symlinks that point to the target.
// This is important for the nvidia driver as it mounts the driver
// files with the driver version appended to the end, and creates
// symlinks that point to the actual files.
links, err := SameDirSymlinks(afs, m.Path)
if err != nil {
log.Error(ctx, "find symlinks", slog.F("path", m.Path), slog.Error(err))
} else {
for _, link := range links {
log.Debug(ctx, "found symlink", slog.F("link", link), slog.F("target", m.Path))
binds = append(binds, mount.MountPoint{
Path: link,
Opts: []string{"ro"},
})
}
}
}
}

Expand Down Expand Up @@ -104,7 +122,11 @@ func usrLibGPUs(ctx context.Context, log slog.Logger, usrLibDir string) ([]mount
return nil
}

if !sharedObjectRegex.MatchString(path) || !gpuExtraRegex.MatchString(path) {
if !gpuExtraRegex.MatchString(path) {
return nil
}

if !sharedObjectRegex.MatchString(path) {
Comment on lines +125 to +129
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review: this makes the control flow a little easier to read; I accidentally removed this but it still performs an important task

return nil
}

Expand Down Expand Up @@ -176,6 +198,68 @@ func recursiveSymlinks(afs FS, mountpoint string, path string) ([]string, error)
return paths, nil
}

// SameDirSymlinks returns all links in the same directory as `target` that
// point to target, either indirectly or directly. Only symlinks in the same
// directory as `target` are considered.
func SameDirSymlinks(afs FS, target string) ([]string, error) {
var (
found = make([]string, 0)
maxIterations = 10 // arbitrary upper limit to prevent infinite loops
)
for range maxIterations {
foundThisTime := false
fis, err := afero.ReadDir(afs, filepath.Dir(target))
if err != nil {
return nil, xerrors.Errorf("read dir %q: %w", filepath.Dir(target), err)
}
for _, fi := range fis {
// Ignore the target itself.
if fi.Name() == filepath.Base(target) {
continue
}
// Ignore non-symlinks.
if fi.Mode()&os.ModeSymlink == 0 {
continue
}
// Get the target of the symlink.
link, err := afs.Readlink(filepath.Join(filepath.Dir(target), fi.Name()))
if err != nil {
return nil, xerrors.Errorf("readlink %q: %w", fi.Name(), err)
}
// Make the link absolute.
if !filepath.IsAbs(link) {
link = filepath.Join(filepath.Dir(target), link)
}
// Ignore symlinks that point outside of target's directory.
if filepath.Dir(link) != filepath.Dir(target) {
continue
}

// Check if the symlink points to to the target, or if it points
// to one of the symlinks we've already found.
if link != target {
if !slices.Contains(found, link) {
continue
}
}

// Have we already seen this target?
fullPath := filepath.Join(filepath.Dir(target), fi.Name())
if slices.Contains(found, fullPath) {
continue
}

found = append(found, filepath.Join(filepath.Dir(target), fi.Name()))
foundThisTime = true
}
// If we didn't find any symlinks this time, we're done.
if !foundThisTime {
break
}
}
return found, nil
}

// TryUnmountProcGPUDrivers unmounts any GPU-related mounts under /proc as it causes
// issues when creating any container in some cases. Errors encountered while
// unmounting are treated as non-fatal.
Expand Down
Loading
Loading