Skip to content

Expose SspiAuthenticationParameters on SspiContextProvider #2454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2435,13 +2435,10 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
try
{
var authParamsBuilder = new SqlAuthenticationParameters.Builder(
authenticationMethod: ConnectionOptions.Authentication,
resource: fedAuthInfo.spn,
authority: fedAuthInfo.stsurl,
serverName: ConnectionOptions.DataSource,
databaseName: ConnectionOptions.InitialCatalog)
.WithConnectionId(_clientConnectionId)
.WithConnectionTimeout(ConnectionOptions.ConnectTimeout);
connection: this,
resource: fedAuthInfo.spn,
authority: fedAuthInfo.stsurl);

switch (ConnectionOptions.Authentication)
{
case SqlAuthenticationMethod.ActiveDirectoryIntegrated:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -15,11 +14,11 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Transactions;
using Microsoft.Data.Common;
using Microsoft.Data.ProviderBase;
using Microsoft.Data.SqlClient.ConnectionPool;
using Microsoft.Identity.Client;
using System.Transactions;


namespace Microsoft.Data.SqlClient
Expand Down Expand Up @@ -137,7 +136,7 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
internal readonly Func<SqlAuthenticationParameters, CancellationToken,Task<SqlAuthenticationToken>> _accessTokenCallback;
internal readonly Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;

Expand Down Expand Up @@ -1651,12 +1650,12 @@ private void OpenLoginEnlist(TimeoutTimer timeout, SqlConnectionString connectio
else
{
_timeoutErrorInternal.SetFailoverScenario(false); // not a failover scenario
LoginNoFailover(dataSource,
newPassword,
newSecurePassword,
LoginNoFailover(dataSource,
newPassword,
newSecurePassword,
redirectedUserInstance,
connectionOptions,
credential,
connectionOptions,
credential,
timeout);
}

Expand Down Expand Up @@ -2625,7 +2624,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)

if (_newDbConnectionPoolAuthenticationContext != null)
{
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
}
}
}
Expand Down Expand Up @@ -2739,13 +2738,10 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
try
{
var authParamsBuilder = new SqlAuthenticationParameters.Builder(
authenticationMethod: ConnectionOptions.Authentication,
connection: this,
resource: fedAuthInfo.spn,
authority: fedAuthInfo.stsurl,
serverName: ConnectionOptions.DataSource,
databaseName: ConnectionOptions.InitialCatalog)
.WithConnectionId(_clientConnectionId)
.WithConnectionTimeout(ConnectionOptions.ConnectTimeout);
authority: fedAuthInfo.stsurl);

switch (ConnectionOptions.Authentication)
{
case SqlAuthenticationMethod.ActiveDirectoryIntegrated:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private void LoadSSPILibrary()
}
}

protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
{
#if NETFRAMEWORK
SNIHandle handle = _physicalStateObj.Handle;
Expand All @@ -62,9 +62,9 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
var sendLength = s_maxSSPILength;
var outBuff = outgoingBlobWriter.GetSpan((int)sendLength);

if (0 != SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, serverSpns[0]))
if (0 != SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.ServerName))
{
throw new InvalidOperationException(SQLMessage.SSPIGenerateError());
return false;
}

if (sendLength > int.MaxValue)
Expand All @@ -73,6 +73,8 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
}

outgoingBlobWriter.Advance((int)sendLength);

return true;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#if NET
#if NET

using System;
using System.Net.Security;
using System.Buffers;
using System.Net.Security;

#nullable enable

Expand All @@ -12,33 +12,24 @@ internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider
{
private NegotiateAuthentication? _negotiateAuth = null;

protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
{
NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials;

for (int i = 0; i < serverSpns.Length; i++)
{
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = serverSpns[i] });
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;

// Log session id, status code and the actual SPN used in the negotiation
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
{
outgoingBlobWriter.Write(sendBuff);
break; // Successful case, exit the loop with current SPN.
}
else
{
_negotiateAuth = null; // Reset _negotiateAuth to be generated again for next SPN.
}
}
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.ServerName });
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;

// Log session id, status code and the actual SPN used in the negotiation
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);

if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded)
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
{
throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
outgoingBlobWriter.Write(sendBuff);
return true;
}

return false;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,71 @@ private protected virtual void Initialize()
{
}

protected abstract void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns);
protected abstract bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams);

internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string serverSpn)
=> SSPIData(receivedBuff, outgoingBlobWriter, new[] { serverSpn });
{
using var _ = TrySNIEventScope.Create(nameof(SSPIContextProvider));

if (!RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
{
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
}
}

internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string[] serverSpns)
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverSpns)
{
using (TrySNIEventScope.Create(nameof(SSPIContextProvider)))
using var _ = TrySNIEventScope.Create(nameof(SSPIContextProvider));

foreach (var serverSpn in serverSpns)
{
try
{
GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpns);
}
catch (Exception e)
if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn))
{
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
return;
}
}

// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
}

private bool RunGenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, string serverSpn)
{
var authParams = CreateSqlAuthParams(_parser.Connection, serverSpn);

try
{
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateSspiClientContext), serverSpn);

return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams);
}
catch (Exception e)
{
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
return false;
}
}

private static SqlAuthenticationParameters CreateSqlAuthParams(SqlInternalConnectionTds connection, string serverSpn)
{
var auth = new SqlAuthenticationParameters.Builder(
connection: connection,
resource: serverSpn,
authority: null);


if (connection.ConnectionOptions.UserID is { } userId)
{
auth.WithUserId(userId);
}

if (connection.ConnectionOptions.Password is { } password)
{
auth.WithPassword(password);
}

return auth;
}

protected void SSPIError(string error, string procedure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ protected SqlAuthenticationParameters(
string authority,
string userId,
string password,
Guid connectionId,
Guid connectionId,
int connectionTimeout)
{
AuthenticationMethod = authenticationMethod;
Expand Down Expand Up @@ -149,11 +149,13 @@ public Builder WithConnectionTimeout(int timeout)
return this;
}

internal Builder(SqlAuthenticationMethod authenticationMethod, string resource, string authority, string serverName, string databaseName)
internal Builder(SqlInternalConnectionTds connection, string resource, string authority)
{
_authenticationMethod = authenticationMethod;
_serverName = serverName;
_databaseName = databaseName;
_authenticationMethod = connection.ConnectionOptions.Authentication;
_serverName = connection.ConnectionOptions.DataSource;
_databaseName = connection.ConnectionOptions.InitialCatalog;
_connectionTimeout = connection.ConnectionOptions.ConnectTimeout;
_connectionId = connection.ClientConnectionId;
_resource = resource;
_authority = authority;
}
Expand Down
Loading