diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
index 11602bd078..161326f6d1 100644
--- a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
+++ b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
@@ -2133,6 +2133,19 @@ The following sample tries to open a connection to an invalid database to simula
Returns 0 if the connection is inactive on the client side.
+
+
+ Gets or sets the instance for customizing the SSPI context. If not set, the default for the platform will be used.
+
+
+ An instance.
+
+
+
+ The SspiContextProvider is a part of the connection pool key. Care should be taken when using this property to ensure the implementation returns a stable identity per resource.
+
+
+
Indicates the state of the during the most recent network operation performed on the connection.
diff --git a/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml b/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml
new file mode 100644
index 0000000000..4623d86052
--- /dev/null
+++ b/doc/snippets/Microsoft.Data.SqlClient/SspiAuthenticationParameters.xml
@@ -0,0 +1,28 @@
+
+
+
+
+ Provides parameters used during SSPI authentication.
+
+
+ Creates an instance of the SspiAuthenticationParameters.
+ The name of the server.
+ The resource (often the server service principal name).
+
+
+ Gets the resource (often the server service principal name).
+
+
+ Gets the server name.
+
+
+ Gets or sets the user id if available.
+
+
+ Gets or sets the database name if available.
+
+
+ Gets or sets the password if available.
+
+
+
diff --git a/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml b/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml
new file mode 100644
index 0000000000..2ea58cb80b
--- /dev/null
+++ b/doc/snippets/Microsoft.Data.SqlClient/SspiContextProvider.xml
@@ -0,0 +1,20 @@
+
+
+
+
+ Provides the ability to customize SSPI context generation.
+
+
+ Creates an instance of the SSPIContextProvider.
+
+
+ Generates an SSPI outgoing blob given the incoming blob.
+ Incoming blob
+ Outgoing blob
+ Gets the authentication parameters associated with this connection.
+
+ true if the context was generated, otherwise false.
+
+
+
+
diff --git a/src/Microsoft.Data.SqlClient.sln b/src/Microsoft.Data.SqlClient.sln
index e4d29d999c..bf0cbe8cc2 100644
--- a/src/Microsoft.Data.SqlClient.sln
+++ b/src/Microsoft.Data.SqlClient.sln
@@ -149,6 +149,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Data.SqlClient",
..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventArgs.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventArgs.xml
..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventHandler.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlRowUpdatingEventHandler.xml
..\doc\snippets\Microsoft.Data.SqlClient\SqlTransaction.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SqlTransaction.xml
+ ..\doc\snippets\Microsoft.Data.SqlClient\SspiAuthenticationParameters.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SspiAuthenticationParameters.xml
+ ..\doc\snippets\Microsoft.Data.SqlClient\SspiContextProvider.xml = ..\doc\snippets\Microsoft.Data.SqlClient\SspiContextProvider.xml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Data.SqlClient.DataClassification", "Microsoft.Data.SqlClient.DataClassification", "{5D1F0032-7B0D-4FB6-A969-FCFB25C9EA1D}"
diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
index 55b68fd9b3..9b00807041 100644
--- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
@@ -931,6 +931,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect
[System.ComponentModel.BrowsableAttribute(false)]
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } }
+ ///
+ public SspiContextProvider SspiContextProvider { get { throw null; } set { } }
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public override string Database { get { throw null; } }
@@ -1978,6 +1980,37 @@ public sealed class SqlConfigurableRetryFactory
///
public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; }
}
+ ///
+ public abstract class SspiContextProvider
+ {
+ ///
+ protected abstract bool GenerateContext(System.ReadOnlySpan incomingBlob, System.Buffers.IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams);
+ }
+ ///
+ public sealed class SspiAuthenticationParameters
+ {
+ ///
+ public SspiAuthenticationParameters(string serverName, string resource)
+ {
+ ServerName = serverName;
+ Resource = resource;
+ }
+
+ ///
+ public string Resource { get; }
+
+ ///
+ public string ServerName { get; }
+
+ ///
+ public string UserId { get; set; }
+
+ ///
+ public string DatabaseName { get; set; }
+
+ ///
+ public string Password { get; set; }
+ }
}
namespace Microsoft.Data.SqlClient.Diagnostics
{
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs
index 11e7bd71ad..56b7de3f34 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs
@@ -91,6 +91,7 @@ private static readonly Dictionary
private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders;
private Func> _accessTokenCallback;
+ private SspiContextProvider _sspiContextProvider;
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
@@ -648,7 +649,7 @@ public override string ConnectionString
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
}
}
- ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
+ ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider));
_connectionString = value; // Change _connectionString value only after value is validated
CacheConnectionStringProperties();
}
@@ -708,7 +709,7 @@ public string AccessToken
}
// Need to call ConnectionString_Set to do proper pool group check
- ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null));
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null, sspiContextProvider: null));
_accessToken = value;
}
}
@@ -731,11 +732,22 @@ public Func
+ public SspiContextProvider SspiContextProvider
+ {
+ get { return _sspiContextProvider; }
+ set
+ {
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value));
+ _sspiContextProvider = value;
+ }
+ }
+
///
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
@@ -1033,7 +1045,7 @@ public SqlCredential Credential
_credential = value;
// Need to call ConnectionString_Set to do proper pool group check
- ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback, sspiContextProvider: null));
}
}
@@ -2263,7 +2275,7 @@ public static void ChangePassword(string connectionString, string newPassword)
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
if (connectionOptions.IntegratedSecurity)
@@ -2312,7 +2324,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
@@ -2350,7 +2362,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
{
con?.Dispose();
}
- SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
+ SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);
SqlConnectionFactory.SingletonInstance.ClearPool(key);
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
index c75eb35a0e..8db440a5ac 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs
@@ -135,6 +135,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
internal readonly Func> _accessTokenCallback;
+ internal readonly SspiContextProvider _sspiContextProvider;
private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
@@ -460,8 +461,8 @@ internal SqlInternalConnectionTds(
bool applyTransientFaultHandling = false,
string accessToken = null,
IDbConnectionPool pool = null,
- Func> accessTokenCallback = null) : base(connectionOptions)
+ Func> accessTokenCallback = null,
+ SspiContextProvider sspiContextProvider = null) : base(connectionOptions)
{
#if DEBUG
if (reconnectSessionData != null)
@@ -514,6 +515,7 @@ internal SqlInternalConnectionTds(
}
_accessTokenCallback = accessTokenCallback;
+ _sspiContextProvider = sspiContextProvider;
_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
index 79228a3449..b344db128d 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -413,7 +413,7 @@ internal void Connect(ServerInfo serverInfo,
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
- _authenticationProvider = _physicalStateObj.CreateSspiContextProvider();
+ _authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider();
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | SSPI or Active Directory Authentication Library loaded for SQL Server based integrated authentication");
}
diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
index d7f280ca33..bd1e17c2d9 100644
--- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
@@ -810,6 +810,8 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden
[System.ComponentModel.BrowsableAttribute(false)]
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } }
+ ///
+ public SspiContextProvider SspiContextProvider { get { throw null; } set { } }
///
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public override string Database { get { throw null; } }
@@ -1959,6 +1961,37 @@ public sealed class SqlConfigurableRetryFactory
///
public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; }
}
+ ///
+ public abstract class SspiContextProvider
+ {
+ ///
+ protected abstract bool GenerateContext(System.ReadOnlySpan incomingBlob, System.Buffers.IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams);
+ }
+ ///
+ public sealed class SspiAuthenticationParameters
+ {
+ ///
+ public SspiAuthenticationParameters(string serverName, string resource)
+ {
+ ServerName = serverName;
+ Resource = resource;
+ }
+
+ ///
+ public string Resource { get; }
+
+ ///
+ public string ServerName { get; }
+
+ ///
+ public string UserId { get; set; }
+
+ ///
+ public string DatabaseName { get; set; }
+
+ ///
+ public string Password { get; set; }
+ }
}
namespace Microsoft.Data.SqlClient.Server
{
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs
index c2ea5c3981..651f1f1c93 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs
@@ -89,6 +89,8 @@ private static readonly Dictionary
private Func> _accessTokenCallback;
+ private SspiContextProvider _sspiContextProvider;
+
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
@@ -575,6 +577,17 @@ internal int ConnectRetryInterval
get => ((SqlConnectionString)ConnectionOptions).ConnectRetryInterval;
}
+ ///
+ public SspiContextProvider SspiContextProvider
+ {
+ get { return _sspiContextProvider; }
+ set
+ {
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value));
+ _sspiContextProvider = value;
+ }
+ }
+
///
[DefaultValue("")]
#pragma warning disable 618 // ignore obsolete warning about RecommendedAsConfigurable to use SettingsBindableAttribute
@@ -643,7 +656,7 @@ public override string ConnectionString
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
}
}
- ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
+ ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider));
_connectionString = value; // Change _connectionString value only after value is validated
CacheConnectionStringProperties();
}
@@ -703,7 +716,7 @@ public string AccessToken
_accessToken = value;
// Need to call ConnectionString_Set to do proper pool group check
- ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null));
+ ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null, sspiContextProvider: null));
}
}
@@ -725,7 +738,7 @@ public Func> _accessTokenCallback;
+ internal readonly SspiContextProvider _sspiContextProvider;
private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
@@ -472,8 +473,8 @@ internal SqlInternalConnectionTds(
bool applyTransientFaultHandling = false,
string accessToken = null,
IDbConnectionPool pool = null,
- Func> accessTokenCallback = null)
+ Func> accessTokenCallback = null,
+ SspiContextProvider sspiContextProvider = null)
: base(connectionOptions)
{
#if DEBUG
@@ -525,6 +526,7 @@ internal SqlInternalConnectionTds(
}
_accessTokenCallback = accessTokenCallback;
+ _sspiContextProvider = sspiContextProvider;
_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
index f871da2c6c..ce32b6d873 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -411,7 +411,7 @@ internal void Connect(ServerInfo serverInfo,
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
- _authenticationProvider = _physicalStateObj.CreateSspiContextProvider();
+ _authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider();
if (!string.IsNullOrEmpty(serverInfo.ServerSPN))
{
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs
index 207c0a8e1a..31da3521df 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ConnectionPool/SqlConnectionPoolKey.cs
@@ -18,10 +18,12 @@ internal class SqlConnectionPoolKey : DbConnectionPoolKey
private readonly SqlCredential _credential;
private readonly string _accessToken;
private Func> _accessTokenCallback;
+ private SspiContextProvider _sspiContextProvider;
internal SqlCredential Credential => _credential;
internal string AccessToken => _accessToken;
internal Func> AccessTokenCallback => _accessTokenCallback;
+ internal SspiContextProvider SspiContextProvider => _sspiContextProvider;
internal override string ConnectionString
{
@@ -33,12 +35,18 @@ internal override string ConnectionString
}
}
- internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken, Func> accessTokenCallback) : base(connectionString)
+ internal SqlConnectionPoolKey(
+ string connectionString,
+ SqlCredential credential,
+ string accessToken,
+ Func> accessTokenCallback,
+ SspiContextProvider sspiContextProvider) : base(connectionString)
{
Debug.Assert(credential == null || accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time.");
_credential = credential;
_accessToken = accessToken;
_accessTokenCallback = accessTokenCallback;
+ _sspiContextProvider = sspiContextProvider;
CalculateHashCode();
}
@@ -47,6 +55,8 @@ private SqlConnectionPoolKey(SqlConnectionPoolKey key) : base(key)
_credential = key.Credential;
_accessToken = key.AccessToken;
_accessTokenCallback = key._accessTokenCallback;
+ _sspiContextProvider = key._sspiContextProvider;
+
CalculateHashCode();
}
@@ -61,7 +71,8 @@ public override bool Equals(object obj)
&& _credential == key._credential
&& ConnectionString == key.ConnectionString
&& _accessTokenCallback == key._accessTokenCallback
- && string.CompareOrdinal(_accessToken, key._accessToken) == 0);
+ && string.CompareOrdinal(_accessToken, key._accessToken) == 0
+ && _sspiContextProvider == key._sspiContextProvider);
}
public override int GetHashCode()
@@ -94,6 +105,11 @@ private void CalculateHashCode()
_hashValue = _hashValue * 17 + _accessTokenCallback.GetHashCode();
}
}
+
+ if (_sspiContextProvider != null)
+ {
+ _hashValue = _hashValue * 17 + _sspiContextProvider.GetHashCode();
+ }
}
}
}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs
index 5935b149c8..1cc4af3e9c 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSspiContextProvider.cs
@@ -49,7 +49,7 @@ private void LoadSSPILibrary()
}
}
- protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams)
+ protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams)
{
#if NETFRAMEWORK
SNIHandle handle = _physicalStateObj.Handle;
@@ -62,7 +62,7 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlo
var sendLength = s_maxSSPILength;
var outBuff = outgoingBlobWriter.GetSpan((int)sendLength);
- if (0 != SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource))
+ if (SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource) != 0)
{
return false;
}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs
index 5dc52010b3..a255712238 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs
@@ -12,7 +12,7 @@ internal sealed class NegotiateSspiContextProvider : SspiContextProvider
{
private NegotiateAuthentication? _negotiateAuth = null;
- protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams)
+ protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams)
{
NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials;
@@ -21,7 +21,7 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlo
// Log session id, status code and the actual SPN used in the negotiation
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSspiContextProvider),
- nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
+ nameof(GenerateContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
{
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs
index dce0858360..ad6c92853f 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiAuthenticationParameters.cs
@@ -2,22 +2,29 @@
namespace Microsoft.Data.SqlClient
{
- internal sealed class SspiAuthenticationParameters
+ ///
+ public sealed class SspiAuthenticationParameters
{
+ ///
public SspiAuthenticationParameters(string serverName, string resource)
{
ServerName = serverName;
Resource = resource;
}
+ ///
public string Resource { get; }
+ ///
public string ServerName { get; }
+ ///
public string? UserId { get; set; }
+ ///
public string? DatabaseName { get; set; }
+ ///
public string? Password { get; set; }
}
}
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs
index ff83422f10..7b339f285d 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs
@@ -6,12 +6,18 @@
namespace Microsoft.Data.SqlClient
{
- internal abstract class SspiContextProvider
+ ///
+ public abstract class SspiContextProvider
{
private TdsParser _parser = null!;
private ServerInfo _serverInfo = null!;
private protected TdsParserStateObject _physicalStateObj = null!;
+ ///
+ protected SspiContextProvider()
+ {
+ }
+
internal void Initialize(ServerInfo serverInfo, TdsParserStateObject physicalStateObj, TdsParser parser)
{
_parser = parser;
@@ -25,13 +31,14 @@ private protected virtual void Initialize()
{
}
- protected abstract bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams);
+ ///
+ protected abstract bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams);
internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string serverSpn)
{
using var _ = TrySNIEventScope.Create(nameof(SspiContextProvider));
- if (!RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
+ if (!RunGenerateContext(receivedBuff, outgoingBlobWriter, serverSpn))
{
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
@@ -44,7 +51,7 @@ internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outg
foreach (var serverSpn in serverSpns)
{
- if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
+ if (RunGenerateContext(receivedBuff, outgoingBlobWriter, serverSpn))
{
return;
}
@@ -54,7 +61,7 @@ internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outg
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
}
- private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, string serverSpn)
+ private bool RunGenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, string serverSpn)
{
var options = _parser.Connection.ConnectionOptions;
var authParams = new SspiAuthenticationParameters(options.DataSource, serverSpn)
@@ -66,9 +73,9 @@ private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBuff
try
{
- SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateSspiClientContext), serverSpn);
+ SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateContext), serverSpn);
- return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams);
+ return GenerateContext(incomingBlob, outgoingBlobWriter, authParams);
}
catch (Exception e)
{
@@ -77,7 +84,7 @@ private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBuff
}
}
- protected void SSPIError(string error, string procedure)
+ private protected void SSPIError(string error, string procedure)
{
Debug.Assert(!string.IsNullOrEmpty(procedure), "TdsParser.SSPIError called with an empty or null procedure string");
Debug.Assert(!string.IsNullOrEmpty(error), "TdsParser.SSPIError called with an empty or null error string");
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
index ae4c7bc98d..00a86bdeac 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs
@@ -103,7 +103,7 @@ protected override DbConnectionInternal CreateConnection(
// This first connection is established to SqlExpress to get the instance name
// of the UserInstance.
SqlConnectionString sseopt = new SqlConnectionString(opt, opt.DataSource, userInstance: true, setEnlistValue: false);
- sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling);
+ sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling, sspiContextProvider: key.SspiContextProvider);
// NOTE: Retrieve here. This user instance name will be used below to connect to the Sql Express User Instance.
instanceName = sseConnection.InstanceName;
@@ -157,7 +157,8 @@ protected override DbConnectionInternal CreateConnection(
applyTransientFaultHandling,
key.AccessToken,
pool,
- key.AccessTokenCallback);
+ key.AccessTokenCallback,
+ key.SspiContextProvider);
}
protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs
index e043b2253c..af650a408e 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/IntegratedAuthenticationTest/IntegratedAuthenticationTest.cs
@@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
+using System.Buffers;
using Xunit;
namespace Microsoft.Data.SqlClient.ManualTesting.Tests
@@ -56,6 +58,39 @@ public static void IntegratedAuthenticationTest_ServerSPN()
TryOpenConnectionWithIntegratedAuthentication(builder.ConnectionString);
}
+ [ConditionalFact(nameof(IsIntegratedSecurityEnvironmentSet), nameof(AreConnectionStringsSetup))]
+ public static void CustomSspiContextGeneratorTest()
+ {
+ SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
+ builder.IntegratedSecurity = true;
+ Assert.True(DataTestUtility.ParseDataSource(builder.DataSource, out string hostname, out int port, out string instanceName));
+ // Build the SPN for the server we are connecting to
+ builder.ServerSPN = $"MSSQLSvc/{DataTestUtility.GetMachineFQDN(hostname)}";
+ if (!string.IsNullOrWhiteSpace(instanceName))
+ {
+ builder.ServerSPN += ":" + instanceName;
+ }
+
+ using SqlConnection conn = new(builder.ConnectionString)
+ {
+ SspiContextProvider = new TestSspiContextProvider(),
+ };
+
+ try
+ {
+ conn.Open();
+
+ Assert.Fail("Expected to use custom SSPI context provider");
+ }
+ catch (SspiTestException sspi)
+ {
+ Assert.Equal(sspi.AuthParams.ServerName, builder.DataSource);
+ Assert.Equal(sspi.AuthParams.DatabaseName, builder.InitialCatalog);
+ Assert.Equal(sspi.AuthParams.UserId, builder.UserID);
+ Assert.Equal(sspi.AuthParams.Password, builder.Password);
+ }
+ }
+
private static void TryOpenConnectionWithIntegratedAuthentication(string connectionString)
{
using (SqlConnection connection = new SqlConnection(connectionString))
@@ -63,5 +98,23 @@ private static void TryOpenConnectionWithIntegratedAuthentication(string connect
connection.Open();
}
}
+
+ private sealed class TestSspiContextProvider : SspiContextProvider
+ {
+ protected override bool GenerateContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams)
+ {
+ throw new SspiTestException(authParams);
+ }
+ }
+
+ private sealed class SspiTestException : Exception
+ {
+ public SspiTestException(SspiAuthenticationParameters authParams)
+ {
+ AuthParams = authParams;
+ }
+
+ public SspiAuthenticationParameters AuthParams { get; }
+ }
}
}