Skip to content

Commit 8ca9f45

Browse files
committed
Expose SSPI context provider as public
This change: - Adds a property to SqlConnection to allow setting a provider - Plumbs that property into the TdsParser so that it can be used if set
1 parent 21a4311 commit 8ca9f45

File tree

12 files changed

+117
-31
lines changed

12 files changed

+117
-31
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<?xml version="1.0"?>
2+
<docs>
3+
<members name="SSPIContextProvider">
4+
<SSPIContextProvider>
5+
<summary>Provides the ability to customize SSPI context generation.</summary>
6+
</SSPIContextProvider>
7+
<AuthenticationParameters>
8+
<summary>Gets the authentication parameters associated with this connection.</summary>
9+
</AuthenticationParameters>
10+
<GenerateSspiClientContext>
11+
<summary>Generates a SSPI outgoing blob given the incoming blob.</summary>
12+
<param name="incomingBlob">Incoming blob</param>
13+
<param name="outgoingBlobWriter">Outgoing blob</param>
14+
</GenerateSspiClientContext>
15+
</members>
16+
</docs>

src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collect
871871
[System.ComponentModel.BrowsableAttribute(false)]
872872
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
873873
public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } }
874+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/SSPIContextProviderFactory/*' />
875+
public System.Func<SSPIContextProvider> SSPIContextProviderFactory { get { throw null; } set { } }
874876
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*'/>
875877
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
876878
public override string Database { get { throw null; } }
@@ -1899,6 +1901,15 @@ public sealed class SqlConfigurableRetryFactory
18991901
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConfigurableRetryFactory.xml' path='docs/members[@name="SqlConfigurableRetryFactory"]/CreateNoneRetryProvider/*' />
19001902
public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; }
19011903
}
1904+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SSPIContextProvider.xml' path='docs/members[@name="SSPIContextProvider"]/SSPIContextProvider/*'/>
1905+
public abstract class SSPIContextProvider
1906+
{
1907+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SSPIContextProvider.xml' path='docs/members[@name="SSPIContextProvider"]/AuthenticationParameters/*'/>
1908+
protected SqlAuthenticationParameters AuthenticationParameters { get { throw null; } }
1909+
1910+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SSPIContextProvider.xml' path='docs/members[@name="SSPIContextProvider"]/GenerateSspiClientContext/*'/>
1911+
protected abstract void GenerateSspiClientContext(System.ReadOnlySpan<byte> incomingBlob, System.Buffers.IBufferWriter<byte> outgoingBlobWriter);
1912+
}
19021913
}
19031914
namespace Microsoft.Data.SqlClient.Server
19041915
{

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
8888
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;
8989

9090
private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
91+
private Func<SSPIContextProvider> _sspiContextProviderFactory;
9192

9293
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
9394
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
@@ -646,7 +647,7 @@ public override string ConnectionString
646647
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
647648
}
648649
}
649-
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
650+
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProviderFactory));
650651
_connectionString = value; // Change _connectionString value only after value is validated
651652
CacheConnectionStringProperties();
652653
}
@@ -706,7 +707,7 @@ public string AccessToken
706707
}
707708

708709
// Need to call ConnectionString_Set to do proper pool group check
709-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null));
710+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null, sspiContextProviderFactory: _sspiContextProviderFactory));
710711
_accessToken = value;
711712
}
712713
}
@@ -729,11 +730,22 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
729730
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
730731
}
731732

732-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value));
733+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value, sspiContextProviderFactory: _sspiContextProviderFactory));
733734
_accessTokenCallback = value;
734735
}
735736
}
736737

738+
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/SSPIContextProviderFactory/*' />
739+
public Func<SSPIContextProvider> SSPIContextProviderFactory
740+
{
741+
get { return _sspiContextProviderFactory; }
742+
set
743+
{
744+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: _accessTokenCallback, sspiContextProviderFactory: value));
745+
_sspiContextProviderFactory = value;
746+
}
747+
}
748+
737749
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
738750
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
739751
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
@@ -1028,7 +1040,7 @@ public SqlCredential Credential
10281040
_credential = value;
10291041

10301042
// Need to call ConnectionString_Set to do proper pool group check
1031-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
1043+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback, _sspiContextProviderFactory));
10321044
}
10331045
}
10341046

@@ -1076,7 +1088,7 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
10761088
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
10771089
}
10781090

1079-
if(_accessTokenCallback != null)
1091+
if (_accessTokenCallback != null)
10801092
{
10811093
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
10821094
}
@@ -1098,7 +1110,7 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCa
10981110
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
10991111
}
11001112

1101-
if(_accessToken != null)
1113+
if (_accessToken != null)
11021114
{
11031115
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
11041116
}
@@ -2212,7 +2224,7 @@ public static void ChangePassword(string connectionString, string newPassword)
22122224
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
22132225
}
22142226

2215-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);
2227+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null, sspiContextProviderFactory: null);
22162228

22172229
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
22182230
if (connectionOptions.IntegratedSecurity)
@@ -2261,7 +2273,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
22612273
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
22622274
}
22632275

2264-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
2276+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProviderFactory: null);
22652277

22662278
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
22672279

@@ -2300,7 +2312,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
23002312
if (con != null)
23012313
con.Dispose();
23022314
}
2303-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
2315+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProviderFactory: null);
23042316

23052317
SqlConnectionFactory.SingletonInstance.ClearPool(key);
23062318
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
9696
// This first connection is established to SqlExpress to get the instance name
9797
// of the UserInstance.
9898
SqlConnectionString sseopt = new SqlConnectionString(opt, opt.DataSource, userInstance: true, setEnlistValue: false);
99-
sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling);
99+
sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling, sspiContextProviderFactory: key.SSPIContextProviderFactory);
100100
// NOTE: Retrieve <UserInstanceName> here. This user instance name will be used below to connect to the Sql Express User Instance.
101101
instanceName = sseConnection.InstanceName;
102102

@@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
133133
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
134134
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
135135
}
136-
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
136+
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback, key.SSPIContextProviderFactory);
137137
}
138138

139139
protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
130130
// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
131131
SqlFedAuthToken _fedAuthToken = null;
132132
internal byte[] _accessTokenInBytes;
133-
internal readonly Func<SqlAuthenticationParameters, CancellationToken,Task<SqlAuthenticationToken>> _accessTokenCallback;
133+
internal readonly Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
134+
internal readonly Func<SSPIContextProvider> _sspiContextProviderFactory;
134135

135136
private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
136137
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;
@@ -447,8 +448,8 @@ internal SqlInternalConnectionTds(
447448
bool applyTransientFaultHandling = false,
448449
string accessToken = null,
449450
DbConnectionPool pool = null,
450-
Func<SqlAuthenticationParameters, CancellationToken,
451-
Task<SqlAuthenticationToken>> accessTokenCallback = null) : base(connectionOptions)
451+
Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback = null,
452+
Func<SSPIContextProvider> sspiContextProviderFactory = null) : base(connectionOptions)
452453

453454
{
454455
#if DEBUG
@@ -482,6 +483,7 @@ internal SqlInternalConnectionTds(
482483
}
483484

484485
_accessTokenCallback = accessTokenCallback;
486+
_sspiContextProviderFactory = sspiContextProviderFactory;
485487

486488
_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
487489
_sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance;

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ internal void Connect(
422422
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
423423
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
424424
{
425-
_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();
425+
_authenticationProvider = Connection._sspiContextProviderFactory?.Invoke() ?? _physicalStateObj.CreateSSPIContextProvider();
426426
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | SSPI or Active Directory Authentication Library loaded for SQL Server based integrated authentication");
427427
}
428428

src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,8 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden
800800
[System.ComponentModel.BrowsableAttribute(false)]
801801
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
802802
public Microsoft.Data.SqlClient.SqlCredential Credential { get { throw null; } set { } }
803+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/SSPIContextProviderFactory/*' />
804+
public System.Func<SSPIContextProvider> SSPIContextProviderFactory { get { throw null; } set { } }
803805
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*'/>
804806
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
805807
public override string Database { get { throw null; } }
@@ -1952,6 +1954,15 @@ public sealed class SqlConfigurableRetryFactory
19521954
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConfigurableRetryFactory.xml' path='docs/members[@name="SqlConfigurableRetryFactory"]/CreateNoneRetryProvider/*' />
19531955
public static SqlRetryLogicBaseProvider CreateNoneRetryProvider() { throw null; }
19541956
}
1957+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SSPIContextProvider.xml' path='docs/members[@name="SSPIContextProvider"]/SSPIContextProvider/*'/>
1958+
public abstract class SSPIContextProvider
1959+
{
1960+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SSPIContextProvider.xml' path='docs/members[@name="SSPIContextProvider"]/AuthenticationParameters/*'/>
1961+
protected SqlAuthenticationParameters AuthenticationParameters { get { throw null; } }
1962+
1963+
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SSPIContextProvider.xml' path='docs/members[@name="SSPIContextProvider"]/GenerateSspiClientContext/*'/>
1964+
protected abstract void GenerateSspiClientContext(System.ReadOnlySpan<byte> incomingBlob, System.Buffers.IBufferWriter<byte> outgoingBlobWriter);
1965+
}
19551966
}
19561967
namespace Microsoft.Data.SqlClient.Server
19571968
{

0 commit comments

Comments
 (0)