diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml index 11602bd078..161326f6d1 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml @@ -2133,6 +2133,19 @@ The following sample tries to open a connection to an invalid database to simula Returns 0 if the connection is inactive on the client side. + + + Gets or sets the instance for customizing the SSPI context. If not set, the default for the platform will be used. + + + An instance. + + + + The SspiContextProvider is a part of the connection pool key. Care should be taken when using this property to ensure the implementation returns a stable identity per resource. + + + Indicates the state of the during the most recent network operation performed on the connection. diff --git a/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml b/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml new file mode 100644 index 0000000000..4623d86052 --- /dev/null +++ b/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml @@ -0,0 +1,28 @@ + + + + + Provides parameters used during SSPI authentication. + + + Creates an instance of the SspiAuthenticationParameters. + The name of the server. + The resource (often the server service principal name). + + + Gets the resource (often the server service principal name). + + + Gets the server name. + + + Gets or sets the user id if available. + + + Gets or sets the database name if available. + + + Gets or sets the password if available. + + + diff --git a/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml b/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml new file mode 100644 index 0000000000..2ea58cb80b --- /dev/null +++ b/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml @@ -0,0 +1,20 @@ + + + + + Provides the ability to customize SSPI context generation. + + + Creates an instance of the SSPIContextProvider. + + + Generates an SSPI outgoing blob given the incoming blob. + Incoming blob + Outgoing blob + Gets the authentication parameters associated with this connection. + + true if the context was generated, otherwise false. + + + + diff --git a/src/Microsoft.Data.SqlClient.sln b/src/Microsoft.Data.SqlClient.sln index e4d29d999c..bf0cbe8cc2 100644 --- a/src/Microsoft.Data.SqlClient.sln +++ b/src/Microsoft.Data.SqlClient.sln @@ -149,6 +149,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Data.SqlClient", ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventArgs.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventArgs.xml ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventHandler.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventHandler.xml ..\doc\snippets\Microsoft.Data.SqlClient\SqlTransaction.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlTransaction.xml + ..\doc\snippets\Microsoft.Data.SqlClient\SspiAuthenticationParameters.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SspiAuthenticationParameters.xml + ..\doc\snippets\Microsoft.Data.SqlClient\SspiContextProvider.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SspiContextProvider.xml EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Data.SqlClient.DataClassification", "Microsoft.Data.SqlClient.DataClassification", "{5D1F0032-7B0D-4FB6-A969-FCFB25C9EA1D}" diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index 55b68fd9b3..9b00807041 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -931,6 +931,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect [System.ComponentModel.BrowsableAttribute(false)] [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } } + /// + public SspiContextProvider SspiContextProvider { get { throw null; } set { } } /// [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public override string Database { get { throw null; } } @@ -1978,6 +1980,37 @@ public sealed class SqlConfigurableRetryFactory /// public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; } } + /// + public abstract class SspiContextProvider + { + /// + protected abstract bool GenerateContext(System.ReadOnlySpan incomingBlob, System.Buffers.IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); + } + /// + public sealed class SspiAuthenticationParameters + { + /// + public SspiAuthenticationParameters(string serverName, string resource) + { + ServerName = serverName; + Resource = resource; + } + + /// + public string Resource { get; } + + /// + public string ServerName { get; } + + /// + public string UserId { get; set; } + + /// + public string DatabaseName { get; set; } + + /// + public string Password { get; set; } + } } namespace Microsoft.Data.SqlClient.Diagnostics { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 11e7bd71ad..56b7de3f34 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -91,6 +91,7 @@ private static readonly Dictionary private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders; private Func> _accessTokenCallback; + private SspiContextProvider _sspiContextProvider; internal bool HasColumnEncryptionKeyStoreProvidersRegistered => _customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0; @@ -648,7 +649,7 @@ public override string ConnectionString CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions); } } - ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback)); + ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider)); _connectionString = value; // Change _connectionString value only after value is validated CacheConnectionStringProperties(); } @@ -708,7 +709,7 @@ public string AccessToken } // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null, sspiContextProvider: null)); _accessToken = value; } } @@ -731,11 +732,22 @@ public Func + public SspiContextProvider SspiContextProvider + { + get { return _sspiContextProvider; } + set + { + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value)); + _sspiContextProvider = value; + } + } + /// [ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)] [ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)] @@ -1033,7 +1045,7 @@ public SqlCredential Credential _credential = value; // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback, sspiContextProvider: null)); } } @@ -2263,7 +2275,7 @@ public static void ChangePassword(string connectionString, string newPassword) throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null, sspiContextProvider: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); if (connectionOptions.IntegratedSecurity) @@ -2312,7 +2324,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null); SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key); @@ -2350,7 +2362,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString { con?.Dispose(); } - SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null); + SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null); SqlConnectionFactory.SingletonInstance.ClearPool(key); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index c75eb35a0e..8db440a5ac 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -135,6 +135,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa SqlFedAuthToken _fedAuthToken = null; internal byte[] _accessTokenInBytes; internal readonly Func> _accessTokenCallback; + internal readonly SspiContextProvider _sspiContextProvider; private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper; @@ -460,8 +461,8 @@ internal SqlInternalConnectionTds( bool applyTransientFaultHandling = false, string accessToken = null, IDbConnectionPool pool = null, - Func> accessTokenCallback = null) : base(connectionOptions) + Func> accessTokenCallback = null, + SspiContextProvider sspiContextProvider = null) : base(connectionOptions) { #if DEBUG if (reconnectSessionData != null) @@ -514,6 +515,7 @@ internal SqlInternalConnectionTds( } _accessTokenCallback = accessTokenCallback; + _sspiContextProvider = sspiContextProvider; _activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper(); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 79228a3449..b344db128d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -413,7 +413,7 @@ internal void Connect(ServerInfo serverInfo, // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - _authenticationProvider = _physicalStateObj.CreateSspiContextProvider(); + _authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider(); SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | SSPI or Active Directory Authentication Library loaded for SQL Server based integrated authentication"); } diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index d7f280ca33..bd1e17c2d9 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -810,6 +810,8 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden [System.ComponentModel.BrowsableAttribute(false)] [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } } + /// + public SspiContextProvider SspiContextProvider { get { throw null; } set { } } /// [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public override string Database { get { throw null; } } @@ -1959,6 +1961,37 @@ public sealed class SqlConfigurableRetryFactory /// public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; } } + /// + public abstract class SspiContextProvider + { + /// + protected abstract bool GenerateContext(System.ReadOnlySpan incomingBlob, System.Buffers.IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); + } + /// + public sealed class SspiAuthenticationParameters + { + /// + public SspiAuthenticationParameters(string serverName, string resource) + { + ServerName = serverName; + Resource = resource; + } + + /// + public string Resource { get; } + + /// + public string ServerName { get; } + + /// + public string UserId { get; set; } + + /// + public string DatabaseName { get; set; } + + /// + public string Password { get; set; } + } } namespace Microsoft.Data.SqlClient.Server { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs index c2ea5c3981..651f1f1c93 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -89,6 +89,8 @@ private static readonly Dictionary private Func> _accessTokenCallback; + private SspiContextProvider _sspiContextProvider; + internal bool HasColumnEncryptionKeyStoreProvidersRegistered => _customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0; @@ -575,6 +577,17 @@ internal int ConnectRetryInterval get => ((SqlConnectionString)ConnectionOptions).ConnectRetryInterval; } + /// + public SspiContextProvider SspiContextProvider + { + get { return _sspiContextProvider; } + set + { + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value)); + _sspiContextProvider = value; + } + } + /// [DefaultValue("")] #pragma warning disable 618 // ignore obsolete warning about RecommendedAsConfigurable to use SettingsBindableAttribute @@ -643,7 +656,7 @@ public override string ConnectionString CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions); } } - ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback)); + ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider)); _connectionString = value; // Change _connectionString value only after value is validated CacheConnectionStringProperties(); } @@ -703,7 +716,7 @@ public string AccessToken _accessToken = value; // Need to call ConnectionString_Set to do proper pool group check - ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null)); + ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null, sspiContextProvider: null)); } } @@ -725,7 +738,7 @@ public Func> _accessTokenCallback; + internal readonly SspiContextProvider _sspiContextProvider; private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper; @@ -472,8 +473,8 @@ internal SqlInternalConnectionTds( bool applyTransientFaultHandling = false, string accessToken = null, IDbConnectionPool pool = null, - Func> accessTokenCallback = null) + Func> accessTokenCallback = null, + SspiContextProvider sspiContextProvider = null) : base(connectionOptions) { #if DEBUG @@ -525,6 +526,7 @@ internal SqlInternalConnectionTds( } _accessTokenCallback = accessTokenCallback; + _sspiContextProvider = sspiContextProvider; _activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper(); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index f871da2c6c..ce32b6d873 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -411,7 +411,7 @@ internal void Connect(ServerInfo serverInfo, // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - _authenticationProvider = _physicalStateObj.CreateSspiContextProvider(); + _authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider(); if (!string.IsNullOrEmpty(serverInfo.ServerSPN)) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs index 207c0a8e1a..31da3521df 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs @@ -18,10 +18,12 @@ internal class SqlConnectionPoolKey : DbConnectionPoolKey private readonly SqlCredential _credential; private readonly string _accessToken; private Func> _accessTokenCallback; + private SspiContextProvider _sspiContextProvider; internal SqlCredential Credential => _credential; internal string AccessToken => _accessToken; internal Func> AccessTokenCallback => _accessTokenCallback; + internal SspiContextProvider SspiContextProvider => _sspiContextProvider; internal override string ConnectionString { @@ -33,12 +35,18 @@ internal override string ConnectionString } } - internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken, Func> accessTokenCallback) : base(connectionString) + internal SqlConnectionPoolKey( + string connectionString, + SqlCredential credential, + string accessToken, + Func> accessTokenCallback, + SspiContextProvider sspiContextProvider) : base(connectionString) { Debug.Assert(credential == null || accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time."); _credential = credential; _accessToken = accessToken; _accessTokenCallback = accessTokenCallback; + _sspiContextProvider = sspiContextProvider; CalculateHashCode(); } @@ -47,6 +55,8 @@ private SqlConnectionPoolKey(SqlConnectionPoolKey key) : base(key) _credential = key.Credential; _accessToken = key.AccessToken; _accessTokenCallback = key._accessTokenCallback; + _sspiContextProvider = key._sspiContextProvider; + CalculateHashCode(); } @@ -61,7 +71,8 @@ public override bool Equals(object obj) && _credential == key._credential && ConnectionString == key.ConnectionString && _accessTokenCallback == key._accessTokenCallback - && string.CompareOrdinal(_accessToken, key._accessToken) == 0); + && string.CompareOrdinal(_accessToken, key._accessToken) == 0 + && _sspiContextProvider == key._sspiContextProvider); } public override int GetHashCode() @@ -94,6 +105,11 @@ private void CalculateHashCode() _hashValue = _hashValue * 17 + _accessTokenCallback.GetHashCode(); } } + + if (_sspiContextProvider != null) + { + _hashValue = _hashValue * 17 + _sspiContextProvider.GetHashCode(); + } } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs index 5935b149c8..1cc4af3e9c 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs @@ -49,7 +49,7 @@ private void LoadSSPILibrary() } } - protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) + protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) { #if NETFRAMEWORK SNIHandle handle = _physicalStateObj.Handle; @@ -62,7 +62,7 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlo var sendLength = s_maxSSPILength; var outBuff = outgoingBlobWriter.GetSpan((int)sendLength); - if (0 != SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource)) + if (SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource) != 0) { return false; } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs index 5dc52010b3..a255712238 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs @@ -12,7 +12,7 @@ internal sealed class NegotiateSspiContextProvider : SspiContextProvider { private NegotiateAuthentication? _negotiateAuth = null; - protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) + protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) { NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials; @@ -21,7 +21,7 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlo // Log session id, status code and the actual SPN used in the negotiation SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSspiContextProvider), - nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName); + nameof(GenerateContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName); if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs index dce0858360..ad6c92853f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs @@ -2,22 +2,29 @@ namespace Microsoft.Data.SqlClient { - internal sealed class SspiAuthenticationParameters + /// + public sealed class SspiAuthenticationParameters { + /// public SspiAuthenticationParameters(string serverName, string resource) { ServerName = serverName; Resource = resource; } + /// public string Resource { get; } + /// public string ServerName { get; } + /// public string? UserId { get; set; } + /// public string? DatabaseName { get; set; } + /// public string? Password { get; set; } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs index ff83422f10..7b339f285d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs @@ -6,12 +6,18 @@ namespace Microsoft.Data.SqlClient { - internal abstract class SspiContextProvider + /// + public abstract class SspiContextProvider { private TdsParser _parser = null!; private ServerInfo _serverInfo = null!; private protected TdsParserStateObject _physicalStateObj = null!; + /// + protected SspiContextProvider() + { + } + internal void Initialize(ServerInfo serverInfo, TdsParserStateObject physicalStateObj, TdsParser parser) { _parser = parser; @@ -25,13 +31,14 @@ private protected virtual void Initialize() { } - protected abstract bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); + /// + protected abstract bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string serverSpn) { using var _ = TrySNIEventScope.Create(nameof(SspiContextProvider)); - if (!RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn)) + if (!RunGenerateContext(receivedBuff, outgoingBlobWriter, serverSpn)) { // If we've hit here, the SSPI context provider implementation failed to generate the SSPI context. SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT); @@ -44,7 +51,7 @@ internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outg foreach (var serverSpn in serverSpns) { - if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn)) + if (RunGenerateContext(receivedBuff, outgoingBlobWriter, serverSpn)) { return; } @@ -54,7 +61,7 @@ internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outg SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT); } - private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, string serverSpn) + private bool RunGenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, string serverSpn) { var options = _parser.Connection.ConnectionOptions; var authParams = new SspiAuthenticationParameters(options.DataSource, serverSpn) @@ -66,9 +73,9 @@ private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBuff try { - SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateSspiClientContext), serverSpn); + SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateContext), serverSpn); - return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams); + return GenerateContext(incomingBlob, outgoingBlobWriter, authParams); } catch (Exception e) { @@ -77,7 +84,7 @@ private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBuff } } - protected void SSPIError(string error, string procedure) + private protected void SSPIError(string error, string procedure) { Debug.Assert(!string.IsNullOrEmpty(procedure), "TdsParser.SSPIError called with an empty or null procedure string"); Debug.Assert(!string.IsNullOrEmpty(error), "TdsParser.SSPIError called with an empty or null error string"); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs index ae4c7bc98d..00a86bdeac 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs @@ -103,7 +103,7 @@ protected override DbConnectionInternal CreateConnection( // This first connection is established to SqlExpress to get the instance name // of the UserInstance. SqlConnectionString sseopt = new SqlConnectionString(opt, opt.DataSource, userInstance: true, setEnlistValue: false); - sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling); + sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling, sspiContextProvider: key.SspiContextProvider); // NOTE: Retrieve here. This user instance name will be used below to connect to the Sql Express User Instance. instanceName = sseConnection.InstanceName; @@ -157,7 +157,8 @@ protected override DbConnectionInternal CreateConnection( applyTransientFaultHandling, key.AccessToken, pool, - key.AccessTokenCallback); + key.AccessTokenCallback, + key.SspiContextProvider); } protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs index e043b2253c..af650a408e 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Buffers; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests @@ -56,6 +58,39 @@ public static void IntegratedAuthenticationTest_ServerSPN() TryOpenConnectionWithIntegratedAuthentication(builder.ConnectionString); } + [ConditionalFact(nameof(IsIntegratedSecurityEnvironmentSet), nameof(AreConnectionStringsSetup))] + public static void CustomSspiContextGeneratorTest() + { + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + builder.IntegratedSecurity = true; + Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out int port, out string instanceName)); + // Build the SPN for the server we are connecting to + builder.ServerSPN = $"MSSQLSvc/{DataTestUtility.GetMachineFQDN(hostname)}"; + if (!string.IsNullOrWhiteSpace(instanceName)) + { + builder.ServerSPN += ":" + instanceName; + } + + using SqlConnection conn = new(builder.ConnectionString) + { + SspiContextProvider = new TestSspiContextProvider(), + }; + + try + { + conn.Open(); + + Assert.Fail("Expected to use custom SSPI context provider"); + } + catch (SspiTestException sspi) + { + Assert.Equal(sspi.AuthParams.ServerName, builder.DataSource); + Assert.Equal(sspi.AuthParams.DatabaseName, builder.InitialCatalog); + Assert.Equal(sspi.AuthParams.UserId, builder.UserID); + Assert.Equal(sspi.AuthParams.Password, builder.Password); + } + } + private static void TryOpenConnectionWithIntegratedAuthentication(string connectionString) { using (SqlConnection connection = new SqlConnection(connectionString)) @@ -63,5 +98,23 @@ private static void TryOpenConnectionWithIntegratedAuthentication(string connect connection.Open(); } } + + private sealed class TestSspiContextProvider : SspiContextProvider + { + protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) + { + throw new SspiTestException(authParams); + } + } + + private sealed class SspiTestException : Exception + { + public SspiTestException(SspiAuthenticationParameters authParams) + { + AuthParams = authParams; + } + + public SspiAuthenticationParameters AuthParams { get; } + } } }