diff --git a/pkg/csi-common/driver.go b/pkg/csi-common/driver.go index d70c58cafde..9046b3483b4 100644 --- a/pkg/csi-common/driver.go +++ b/pkg/csi-common/driver.go @@ -94,7 +94,7 @@ func (d *CSIDriver) AddControllerServiceCapabilities(cl []csi.ControllerServiceC csc = append(csc, NewControllerServiceCapability(c)) } - d.Cap = csc + d.Cap = append(d.Cap, csc...) } func (d *CSIDriver) AddNodeServiceCapabilities(nl []csi.NodeServiceCapability_RPC_Type) { @@ -103,7 +103,7 @@ func (d *CSIDriver) AddNodeServiceCapabilities(nl []csi.NodeServiceCapability_RP klog.V(2).Infof("Enabling node service capability: %v", n.String()) nsc = append(nsc, NewNodeServiceCapability(n)) } - d.NSCap = nsc + d.NSCap = append(d.NSCap, nsc...) } func (d *CSIDriver) AddVolumeCapabilityAccessModes(vc []csi.VolumeCapability_AccessMode_Mode) []*csi.VolumeCapability_AccessMode { diff --git a/pkg/smb/controllerserver.go b/pkg/smb/controllerserver.go index b16d0b88eaa..ead2c14eeea 100644 --- a/pkg/smb/controllerserver.go +++ b/pkg/smb/controllerserver.go @@ -18,6 +18,8 @@ package smb import ( "context" + "crypto/sha256" + "encoding/hex" "fmt" "io/fs" "os" @@ -85,6 +87,17 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } secrets := req.GetSecrets() + username := strings.TrimSpace(secrets["username"]) + password := strings.TrimSpace(secrets["password"]) + if username != "" || password != "" { + hashKey := fmt.Sprintf("%s|%s", username, password) + hash := sha256.Sum256([]byte(hashKey)) + hashStr := hex.EncodeToString(hash[:8]) + smbVol.id = fmt.Sprintf("%s#cred=%s", getVolumeIDFromSmbVol(smbVol), hashStr) + } else { + smbVol.id = getVolumeIDFromSmbVol(smbVol) + } + createSubDir := len(secrets) > 0 if len(smbVol.uuid) > 0 { klog.V(2).Infof("create subdirectory(%s) if not exists", smbVol.subDir) diff --git a/pkg/smb/controllerserver_test.go b/pkg/smb/controllerserver_test.go index 64af250ef1a..17741470282 100644 --- a/pkg/smb/controllerserver_test.go +++ b/pkg/smb/controllerserver_test.go @@ -23,6 +23,7 @@ import ( "path/filepath" "reflect" "runtime" + "strings" "testing" "github.com/container-storage-interface/spec/lib/go/csi" @@ -203,7 +204,11 @@ func TestCreateVolume(t *testing.T) { if !test.expectErr && err != nil { t.Errorf("test %q failed: %v", test.name, err) } - if !reflect.DeepEqual(resp, test.resp) { + if !test.expectErr && test.name == "valid defaults" { + if resp.Volume == nil || !strings.HasPrefix(resp.Volume.VolumeId, "test-server/baseDir#test-csi###cred=") { + t.Errorf("test %q failed: got volume ID %q, expected it to start with prefix %q", test.name, resp.Volume.VolumeId, "test-server/baseDir#test-csi###cred=") + } + } else if !reflect.DeepEqual(resp, test.resp) { t.Errorf("test %q failed: got resp %+v, expected %+v", test.name, resp, test.resp) } if !test.expectErr { diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index 0528003ea01..56e9f78b85f 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -40,7 +40,6 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" ) -// NodePublishVolume mount the volume from staging to target path func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) { volCap := req.GetVolumeCapability() if volCap == nil { @@ -51,8 +50,11 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu return nil, status.Error(codes.InvalidArgument, "Volume ID missing in request") } - target := req.GetTargetPath() - if len(target) == 0 { + // Strip cred hash suffix if present + cleanID := strings.SplitN(volumeID, "#cred=", 2)[0] + + targetPath := req.GetTargetPath() + if len(targetPath) == 0 { return nil, status.Error(codes.InvalidArgument, "Target path not provided") } @@ -60,18 +62,19 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu if context != nil && strings.EqualFold(context[ephemeralField], trueValue) { // ephemeral volume util.SetKeyValueInMap(context, secretNamespaceField, context[podNamespaceField]) - klog.V(2).Infof("NodePublishVolume: ephemeral volume(%s) mount on %s", volumeID, target) + klog.V(2).Infof("NodePublishVolume: ephemeral volume(%s) mount on %s", volumeID, targetPath) _, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{ - StagingTargetPath: target, + StagingTargetPath: targetPath, VolumeContext: context, VolumeCapability: volCap, - VolumeId: volumeID, + VolumeId: cleanID, }) return &csi.NodePublishVolumeResponse{}, err } - source := req.GetStagingTargetPath() - if len(source) == 0 { + // Get staging path + stagingPath := req.GetStagingTargetPath() + if len(stagingPath) == 0 { return nil, status.Error(codes.InvalidArgument, "Staging target not provided") } @@ -80,31 +83,31 @@ func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolu mountOptions = append(mountOptions, "ro") } - mnt, err := d.ensureMountPoint(target) + mnt, err := d.ensureMountPoint(targetPath) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not mount target %q: %v", target, err) + return nil, status.Errorf(codes.Internal, "Could not mount target %q: %v", targetPath, err) } if mnt { - klog.V(2).Infof("NodePublishVolume: %s is already mounted", target) + klog.V(2).Infof("NodePublishVolume: %s is already mounted", targetPath) return &csi.NodePublishVolumeResponse{}, nil } - if err = preparePublishPath(target, d.mounter); err != nil { - return nil, fmt.Errorf("prepare publish failed for %s with error: %v", target, err) + if err = preparePublishPath(targetPath, d.mounter); err != nil { + return nil, fmt.Errorf("prepare publish failed for %s with error: %v", targetPath, err) } - klog.V(2).Infof("NodePublishVolume: mounting %s at %s with mountOptions: %v volumeID(%s)", source, target, mountOptions, volumeID) - if err := d.mounter.Mount(source, target, "", mountOptions); err != nil { - if removeErr := os.Remove(target); removeErr != nil { - return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", target, removeErr) + klog.V(2).Infof("NodePublishVolume: bind mounting %s to %s with options: %v", stagingPath, targetPath, mountOptions) + if err := d.mounter.Mount(stagingPath, targetPath, "", mountOptions); err != nil { + if removeErr := os.Remove(targetPath); removeErr != nil { + return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", targetPath, removeErr) } - return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", source, target, err) + return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", stagingPath, targetPath, err) } - klog.V(2).Infof("NodePublishVolume: mount %s at %s volumeID(%s) successfully", source, target, volumeID) + + klog.V(2).Infof("NodePublishVolume: mount %s at %s volumeID(%s) successfully", stagingPath, targetPath, volumeID) return &csi.NodePublishVolumeResponse{}, nil } -// NodeUnpublishVolume unmount the volume from the target path func (d *Driver) NodeUnpublishVolume(_ context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) { volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -115,12 +118,28 @@ func (d *Driver) NodeUnpublishVolume(_ context.Context, req *csi.NodeUnpublishVo return nil, status.Error(codes.InvalidArgument, "Target path missing in request") } - klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s on %s", volumeID, targetPath) - err := CleanupMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/) - if err != nil { + klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s from %s", volumeID, targetPath) + + notMnt, err := d.mounter.IsLikelyNotMountPoint(targetPath) + if err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to check mount point %q: %v", targetPath, err) + } + if notMnt { + klog.V(2).Infof("NodeUnpublishVolume: target %s is already unmounted", targetPath) + if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove target path %q: %v", targetPath, err) + } + return &csi.NodeUnpublishVolumeResponse{}, nil + } + + if err := d.mounter.Unmount(targetPath); err != nil { return nil, status.Errorf(codes.Internal, "failed to unmount target %q: %v", targetPath, err) } - klog.V(2).Infof("NodeUnpublishVolume: unmount volume %s on %s successfully", volumeID, targetPath) + if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove target path %q after unmount: %v", targetPath, err) + } + + klog.V(2).Infof("NodeUnpublishVolume: successfully unmounted and removed %s for volume %s", targetPath, volumeID) return &csi.NodeUnpublishVolumeResponse{}, nil } @@ -142,8 +161,8 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe } context := req.GetVolumeContext() - mountFlags := req.GetVolumeCapability().GetMount().GetMountFlags() - volumeMountGroup := req.GetVolumeCapability().GetMount().GetVolumeMountGroup() + mountFlags := volumeCapability.GetMount().GetMountFlags() + volumeMountGroup := volumeCapability.GetMount().GetVolumeMountGroup() secrets := req.GetSecrets() gidPresent := checkGidPresentInMountFlags(mountFlags) @@ -199,7 +218,6 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe mountFlags = strings.Split(ephemeralVolMountOptions, ",") } - // in guest login, username and password options are not needed requireUsernamePwdOption := !hasGuestMountOptions(mountFlags) if ephemeralVol && requireUsernamePwdOption { klog.V(2).Infof("NodeStageVolume: getting username and password from secret %s in namespace %s", secretName, secretNamespace) @@ -264,7 +282,6 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe if subDir != "" { // replace pv/pvc name namespace metadata in subDir subDir = replaceWithMap(subDir, subDirReplaceMap) - source = strings.TrimRight(source, "/") source = fmt.Sprintf("%s/%s", source, subDir) } @@ -281,7 +298,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe return &csi.NodeStageVolumeResponse{}, nil } -// NodeUnstageVolume unmount the volume from the staging path +// NodeUnstageVolume unmounts the volume from the staging path func (d *Driver) NodeUnstageVolume(_ context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) { volumeID := req.GetVolumeId() if len(volumeID) == 0 { @@ -298,16 +315,36 @@ func (d *Driver) NodeUnstageVolume(_ context.Context, req *csi.NodeUnstageVolume } defer d.volumeLocks.Release(lockKey) - klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint on %s with volume %s", stagingTargetPath, volumeID) - if err := CleanupSMBMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/, volumeID); err != nil { - return nil, status.Errorf(codes.Internal, "failed to unmount staging target %q: %v", stagingTargetPath, err) + inUse, err := HasMountReferences(stagingTargetPath) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to check mount references: %v", err) + } + if inUse { + klog.V(2).Infof("NodeUnstageVolume: staging path %s is still in use by other mounts", stagingTargetPath) + return &csi.NodeUnstageVolumeResponse{}, nil } - if err := deleteKerberosCache(d.krb5CacheDirectory, volumeID); err != nil { - return nil, status.Errorf(codes.Internal, "failed to delete kerberos cache: %v", err) + notMnt, err := d.mounter.IsLikelyNotMountPoint(stagingTargetPath) + if err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to check mount point %q: %v", stagingTargetPath, err) + } + if notMnt { + klog.V(2).Infof("NodeUnstageVolume: staging path %s is already unmounted", stagingTargetPath) + if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove staging path %q: %v", stagingTargetPath, err) + } + return &csi.NodeUnstageVolumeResponse{}, nil + } + + klog.V(2).Infof("NodeUnstageVolume: unmounting %s for volume %s", stagingTargetPath, volumeID) + if err := d.mounter.Unmount(stagingTargetPath); err != nil { + return nil, status.Errorf(codes.Internal, "failed to unmount staging path %q: %v", stagingTargetPath, err) + } + if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) { + return nil, status.Errorf(codes.Internal, "failed to remove staging path %q after unmount: %v", stagingTargetPath, err) } - klog.V(2).Infof("NodeUnstageVolume: unmount volume %s on %s successfully", volumeID, stagingTargetPath) + klog.V(2).Infof("NodeUnstageVolume: successfully unmounted and cleaned up %s for volume %s", stagingTargetPath, volumeID) return &csi.NodeUnstageVolumeResponse{}, nil } diff --git a/pkg/smb/smb_common_darwin.go b/pkg/smb/smb_common_darwin.go index f4b4fcc3270..c24408da608 100644 --- a/pkg/smb/smb_common_darwin.go +++ b/pkg/smb/smb_common_darwin.go @@ -48,3 +48,8 @@ func prepareStagePath(path string, m *mount.SafeFormatAndMount) error { func Mkdir(m *mount.SafeFormatAndMount, name string, perm os.FileMode) error { return os.Mkdir(name, perm) } + +func HasMountReferences(stagingTargetPath string) (bool, error) { + // Stubbed for Windows/macOS — cannot inspect bind mounts + return false, nil +} diff --git a/pkg/smb/smb_common_linux.go b/pkg/smb/smb_common_linux.go index c6b28fe394d..098ef216369 100644 --- a/pkg/smb/smb_common_linux.go +++ b/pkg/smb/smb_common_linux.go @@ -20,7 +20,10 @@ limitations under the License. package smb import ( + "bufio" + "fmt" "os" + "strings" mount "k8s.io/mount-utils" ) @@ -48,3 +51,23 @@ func prepareStagePath(_ string, _ *mount.SafeFormatAndMount) error { func Mkdir(_ *mount.SafeFormatAndMount, name string, perm os.FileMode) error { return os.Mkdir(name, perm) } + +func HasMountReferences(stagingTargetPath string) (bool, error) { + f, err := os.Open("/proc/mounts") + if err != nil { + return false, fmt.Errorf("failed to open /proc/mounts: %v", err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) >= 2 { + mountPoint := fields[1] + if strings.HasPrefix(mountPoint, stagingTargetPath) && mountPoint != stagingTargetPath { + return true, nil + } + } + } + return false, nil +} diff --git a/pkg/smb/smb_common_windows.go b/pkg/smb/smb_common_windows.go index 61a86eeeff0..507bc00d92d 100644 --- a/pkg/smb/smb_common_windows.go +++ b/pkg/smb/smb_common_windows.go @@ -87,3 +87,8 @@ func Mkdir(m *mount.SafeFormatAndMount, name string, perm os.FileMode) error { } return fmt.Errorf("could not cast to csi proxy class") } + +func HasMountReferences(stagingTargetPath string) (bool, error) { + // Stubbed for Windows/macOS — cannot inspect bind mounts + return false, nil +}