From c1595772e39268ae01c91d45a378f61759d36d60 Mon Sep 17 00:00:00 2001 From: Matt Olson Date: Wed, 16 Apr 2025 22:57:36 -0400 Subject: [PATCH 1/3] feat(smb): add volume isolation and stage/unstage support to SMB CSI driver - Enables NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME and ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME - Appends SHA-256 hash of SMB credentials to volume ID for credential-based volume isolation - NodePublishVolume updated to bind-mount from staging path to pod volume path - Implements NodeUnstageVolume with reference count check using /proc/mounts - Preserves all existing SMB CSI functionality (Kerberos, GID, subDir handling) --- pkg/csi-common/driver.go | 4 +- pkg/smb/controllerserver.go | 13 ++++ pkg/smb/controllerserver_test.go | 7 +- pkg/smb/nodeserver.go | 123 ++++++++++++++++++++++--------- pkg/smb/smb.go | 1 + 5 files changed, 110 insertions(+), 38 deletions(-) diff --git a/pkg/csi-common/driver.go b/pkg/csi-common/driver.go index d70c58cafde..9046b3483b4 100644 --- a/pkg/csi-common/driver.go +++ b/pkg/csi-common/driver.go @@ -94,7 +94,7 @@ func (d *CSIDriver) AddControllerServiceCapabilities(cl []csi.ControllerServiceC csc = append(csc, NewControllerServiceCapability(c)) } - d.Cap = csc + d.Cap = append(d.Cap, csc...) } func (d *CSIDriver) AddNodeServiceCapabilities(nl []csi.NodeServiceCapability_RPC_Type) { @@ -103,7 +103,7 @@ func (d *CSIDriver) AddNodeServiceCapabilities(nl []csi.NodeServiceCapability_RP klog.V(2).Infof("Enabling node service capability: %v", n.String()) nsc = append(nsc, NewNodeServiceCapability(n)) } - d.NSCap = nsc + d.NSCap = append(d.NSCap, nsc...) } func (d *CSIDriver) AddVolumeCapabilityAccessModes(vc []csi.VolumeCapability_AccessMode_Mode) []*csi.VolumeCapability_AccessMode { diff --git a/pkg/smb/controllerserver.go b/pkg/smb/controllerserver.go index b16d0b88eaa..ead2c14eeea 100644 --- a/pkg/smb/controllerserver.go +++ b/pkg/smb/controllerserver.go @@ -18,6 +18,8 @@ package smb import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "io/fs" "os" @@ -85,6 +87,17 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } secrets := req.GetSecrets() + username := strings.TrimSpace(secrets["username"]) + password := strings.TrimSpace(secrets["password"]) + if username != "" || password != "" { + hashKey := fmt.Sprintf("%s|%s", username, password) + hash := sha256.Sum256([]byte(hashKey)) + hashStr := hex.EncodeToString(hash[:8]) + smbVol.id = fmt.Sprintf("%s#cred=%s", getVolumeIDFromSmbVol(smbVol), hashStr) + } else { + smbVol.id = getVolumeIDFromSmbVol(smbVol) + } + createSubDir := len(secrets) > 0 if len(smbVol.uuid) > 0 { klog.V(2).Infof("create subdirectory(%s) if not exists", smbVol.subDir) diff --git a/pkg/smb/controllerserver_test.go b/pkg/smb/controllerserver_test.go index 64af250ef1a..17741470282 100644 --- a/pkg/smb/controllerserver_test.go +++ b/pkg/smb/controllerserver_test.go @@ -23,6 +23,7 @@ import ( "path/filepath" "reflect" "runtime" + "strings" "testing" "github.com/container-storage-interface/spec/lib/go/csi" @@ -203,7 +204,11 @@ func TestCreateVolume(t *testing.T) { if !test.expectErr && err != nil { t.Errorf("test %q failed: %v", test.name, err) } - if !reflect.DeepEqual(resp, test.resp) { + if !test.expectErr && test.name == "valid defaults" { + if resp.Volume == nil || !strings.HasPrefix(resp.Volume.VolumeId, "test-server/baseDir#test-csi###cred=") { + t.Errorf("test %q failed: got volume ID %q, expected it to start with prefix %q", test.name, resp.Volume.VolumeId, "test-server/baseDir#test-csi###cred=") + } + } else if !reflect.DeepEqual(resp, test.resp) { t.Errorf("test %q failed: got resp %+v, expected %+v", test.name, resp, test.resp) } if !test.expectErr { diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index 0528003ea01..ec178ad90eb 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -17,6 +17,7 @@ limitations under the License. package smb import ( + "bufio" "encoding/base64" "fmt" "os" @@ -40,7 +41,6 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" ) -// NodePublishVolume mount the volume from staging to target path func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { volCap := req.GetVolumeCapability() if volCap == nil { @@ -51,8 +51,11 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu return nil, status.Error(codes.InvalidArgument, "Volume ID missing in request") } - target := req.GetTargetPath() - if len(target) == 0 { + // Strip cred hash suffix if present + cleanID := strings.SplitN(volumeID, "#cred=", 2)[0] + + targetPath := req.GetTargetPath() + if len(targetPath) == 0 { return nil, status.Error(codes.InvalidArgument, "Target path not provided") } @@ -60,18 +63,19 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu if context != nil && strings.EqualFold(context[ephemeralField], trueValue) { // ephemeral volume util.SetKeyValueInMap(context, secretNamespaceField, context[podNamespaceField]) - klog.V(2).Infof("NodePublishVolume: ephemeral volume(%s) mount on %s", volumeID, target) + klog.V(2).Infof("NodePublishVolume: ephemeral volume(%s) mount on %s", volumeID, targetPath) _, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{ - StagingTargetPath: target, + StagingTargetPath: targetPath, VolumeContext: context, VolumeCapability: volCap, - VolumeId: volumeID, + VolumeId: cleanID, }) return &csi.NodePublishVolumeResponse{}, err } - source := req.GetStagingTargetPath() - if len(source) == 0 { + // Get staging path + stagingPath := req.GetStagingTargetPath() + if len(stagingPath) == 0 { return nil, status.Error(codes.InvalidArgument, "Staging target not provided") } @@ -80,31 +84,31 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu mountOptions = append(mountOptions, "ro") } - mnt, err := d.ensureMountPoint(target) + mnt, err := d.ensureMountPoint(targetPath) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not mount target %q: %v", target, err) + return nil, status.Errorf(codes.Internal, "Could not mount target %q: %v", targetPath, err) } if mnt { - klog.V(2).Infof("NodePublishVolume: %s is already mounted", target) + klog.V(2).Infof("NodePublishVolume: %s is already mounted", targetPath) return &csi.NodePublishVolumeResponse{}, nil } - if err = preparePublishPath(target, d.mounter); err != nil { - return nil, fmt.Errorf("prepare publish failed for %s with error: %v", target, err) + if err = preparePublishPath(targetPath, d.mounter); err != nil { + return nil, fmt.Errorf("prepare publish failed for %s with error: %v", targetPath, err) } - klog.V(2).Infof("NodePublishVolume: mounting %s at %s with mountOptions: %v volumeID(%s)", source, target, mountOptions, volumeID) - if err := d.mounter.Mount(source, target, "", mountOptions); err != nil { - if removeErr := os.Remove(target); removeErr != nil { - return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", target, removeErr) + klog.V(2).Infof("NodePublishVolume: bind mounting %s to %s with options: %v", stagingPath, targetPath, mountOptions) + if err := d.mounter.Mount(stagingPath, targetPath, "", mountOptions); err != nil { + if removeErr := os.Remove(targetPath); removeErr != nil { + return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", targetPath, removeErr) } - return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", source, target, err) + return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", stagingPath, targetPath, err) } - klog.V(2).Infof("NodePublishVolume: mount %s at %s volumeID(%s) successfully", source, target, volumeID) + + klog.V(2).Infof("NodePublishVolume: mount %s at %s volumeID(%s) successfully", stagingPath, targetPath, volumeID) return &csi.NodePublishVolumeResponse{}, nil } -// NodeUnpublishVolume unmount the volume from the target path func (d *Driver) NodeUnpublishVolume(_ context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -115,12 +119,28 @@ func (d *Driver) NodeUnpublishVolume(_ context.Context, req *csi.NodeUnpublishVo return nil, status.Error(codes.InvalidArgument, "Target path missing in request") } - klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s on %s", volumeID, targetPath) - err := CleanupMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/) - if err != nil { + klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s from %s", volumeID, targetPath) + + notMnt, err := d.mounter.IsLikelyNotMountPoint(targetPath) + if err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to check mount point %q: %v", targetPath, err) + } + if notMnt { + klog.V(2).Infof("NodeUnpublishVolume: target %s is already unmounted", targetPath) + if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove target path %q: %v", targetPath, err) + } + return &csi.NodeUnpublishVolumeResponse{}, nil + } + + if err := d.mounter.Unmount(targetPath); err != nil { return nil, status.Errorf(codes.Internal, "failed to unmount target %q: %v", targetPath, err) } - klog.V(2).Infof("NodeUnpublishVolume: unmount volume %s on %s successfully", volumeID, targetPath) + if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove target path %q after unmount: %v", targetPath, err) + } + + klog.V(2).Infof("NodeUnpublishVolume: successfully unmounted and removed %s for volume %s", targetPath, volumeID) return &csi.NodeUnpublishVolumeResponse{}, nil } @@ -142,8 +162,8 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe } context := req.GetVolumeContext() - mountFlags := req.GetVolumeCapability().GetMount().GetMountFlags() - volumeMountGroup := req.GetVolumeCapability().GetMount().GetVolumeMountGroup() + mountFlags := volumeCapability.GetMount().GetMountFlags() + volumeMountGroup := volumeCapability.GetMount().GetVolumeMountGroup() secrets := req.GetSecrets() gidPresent := checkGidPresentInMountFlags(mountFlags) @@ -199,7 +219,6 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe mountFlags = strings.Split(ephemeralVolMountOptions, ",") } - // in guest login, username and password options are not needed requireUsernamePwdOption := !hasGuestMountOptions(mountFlags) if ephemeralVol && requireUsernamePwdOption { klog.V(2).Infof("NodeStageVolume: getting username and password from secret %s in namespace %s", secretName, secretNamespace) @@ -264,7 +283,6 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe if subDir != "" { // replace pv/pvc name namespace metadata in subDir subDir = replaceWithMap(subDir, subDirReplaceMap) - source = strings.TrimRight(source, "/") source = fmt.Sprintf("%s/%s", source, subDir) } @@ -281,7 +299,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe return &csi.NodeStageVolumeResponse{}, nil } -// NodeUnstageVolume unmount the volume from the staging path +// NodeUnstageVolume unmounts the volume from the staging path func (d *Driver) NodeUnstageVolume(_ context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -298,16 +316,51 @@ func (d *Driver) NodeUnstageVolume(_ context.Context, req *csi.NodeUnstageVolume } defer d.volumeLocks.Release(lockKey) - klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint on %s with volume %s", stagingTargetPath, volumeID) - if err := CleanupSMBMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/, volumeID); err != nil { - return nil, status.Errorf(codes.Internal, "failed to unmount staging target %q: %v", stagingTargetPath, err) + // Check if any other mounts still reference the staging path + f, err := os.Open("/proc/mounts") + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to open /proc/mounts: %v", err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + refCount := 0 + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 2 { + mountPoint := fields[1] + if strings.HasPrefix(mountPoint, stagingTargetPath) && mountPoint != stagingTargetPath { + refCount++ + } + } + } + if refCount > 0 { + klog.V(2).Infof("NodeUnstageVolume: staging path %s is still in use by %d other mounts", stagingTargetPath, refCount) + return &csi.NodeUnstageVolumeResponse{}, nil + } + + notMnt, err := d.mounter.IsLikelyNotMountPoint(stagingTargetPath) + if err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to check mount point %q: %v", stagingTargetPath, err) + } + if notMnt { + klog.V(2).Infof("NodeUnstageVolume: staging path %s is already unmounted", stagingTargetPath) + if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove staging path %q: %v", stagingTargetPath, err) + } + return &csi.NodeUnstageVolumeResponse{}, nil } - if err := deleteKerberosCache(d.krb5CacheDirectory, volumeID); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete kerberos cache: %v", err) + klog.V(2).Infof("NodeUnstageVolume: unmounting %s for volume %s", stagingTargetPath, volumeID) + if err := d.mounter.Unmount(stagingTargetPath); err != nil { + return nil, status.Errorf(codes.Internal, "failed to unmount staging path %q: %v", stagingTargetPath, err) + } + if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove staging path %q after unmount: %v", stagingTargetPath, err) } - klog.V(2).Infof("NodeUnstageVolume: unmount volume %s on %s successfully", volumeID, stagingTargetPath) + klog.V(2).Infof("NodeUnstageVolume: successfully unmounted and cleaned up %s for volume %s", stagingTargetPath, volumeID) return &csi.NodeUnstageVolumeResponse{}, nil } diff --git a/pkg/smb/smb.go b/pkg/smb/smb.go index d0f2365708b..e4c773d4809 100644 --- a/pkg/smb/smb.go +++ b/pkg/smb/smb.go @@ -189,6 +189,7 @@ func (d *Driver) Run(endpoint, _ string, testMode bool) { csi.ControllerServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER, csi.ControllerServiceCapability_RPC_CLONE_VOLUME, csi.ControllerServiceCapability_RPC_EXPAND_VOLUME, + csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, }) d.AddVolumeCapabilityAccessModes([]csi.VolumeCapability_AccessMode_Mode{ From c954b516e824588679ef6de4bf500cf2c80b946e Mon Sep 17 00:00:00 2001 From: Matt Olson Date: Sat, 19 Apr 2025 17:01:03 -0400 Subject: [PATCH 2/3] fix(controller): remove PUBLISH_UNPUBLISH_VOLUME capability for SMB The SMB CSI driver does not implement ControllerPublishVolume or ControllerUnpublishVolume since SMB shares are mounted directly by nodes and do not require controller-side attach/detach. Removing the ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME capability avoids advertising unsupported functionality and allows the sanity test suite to pass. This change does not impact node-side STAGE_UNSTAGE_VOLUME support, which remains fully functional. --- pkg/smb/smb.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/smb/smb.go b/pkg/smb/smb.go index e4c773d4809..d0f2365708b 100644 --- a/pkg/smb/smb.go +++ b/pkg/smb/smb.go @@ -189,7 +189,6 @@ func (d *Driver) Run(endpoint, _ string, testMode bool) { csi.ControllerServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER, csi.ControllerServiceCapability_RPC_CLONE_VOLUME, csi.ControllerServiceCapability_RPC_EXPAND_VOLUME, - csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, }) d.AddVolumeCapabilityAccessModes([]csi.VolumeCapability_AccessMode_Mode{ From b472e1a2ff3d6359dbf3745ecb5c68449703fe25 Mon Sep 17 00:00:00 2001 From: Matt Olson Date: Sun, 20 Apr 2025 16:14:55 -0400 Subject: [PATCH 3/3] fix(node): add OS-specific HasMountReferences implementations to support Windows and Darwin Refactored NodeUnstageVolume to use a platform-aware HasMountReferences() helper, moving Linux-specific /proc/mounts parsing into smb_common_linux.go and stubbing it out for Windows and Darwin to prevent test failures on non-Linux environments. - Fixes Windows e2e failures due to /proc/mounts not being available - Ensures future compatibility for multi-platform CSI driver builds - Preserves original Linux mount reference tracking behavior --- pkg/smb/nodeserver.go | 24 ++++-------------------- pkg/smb/smb_common_darwin.go | 5 +++++ pkg/smb/smb_common_linux.go | 23 +++++++++++++++++++++++ pkg/smb/smb_common_windows.go | 5 +++++ 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index ec178ad90eb..56e9f78b85f 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -17,7 +17,6 @@ limitations under the License. package smb import ( - "bufio" "encoding/base64" "fmt" "os" @@ -316,27 +315,12 @@ func (d *Driver) NodeUnstageVolume(_ context.Context, req *csi.NodeUnstageVolume } defer d.volumeLocks.Release(lockKey) - // Check if any other mounts still reference the staging path - f, err := os.Open("/proc/mounts") + inUse, err := HasMountReferences(stagingTargetPath) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to open /proc/mounts: %v", err) - } - defer f.Close() - - scanner := bufio.NewScanner(f) - refCount := 0 - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) >= 2 { - mountPoint := fields[1] - if strings.HasPrefix(mountPoint, stagingTargetPath) && mountPoint != stagingTargetPath { - refCount++ - } - } + return nil, status.Errorf(codes.Internal, "failed to check mount references: %v", err) } - if refCount > 0 { - klog.V(2).Infof("NodeUnstageVolume: staging path %s is still in use by %d other mounts", stagingTargetPath, refCount) + if inUse { + klog.V(2).Infof("NodeUnstageVolume: staging path %s is still in use by other mounts", stagingTargetPath) return &csi.NodeUnstageVolumeResponse{}, nil } diff --git a/pkg/smb/smb_common_darwin.go b/pkg/smb/smb_common_darwin.go index f4b4fcc3270..c24408da608 100644 --- a/pkg/smb/smb_common_darwin.go +++ b/pkg/smb/smb_common_darwin.go @@ -48,3 +48,8 @@ func prepareStagePath(path string, m *mount.SafeFormatAndMount) error { func Mkdir(m *mount.SafeFormatAndMount, name string, perm os.FileMode) error { return os.Mkdir(name, perm) } + +func HasMountReferences(stagingTargetPath string) (bool, error) { + // Stubbed for Windows/macOS — cannot inspect bind mounts + return false, nil +} diff --git a/pkg/smb/smb_common_linux.go b/pkg/smb/smb_common_linux.go index c6b28fe394d..098ef216369 100644 --- a/pkg/smb/smb_common_linux.go +++ b/pkg/smb/smb_common_linux.go @@ -20,7 +20,10 @@ limitations under the License. package smb import ( + "bufio" + "fmt" "os" + "strings" mount "k8s.io/mount-utils" ) @@ -48,3 +51,23 @@ func prepareStagePath(_ string, _ *mount.SafeFormatAndMount) error { func Mkdir(_ *mount.SafeFormatAndMount, name string, perm os.FileMode) error { return os.Mkdir(name, perm) } + +func HasMountReferences(stagingTargetPath string) (bool, error) { + f, err := os.Open("/proc/mounts") + if err != nil { + return false, fmt.Errorf("failed to open /proc/mounts: %v", err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) >= 2 { + mountPoint := fields[1] + if strings.HasPrefix(mountPoint, stagingTargetPath) && mountPoint != stagingTargetPath { + return true, nil + } + } + } + return false, nil +} diff --git a/pkg/smb/smb_common_windows.go b/pkg/smb/smb_common_windows.go index 61a86eeeff0..507bc00d92d 100644 --- a/pkg/smb/smb_common_windows.go +++ b/pkg/smb/smb_common_windows.go @@ -87,3 +87,8 @@ func Mkdir(m *mount.SafeFormatAndMount, name string, perm os.FileMode) error { } return fmt.Errorf("could not cast to csi proxy class") } + +func HasMountReferences(stagingTargetPath string) (bool, error) { + // Stubbed for Windows/macOS — cannot inspect bind mounts + return false, nil +}