Skip to content

Commit ff9a983

Browse files
committed
Pass SqlAuthenticationParameters in GenerateSspiClientContext
As part of this change, the SSPIContextProvider base class now iterates through all the server names similar to what NegotiateSSPIContextProvider did.
1 parent ccb77e8 commit ff9a983

File tree

3 files changed

+83
-37
lines changed

3 files changed

+83
-37
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ private void LoadSSPILibrary()
4949
}
5050
}
5151

52-
protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> _sniSpnBuffer)
52+
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
5353
{
5454
#if NETFRAMEWORK
5555
SNIHandle handle = _physicalStateObj.Handle;
@@ -62,9 +62,9 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
6262
var sendLength = s_maxSSPILength;
6363
var outBuff = outgoingBlobWriter.GetSpan((int)sendLength);
6464

65-
if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, _sniSpnBuffer[0]))
65+
if (0 != SniNativeWrapper.SNISecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.ServerName))
6666
{
67-
throw new InvalidOperationException(SQLMessage.SSPIGenerateError());
67+
return false;
6868
}
6969

7070
if (sendLength > int.MaxValue)
@@ -73,6 +73,8 @@ protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
7373
}
7474

7575
outgoingBlobWriter.Advance((int)sendLength);
76+
77+
return true;
7678
}
7779
}
7880
}

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#if NET
22

33
using System;
4-
using System.Text;
5-
using System.Net.Security;
64
using System.Buffers;
5+
using System.Net.Security;
76

87
#nullable enable
98

@@ -12,34 +11,25 @@ namespace Microsoft.Data.SqlClient
1211
internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider
1312
{
1413
private NegotiateAuthentication? _negotiateAuth = null;
15-
16-
protected override void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverNames)
14+
15+
protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams)
1716
{
1817
NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials;
1918

20-
for (int i = 0; i < serverNames.Length; i++)
21-
{
22-
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = serverNames[i] });
23-
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;
24-
25-
// Log session id, status code and the actual SPN used in the negotiation
26-
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
27-
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
28-
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
29-
{
30-
outgoingBlobWriter.Write(sendBuff);
31-
break; // Successful case, exit the loop with current SPN.
32-
}
33-
else
34-
{
35-
_negotiateAuth = null; // Reset _negotiateAuth to be generated again for next SPN.
36-
}
37-
}
19+
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.ServerName });
20+
var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!;
21+
22+
// Log session id, status code and the actual SPN used in the negotiation
23+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSSPIContextProvider),
24+
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName);
3825

39-
if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded)
26+
if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded)
4027
{
41-
throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
28+
outgoingBlobWriter.Write(sendBuff);
29+
return true; // Successful case, exit the loop with current SPN.
4230
}
31+
32+
return false;
4333
}
4434
}
4535
}

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,80 @@ private protected virtual void Initialize()
2626
{
2727
}
2828

29-
protected abstract void GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> _sniSpnBuffer);
29+
protected abstract bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SqlAuthenticationParameters authParams);
3030

31-
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string sniSpnBuffer)
32-
=> SSPIData(receivedBuff, outgoingBlobWriter, new[] { sniSpnBuffer });
31+
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string serverName)
32+
{
33+
if (!RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverName))
34+
{
35+
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
36+
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
37+
}
38+
}
3339

34-
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, string[] sniSpnBuffer)
40+
internal void SSPIData(ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> outgoingBlobWriter, ReadOnlySpan<string> serverNames)
3541
{
3642
using (TrySNIEventScope.Create(nameof(SSPIContextProvider)))
3743
{
38-
try
44+
foreach (var serverName in serverNames)
3945
{
40-
GenerateSspiClientContext(receivedBuff, outgoingBlobWriter, sniSpnBuffer);
41-
}
42-
catch (Exception e)
43-
{
44-
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
46+
if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverName))
47+
{
48+
return;
49+
}
4550
}
51+
52+
// If we've hit here, the SSPI context provider implementation failed to generate the SSPI context.
53+
SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT);
54+
}
55+
}
56+
57+
private bool RunGenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, string serverName)
58+
{
59+
var authParams = CreateSqlAuthParams(_parser.Connection, serverName);
60+
61+
try
62+
{
63+
#if NET8_0_OR_GREATER
64+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, SPN={3}", GetType().FullName,
65+
nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, serverName);
66+
#else
67+
SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName,
68+
nameof(GenerateSspiClientContext), serverName);
69+
#endif
70+
71+
return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams);
72+
}
73+
catch (Exception e)
74+
{
75+
//throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
76+
SSPIError(e.Message + Environment.NewLine + e.StackTrace, TdsEnums.GEN_CLIENT_CONTEXT);
77+
return false;
4678
}
4779
}
4880

81+
private static SqlAuthenticationParameters CreateSqlAuthParams(SqlInternalConnectionTds connection, string serverName)
82+
{
83+
var auth = new SqlAuthenticationParameters.Builder(
84+
authenticationMethod: connection.ConnectionOptions.Authentication,
85+
resource: null,
86+
authority: null,
87+
serverName: serverName,
88+
connection.ConnectionOptions.InitialCatalog);
89+
90+
if (connection.ConnectionOptions.UserID is { } userId)
91+
{
92+
auth.WithUserId(userId);
93+
}
94+
95+
if (connection.ConnectionOptions.Password is { } password)
96+
{
97+
auth.WithPassword(password);
98+
}
99+
100+
return auth;
101+
}
102+
49103
protected void SSPIError(string error, string procedure)
50104
{
51105
Debug.Assert(!ADP.IsEmpty(procedure), "TdsParser.SSPIError called with an empty or null procedure string");

0 commit comments

Comments
 (0)